feat: implement full delete feature
This commit is contained in:
parent
a8dab3019e
commit
caf4801a6b
43 changed files with 783 additions and 277 deletions
|
|
@ -108,6 +108,13 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
|||
setDatasets(datasets);
|
||||
}, []);
|
||||
|
||||
const [searchValue, setSearchValue] = useState<string>("");
|
||||
|
||||
const handleSearchDatasetInputChange = useCallback((event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const newSearchValue = event.currentTarget.value;
|
||||
setSearchValue(newSearchValue);
|
||||
}, []);
|
||||
|
||||
const isCloudEnv = isCloudEnvironment();
|
||||
|
||||
return (
|
||||
|
|
@ -129,7 +136,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
|||
<div className="px-5 py-4 lg:w-96 bg-white rounded-xl h-[calc(100%-2.75rem)]">
|
||||
<div className="relative mb-2">
|
||||
<label htmlFor="search-input"><SearchIcon className="absolute left-3 top-[10px] cursor-text" /></label>
|
||||
<input id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
||||
<input onChange={handleSearchDatasetInputChange} id="search-input" className="text-xs leading-3 w-full h-8 flex flex-row items-center gap-2.5 rounded-3xl pl-9 placeholder-gray-300 border-gray-300 border-[1px] focus:outline-indigo-600" placeholder="Search datasets..." />
|
||||
</div>
|
||||
|
||||
<AddDataToCognee
|
||||
|
|
@ -148,6 +155,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
|
|||
<div className="mt-7 mb-14">
|
||||
<CogneeInstancesAccordion>
|
||||
<InstanceDatasetsAccordion
|
||||
searchValue={searchValue}
|
||||
onDatasetsChange={handleDatasetsChange}
|
||||
/>
|
||||
</CogneeInstancesAccordion>
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@ interface DatasetsChangePayload {
|
|||
export interface DatasetsAccordionProps extends Omit<AccordionProps, "isOpen" | "openAccordion" | "closeAccordion" | "children"> {
|
||||
onDatasetsChange?: (payload: DatasetsChangePayload) => void;
|
||||
useCloud?: boolean;
|
||||
searchValue: string;
|
||||
}
|
||||
|
||||
export default function DatasetsAccordion({
|
||||
title,
|
||||
tools,
|
||||
searchValue,
|
||||
switchCaretPosition = false,
|
||||
className,
|
||||
contentClassName,
|
||||
|
|
@ -43,13 +45,7 @@ export default function DatasetsAccordion({
|
|||
removeDataset,
|
||||
getDatasetData,
|
||||
removeDatasetData,
|
||||
} = useDatasets(useCloud);
|
||||
|
||||
useEffect(() => {
|
||||
if (datasets.length === 0) {
|
||||
refreshDatasets();
|
||||
}
|
||||
}, [datasets.length, refreshDatasets]);
|
||||
} = useDatasets(useCloud, searchValue);
|
||||
|
||||
const [openDatasets, openDataset] = useState<Set<string>>(new Set());
|
||||
|
||||
|
|
@ -237,11 +233,16 @@ export default function DatasetsAccordion({
|
|||
contentClassName={contentClassName}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
{datasets.length === 0 && (
|
||||
{datasets.length === 0 && !searchValue && (
|
||||
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
|
||||
<span>No datasets here, add one by clicking +</span>
|
||||
</div>
|
||||
)}
|
||||
{datasets.length === 0 && searchValue && (
|
||||
<div className="flex flex-row items-baseline-last text-sm text-gray-400 mt-2 px-2">
|
||||
<span>No datasets found, please adjust your search term</span>
|
||||
</div>
|
||||
)}
|
||||
{datasets.map((dataset) => {
|
||||
return (
|
||||
<Accordion
|
||||
|
|
|
|||
|
|
@ -1,16 +1,18 @@
|
|||
import classNames from "classnames";
|
||||
import { useCallback, useEffect } from "react";
|
||||
|
||||
import { fetch, isCloudEnvironment, useBoolean } from "@/utils";
|
||||
import { fetch, isCloudApiKeySet, isCloudEnvironment, useBoolean } from "@/utils";
|
||||
import { checkCloudConnection } from "@/modules/cloud";
|
||||
import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons";
|
||||
import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements";
|
||||
|
||||
import DatasetsAccordion, { DatasetsAccordionProps } from "./DatasetsAccordion";
|
||||
|
||||
type InstanceDatasetsAccordionProps = Omit<DatasetsAccordionProps, "title">;
|
||||
interface InstanceDatasetsAccordionProps extends Omit<DatasetsAccordionProps, "title"> {
|
||||
searchValue: string;
|
||||
}
|
||||
|
||||
export default function InstanceDatasetsAccordion({ onDatasetsChange }: InstanceDatasetsAccordionProps) {
|
||||
export default function InstanceDatasetsAccordion({ searchValue, onDatasetsChange }: InstanceDatasetsAccordionProps) {
|
||||
const {
|
||||
value: isLocalCogneeConnected,
|
||||
setTrue: setLocalCogneeConnected,
|
||||
|
|
@ -19,7 +21,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
|||
const {
|
||||
value: isCloudCogneeConnected,
|
||||
setTrue: setCloudCogneeConnected,
|
||||
} = useBoolean(isCloudEnvironment());
|
||||
} = useBoolean(isCloudEnvironment() || isCloudApiKeySet());
|
||||
|
||||
const checkConnectionToCloudCognee = useCallback((apiKey?: string) => {
|
||||
if (apiKey) {
|
||||
|
|
@ -71,6 +73,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
|||
</div>
|
||||
</div>
|
||||
)}
|
||||
searchValue={searchValue}
|
||||
tools={isLocalCogneeConnected ? <span className="text-xs text-indigo-600">Connected</span> : <span className="text-xs text-gray-400">Not connected</span>}
|
||||
switchCaretPosition={true}
|
||||
className="pt-3 pb-1.5"
|
||||
|
|
@ -88,6 +91,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
|
|||
</div>
|
||||
</div>
|
||||
)}
|
||||
searchValue={searchValue}
|
||||
tools={<span className="text-xs text-indigo-600">Connected</span>}
|
||||
switchCaretPosition={true}
|
||||
className="pt-3 pb-1.5"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useCallback, useState } from 'react';
|
||||
import { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react';
|
||||
|
||||
import { fetch } from '@/utils';
|
||||
import { DataFile } from './useData';
|
||||
|
|
@ -11,7 +11,20 @@ export interface Dataset {
|
|||
status: string;
|
||||
}
|
||||
|
||||
function useDatasets(useCloud = false) {
|
||||
function filterDatasets(datasets: Dataset[], searchValue: string) {
|
||||
if (searchValue.trim() === "") {
|
||||
return datasets;
|
||||
}
|
||||
|
||||
const lowercaseSearchValue = searchValue.toLowerCase();
|
||||
|
||||
return datasets.filter((dataset) =>
|
||||
dataset.name.toLowerCase().includes(lowercaseSearchValue)
|
||||
);
|
||||
}
|
||||
|
||||
function useDatasets(useCloud = false, searchValue: string = "") {
|
||||
const allDatasets = useRef<Dataset[]>([]);
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
// const statusTimeout = useRef<any>(null);
|
||||
|
|
@ -57,26 +70,32 @@ function useDatasets(useCloud = false) {
|
|||
// };
|
||||
// }, []);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
setDatasets(filterDatasets(allDatasets.current, searchValue));
|
||||
}, [searchValue]);
|
||||
|
||||
const addDataset = useCallback((datasetName: string) => {
|
||||
return createDataset({ name: datasetName }, useCloud)
|
||||
.then((dataset) => {
|
||||
setDatasets((datasets) => [
|
||||
...datasets,
|
||||
const newDatasets = [
|
||||
...allDatasets.current,
|
||||
dataset,
|
||||
]);
|
||||
];
|
||||
allDatasets.current = newDatasets;
|
||||
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||
});
|
||||
}, [useCloud]);
|
||||
}, [searchValue, useCloud]);
|
||||
|
||||
const removeDataset = useCallback((datasetId: string) => {
|
||||
return fetch(`/v1/datasets/${datasetId}`, {
|
||||
method: 'DELETE',
|
||||
}, useCloud)
|
||||
.then(() => {
|
||||
setDatasets((datasets) =>
|
||||
datasets.filter((dataset) => dataset.id !== datasetId)
|
||||
);
|
||||
const newDatasets = allDatasets.current.filter((dataset) => dataset.id !== datasetId)
|
||||
allDatasets.current = newDatasets;
|
||||
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||
});
|
||||
}, [useCloud]);
|
||||
}, [searchValue, useCloud]);
|
||||
|
||||
const fetchDatasets = useCallback(() => {
|
||||
return fetch('/v1/datasets', {
|
||||
|
|
@ -86,7 +105,8 @@ function useDatasets(useCloud = false) {
|
|||
}, useCloud)
|
||||
.then((response) => response.json())
|
||||
.then((datasets) => {
|
||||
setDatasets(datasets);
|
||||
allDatasets.current = datasets;
|
||||
setDatasets(filterDatasets(datasets, searchValue));
|
||||
|
||||
// if (datasets.length > 0) {
|
||||
// checkDatasetStatuses(datasets);
|
||||
|
|
@ -97,28 +117,38 @@ function useDatasets(useCloud = false) {
|
|||
.catch((error) => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
});
|
||||
}, [useCloud]);
|
||||
}, [searchValue, useCloud]);
|
||||
|
||||
useEffect(() => {
|
||||
if (allDatasets.current.length === 0) {
|
||||
fetchDatasets();
|
||||
}
|
||||
}, [fetchDatasets]);
|
||||
|
||||
const getDatasetData = useCallback((datasetId: string) => {
|
||||
return fetch(`/v1/datasets/${datasetId}/data`, {}, useCloud)
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
const datasetIndex = datasets.findIndex((dataset) => dataset.id === datasetId);
|
||||
const datasetIndex = allDatasets.current.findIndex((dataset) => dataset.id === datasetId);
|
||||
|
||||
if (datasetIndex >= 0) {
|
||||
setDatasets((datasets) => [
|
||||
...datasets.slice(0, datasetIndex),
|
||||
{
|
||||
...datasets[datasetIndex],
|
||||
data,
|
||||
},
|
||||
...datasets.slice(datasetIndex + 1),
|
||||
]);
|
||||
const newDatasets = [
|
||||
...allDatasets.current.slice(0, datasetIndex),
|
||||
{
|
||||
...allDatasets.current[datasetIndex],
|
||||
data,
|
||||
},
|
||||
...allDatasets.current.slice(datasetIndex + 1),
|
||||
];
|
||||
|
||||
allDatasets.current = newDatasets;
|
||||
|
||||
setDatasets(filterDatasets(newDatasets, searchValue));
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
}, [datasets, useCloud]);
|
||||
}, [searchValue, useCloud]);
|
||||
|
||||
const removeDatasetData = useCallback((datasetId: string, dataId: string) => {
|
||||
return fetch(`/v1/datasets/${datasetId}/data/${dataId}`, {
|
||||
|
|
|
|||
|
|
@ -7,15 +7,19 @@ import json
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, Mapping
|
||||
from typing import Any, Dict, Iterable, List, Mapping
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee import config, prune, search, SearchType, visualize_graph
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import load_or_create_datasets
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
|
|
@ -76,18 +80,6 @@ def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]:
|
|||
return out
|
||||
|
||||
|
||||
def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
||||
"""Collect people from payloads."""
|
||||
people = [person for payload in payloads for person in payload.get("people", [])]
|
||||
return people
|
||||
|
||||
|
||||
def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]:
|
||||
"""Collect companies from payloads."""
|
||||
companies = [company for payload in payloads for company in payload.get("companies", [])]
|
||||
return companies
|
||||
|
||||
|
||||
def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict:
|
||||
"""Build person nodes keyed by name."""
|
||||
nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")}
|
||||
|
|
@ -176,10 +168,10 @@ def attach_employees_to_departments(
|
|||
target.employees = employees
|
||||
|
||||
|
||||
def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
|
||||
def build_companies(data: Data) -> list[Company]:
|
||||
"""Build company nodes from payloads."""
|
||||
people = collect_people(payloads)
|
||||
companies = collect_companies(payloads)
|
||||
people = data.people
|
||||
companies = data.companies
|
||||
people_nodes = build_people_nodes(people)
|
||||
groups = group_people_by_department(people)
|
||||
dept_names = collect_declared_departments(groups, companies)
|
||||
|
|
@ -192,19 +184,29 @@ def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]:
|
|||
return result
|
||||
|
||||
|
||||
def load_default_payload() -> list[Mapping[str, Any]]:
|
||||
class Data(BaseModel):
|
||||
id: UUID
|
||||
companies: List[Dict[str, Any]]
|
||||
people: List[Dict[str, Any]]
|
||||
|
||||
|
||||
def load_default_payload() -> Data:
|
||||
"""Load the default payload from data files."""
|
||||
companies = load_json_file(COMPANIES_JSON)
|
||||
people = load_json_file(PEOPLE_JSON)
|
||||
payload = [{"companies": companies, "people": people}]
|
||||
return payload
|
||||
|
||||
data = Data(
|
||||
id=uuid4(),
|
||||
companies=companies,
|
||||
people=people,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ingest_payloads(data: List[Any] | None) -> list[Company]:
|
||||
def ingest_payloads(data: List[Data]) -> list[Company]:
|
||||
"""Ingest payloads and build company nodes."""
|
||||
if not data or data == [None]:
|
||||
data = load_default_payload()
|
||||
companies = build_companies(data)
|
||||
companies = build_companies(data[0])
|
||||
return companies
|
||||
|
||||
|
||||
|
|
@ -221,18 +223,16 @@ async def execute_pipeline() -> None:
|
|||
await setup()
|
||||
|
||||
# Get user and dataset
|
||||
user = await get_default_user()
|
||||
datasets = await load_or_create_datasets(["demo_dataset"], [], user)
|
||||
dataset_id = datasets[0].id
|
||||
user: User = await get_default_user() # type: ignore
|
||||
dataset = Dataset(id=uuid4(), name="demo_dataset")
|
||||
data = load_default_payload()
|
||||
|
||||
# Build and run pipeline
|
||||
tasks = [Task(ingest_payloads), Task(add_data_points)]
|
||||
pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline")
|
||||
pipeline = run_tasks(tasks, dataset, [data], user, "demo_pipeline")
|
||||
async for status in pipeline:
|
||||
logging.info("Pipeline status: %s", status)
|
||||
|
||||
# Post-process: index graph edges and visualize
|
||||
await index_graph_edges()
|
||||
await visualize_graph(str(GRAPH_HTML))
|
||||
|
||||
# Run query against graph
|
||||
|
|
|
|||
|
|
@ -85,13 +85,13 @@ async def run_code_graph_pipeline(
|
|||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline"
|
||||
non_code_tasks, dataset, repo_path, user, "cognify_pipeline"
|
||||
)
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
tasks, dataset, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from cognee.modules.ontology.get_default_ontology_resolver import (
|
|||
get_ontology_resolver_from_env,
|
||||
)
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
|
|
@ -208,6 +209,9 @@ async def cognify(
|
|||
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||
}
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if temporal_cognify:
|
||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,8 @@
|
|||
import inspect
|
||||
from functools import wraps
|
||||
from uuid import UUID
|
||||
from abc import abstractmethod, ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -19,121 +14,6 @@ EdgeData = Tuple[
|
|||
Node = Tuple[str, NodeData] # (node_id, properties)
|
||||
|
||||
|
||||
def record_graph_changes(func):
|
||||
"""
|
||||
Decorator to record graph changes in the relationship database.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: The asynchronous function to wrap, which likely modifies graph data.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapped function that manages database relationships.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
"""
|
||||
Wraps the given asynchronous function to handle database relationships.
|
||||
|
||||
Tracks the caller's function and class name for context. When the wrapped function is
|
||||
called, it manages database relationships for nodes or edges by adding entries to a
|
||||
ledger and committing the changes to the database session. Errors during relationship
|
||||
addition or session commit are logged and will not disrupt the execution of the wrapped
|
||||
function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function call.
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
||||
caller_frame = frame.f_back
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
caller_name = caller_frame.f_code.co_name
|
||||
caller_class = (
|
||||
caller_frame.f_locals.get("self", None).__class__.__name__
|
||||
if caller_frame.f_locals.get("self", None)
|
||||
else None
|
||||
)
|
||||
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
||||
|
||||
result = await func(self, *args, **kwargs)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
if func.__name__ == "add_nodes":
|
||||
nodes: List[DataPoint] = args[0]
|
||||
|
||||
relationship_ledgers = []
|
||||
|
||||
for node in nodes:
|
||||
node_id = UUID(str(node.id))
|
||||
relationship_ledgers.append(
|
||||
GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=node_id,
|
||||
destination_node_id=node_id,
|
||||
creator_function=f"{creator}.node",
|
||||
node_label=getattr(node, "name", None) or str(node.id),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
session.add_all(relationship_ledgers)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
|
||||
elif func.__name__ == "add_edges":
|
||||
edges = args[0]
|
||||
|
||||
relationship_ledgers = []
|
||||
|
||||
for edge in edges:
|
||||
source_id = UUID(str(edge[0]))
|
||||
target_id = UUID(str(edge[1]))
|
||||
rel_type = str(edge[2])
|
||||
relationship_ledgers.append(
|
||||
GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=source_id,
|
||||
destination_node_id=target_id,
|
||||
creator_function=f"{creator}.{rel_type}",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
session.add_all(relationship_ledgers)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error committing session: {e}")
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class GraphDBInterface(ABC):
|
||||
"""
|
||||
Define an interface for graph database operations to be implemented by concrete classes.
|
||||
|
|
@ -189,7 +69,6 @@ class GraphDBInterface(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the graph in a single operation.
|
||||
|
|
@ -273,7 +152,6 @@ class GraphDBInterface(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@record_graph_changes
|
||||
async def add_edges(
|
||||
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from cognee.infrastructure.utils.run_sync import run_sync
|
|||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
|
@ -378,7 +377,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add node: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the graph in a batch operation.
|
||||
|
|
@ -675,7 +673,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add edge: {e}")
|
||||
raise
|
||||
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
"""
|
||||
Add multiple edges in a batch operation.
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from cognee.tasks.temporal_graph.models import Timestamp
|
|||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
)
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
|
|
@ -175,7 +174,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_nodes)
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
"""
|
||||
|
|
@ -446,7 +444,6 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return flattened
|
||||
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_edges)
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from uuid import UUID
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
NodeData,
|
||||
EdgeData,
|
||||
Node,
|
||||
|
|
@ -229,7 +228,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
logger.error(f"Failed to add node {node.id}: {error_msg}")
|
||||
raise Exception(f"Failed to add node: {error_msg}") from e
|
||||
|
||||
@record_graph_changes
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""
|
||||
Add multiple nodes to the graph in a single operation.
|
||||
|
|
@ -534,7 +532,6 @@ class NeptuneGraphDB(GraphDBInterface):
|
|||
logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}")
|
||||
raise Exception(f"Failed to add edge: {error_msg}") from e
|
||||
|
||||
@record_graph_changes
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None:
|
||||
"""
|
||||
Add multiple edges to the graph in a single operation.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from os import path
|
||||
from uuid import UUID
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
|
@ -282,7 +283,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, get_type_hints
|
||||
from uuid import UUID
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
|
@ -384,7 +385,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
async with self.get_async_session() as session:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List, Protocol, Optional, Union, Any
|
||||
from typing import List, Protocol, Optional, Any
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from .models.PayloadSchema import PayloadSchema
|
||||
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
|
|
@ -127,9 +127,7 @@ class VectorDBInterface(Protocol):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete_data_points(
|
||||
self, collection_name: str, data_point_ids: Union[List[str], list[str]]
|
||||
):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[UUID]):
|
||||
"""
|
||||
Delete specified data points from a collection.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import pickle
|
||||
from uuid import UUID, uuid4
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -28,8 +27,6 @@ class DataPoint(BaseModel):
|
|||
- update_version
|
||||
- to_json
|
||||
- from_json
|
||||
- to_pickle
|
||||
- from_pickle
|
||||
- to_dict
|
||||
- from_dict
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset:
|
||||
owner_id = user.id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .generate_node_id import generate_node_id
|
||||
from .generate_edge_id import generate_edge_id
|
||||
from .generate_node_name import generate_node_name
|
||||
from .generate_edge_name import generate_edge_name
|
||||
from .generate_event_datapoint import generate_event_datapoint
|
||||
|
|
|
|||
|
|
@ -1 +1,8 @@
|
|||
from .get_formatted_graph_data import get_formatted_graph_data
|
||||
from .upsert_edges import upsert_edges
|
||||
from .upsert_nodes import upsert_nodes
|
||||
from .get_data_related_nodes import get_data_related_nodes
|
||||
from .delete_data_related_nodes import delete_data_related_nodes
|
||||
from .delete_data_related_edges import delete_data_related_edges
|
||||
from .get_data_related_edges import get_data_related_edges
|
||||
from .delete_data_nodes_and_edges import delete_data_nodes_and_edges
|
||||
|
|
|
|||
43
cognee/modules/graph/methods/delete_data_nodes_and_edges.py
Normal file
43
cognee/modules/graph/methods/delete_data_nodes_and_edges.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from uuid import UUID
|
||||
from typing import Dict, List
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
from cognee.modules.graph.methods import (
|
||||
delete_data_related_edges,
|
||||
delete_data_related_nodes,
|
||||
get_data_related_nodes,
|
||||
get_data_related_edges,
|
||||
)
|
||||
|
||||
|
||||
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None:
|
||||
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
|
||||
|
||||
affected_vector_collections: Dict[str, List] = {}
|
||||
for node in affected_nodes:
|
||||
for indexed_field in node.indexed_fields:
|
||||
collection_name = f"{node.type}_{indexed_field}"
|
||||
if collection_name not in affected_vector_collections:
|
||||
affected_vector_collections[collection_name] = []
|
||||
affected_vector_collections[collection_name].append(node)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
for affected_collection, affected_nodes in affected_vector_collections.items():
|
||||
await vector_engine.delete_data_points(
|
||||
affected_collection, [node.id for node in affected_nodes]
|
||||
)
|
||||
|
||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||
|
||||
await vector_engine.delete_data_points(
|
||||
"EdgeType_relationship_name",
|
||||
[generate_edge_id(edge.relationship_name) for edge in affected_relationships],
|
||||
)
|
||||
|
||||
await delete_data_related_nodes(data_id)
|
||||
await delete_data_related_edges(data_id)
|
||||
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_edges.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
|
||||
nodes = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()
|
||||
|
||||
await session.execute(delete(Edge).where(Edge.id.in_([node.id for node in nodes])))
|
||||
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
13
cognee/modules/graph/methods/delete_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
|
||||
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()
|
||||
|
||||
await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
|
||||
17
cognee/modules/graph/methods/get_data_related_edges.py
Normal file
17
cognee/modules/graph/methods/get_data_related_edges.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: AsyncSession):
|
||||
return (
|
||||
await session.scalars(
|
||||
select(Edge)
|
||||
.where(Edge.data_id == data_id, Edge.dataset_id == dataset_id)
|
||||
.distinct(Edge.relationship_name)
|
||||
)
|
||||
).all()
|
||||
26
cognee/modules/graph/methods/get_data_related_nodes.py
Normal file
26
cognee/modules/graph/methods/get_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import and_, exists, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession):
|
||||
NodeAlias = Node.__table__.alias("n2")
|
||||
|
||||
subq = select(NodeAlias.c.id).where(
|
||||
and_(
|
||||
NodeAlias.c.slug == Node.slug,
|
||||
NodeAlias.c.dataset_id == dataset_id,
|
||||
NodeAlias.c.data_id != data_id,
|
||||
)
|
||||
)
|
||||
|
||||
query_statement = select(Node).where(
|
||||
and_(Node.data_id == data_id, Node.dataset_id == dataset_id, ~exists(subq))
|
||||
)
|
||||
|
||||
data_related_nodes = await session.scalars(query_statement)
|
||||
return data_related_nodes.all()
|
||||
8
cognee/modules/graph/methods/set_current_user.py
Normal file
8
cognee/modules/graph/methods/set_current_user.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import uuid
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def set_current_user(session: AsyncSession, user_id: uuid.UUID, local: bool = False):
|
||||
scope = "LOCAL " if local else ""
|
||||
await session.execute(text(f"SET {scope}app.current_user_id = '{user_id}'"))
|
||||
57
cognee/modules/graph/methods/upsert_edges.py
Normal file
57
cognee/modules/graph/methods/upsert_edges.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models.Edge import Edge
|
||||
from .set_current_user import set_current_user
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def upsert_edges(
|
||||
edges: List[Tuple[UUID, UUID, str, Dict[str, Any]]],
|
||||
user_id: UUID,
|
||||
data_id: UUID,
|
||||
dataset_id: UUID,
|
||||
session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Adds edges to the edges table.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- edges (list): A list of edges to be added to the graph.
|
||||
"""
|
||||
if session.get_bind().dialect.name == "postgresql":
|
||||
# Set the session-level RLS variable
|
||||
await set_current_user(session, user_id)
|
||||
|
||||
upsert_statement = (
|
||||
insert(Edge)
|
||||
.values(
|
||||
[
|
||||
{
|
||||
"id": uuid5(
|
||||
NAMESPACE_OID,
|
||||
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
|
||||
),
|
||||
"user_id": user_id,
|
||||
"data_id": data_id,
|
||||
"dataset_id": dataset_id,
|
||||
"source_node_id": edge[0],
|
||||
"destination_node_id": edge[1],
|
||||
"relationship_name": edge[2],
|
||||
"label": edge[2],
|
||||
"props": jsonable_encoder(edge[3]),
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["id"])
|
||||
)
|
||||
|
||||
await session.execute(upsert_statement)
|
||||
|
||||
await session.commit()
|
||||
51
cognee/modules/graph/methods/upsert_nodes.py
Normal file
51
cognee/modules/graph/methods/upsert_nodes.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import List
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
from .set_current_user import set_current_user
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def upsert_nodes(
|
||||
nodes: List[DataPoint], user_id: UUID, dataset_id: UUID, data_id: UUID, session: AsyncSession
|
||||
):
|
||||
"""
|
||||
Adds nodes to the nodes table.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- nodes (list): A list of nodes to be added to the graph.
|
||||
"""
|
||||
if session.get_bind().dialect.name == "postgresql":
|
||||
# Set the session-level RLS variable
|
||||
await set_current_user(session, user_id)
|
||||
|
||||
upsert_statement = (
|
||||
insert(Node)
|
||||
.values(
|
||||
[
|
||||
{
|
||||
"id": uuid5(
|
||||
NAMESPACE_OID, str(user_id) + str(dataset_id) + str(data_id) + str(node.id)
|
||||
),
|
||||
"slug": node.id,
|
||||
"user_id": user_id,
|
||||
"data_id": data_id,
|
||||
"dataset_id": dataset_id,
|
||||
"type": node.type,
|
||||
"indexed_fields": DataPoint.get_embeddable_property_names(node),
|
||||
"label": getattr(node, "label", getattr(node, "name", str(node.id))),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["id"])
|
||||
)
|
||||
|
||||
await session.execute(upsert_statement)
|
||||
|
||||
await session.commit()
|
||||
50
cognee/modules/graph/models/Edge.py
Normal file
50
cognee/modules/graph/models/Edge.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from sqlalchemy import (
|
||||
# event,
|
||||
String,
|
||||
JSON,
|
||||
UUID,
|
||||
)
|
||||
|
||||
# from sqlalchemy.schema import DDL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Edge(Base):
|
||||
__tablename__ = "edges"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)
|
||||
|
||||
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
relationship_name: Mapped[str | None] = mapped_column(String(255))
|
||||
|
||||
label: Mapped[str | None] = mapped_column(String(255))
|
||||
props: Mapped[dict | None] = mapped_column(JSON)
|
||||
|
||||
# __table_args__ = (
|
||||
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
|
||||
# )
|
||||
|
||||
|
||||
# Enable row-level security (RLS) for edges
|
||||
# enable_edge_rls = DDL("""
|
||||
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
|
||||
# """)
|
||||
# create_user_isolation_policy = DDL("""
|
||||
# CREATE POLICY user_isolation_policy
|
||||
# ON edges
|
||||
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||
# """)
|
||||
|
||||
# event.listen(Edge.__table__, "after_create", enable_edge_rls)
|
||||
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)
|
||||
53
cognee/modules/graph/models/Node.py
Normal file
53
cognee/modules/graph/models/Node.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from sqlalchemy import (
|
||||
Index,
|
||||
# event,
|
||||
String,
|
||||
JSON,
|
||||
UUID,
|
||||
)
|
||||
|
||||
# from sqlalchemy.schema import DDL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class Node(Base):
|
||||
__tablename__ = "nodes"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
|
||||
slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
label: Mapped[str] = mapped_column(String(255))
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
indexed_fields: Mapped[list] = mapped_column(JSON)
|
||||
|
||||
# props: Mapped[dict] = mapped_column(JSON)
|
||||
|
||||
__table_args__ = (
|
||||
Index("index_node_dataset_slug", "dataset_id", "slug"),
|
||||
Index("index_node_dataset_data", "dataset_id", "data_id"),
|
||||
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
|
||||
)
|
||||
|
||||
|
||||
# Enable row-level security (RLS) for nodes
|
||||
# enable_node_rls = DDL("""
|
||||
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
|
||||
# """)
|
||||
# create_user_isolation_policy = DDL("""
|
||||
# CREATE POLICY user_isolation_policy
|
||||
# ON nodes
|
||||
# USING (user_id = current_setting('app.current_user_id')::uuid)
|
||||
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
|
||||
# """)
|
||||
|
||||
# event.listen(Node.__table__, "after_create", enable_node_rls)
|
||||
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)
|
||||
2
cognee/modules/graph/models/__init__.py
Normal file
2
cognee/modules/graph/models/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .Edge import Edge
|
||||
from .Node import Node
|
||||
|
|
@ -134,8 +134,8 @@ def _targets_generator(
|
|||
|
||||
async def get_graph_from_model(
|
||||
data_point: DataPoint,
|
||||
added_nodes: Dict[str, bool],
|
||||
added_edges: Dict[str, bool],
|
||||
added_nodes: Optional[Dict[str, bool]] = None,
|
||||
added_edges: Optional[Dict[str, bool]] = None,
|
||||
visited_properties: Optional[Dict[str, bool]] = None,
|
||||
include_root: bool = True,
|
||||
) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
||||
|
|
@ -152,6 +152,12 @@ async def get_graph_from_model(
|
|||
Returns:
|
||||
Tuple of (nodes, edges) extracted from the model
|
||||
"""
|
||||
if not added_nodes:
|
||||
added_nodes = {}
|
||||
|
||||
if not added_edges:
|
||||
added_edges = {}
|
||||
|
||||
if str(data_point.id) in added_nodes:
|
||||
return [], []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from cognee.modules.pipelines.layers.setup_and_check_environment import (
|
||||
setup_and_check_environment,
|
||||
|
|
@ -29,12 +29,13 @@ update_status_lock = asyncio.Lock()
|
|||
async def run_pipeline(
|
||||
tasks: list[Task],
|
||||
data=None,
|
||||
datasets: Union[str, list[str], list[UUID]] = None,
|
||||
user: User = None,
|
||||
datasets: Optional[Union[str, list[str], list[UUID]]] = None,
|
||||
user: Optional[User] = None,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: Optional[dict] = None,
|
||||
graph_db_config: Optional[dict] = None,
|
||||
incremental_loading: bool = False,
|
||||
context: Optional[Dict] = None,
|
||||
):
|
||||
validate_pipeline_tasks(tasks)
|
||||
await setup_and_check_environment(vector_db_config, graph_db_config)
|
||||
|
|
@ -48,8 +49,8 @@ async def run_pipeline(
|
|||
tasks=tasks,
|
||||
data=data,
|
||||
pipeline_name=pipeline_name,
|
||||
context={"dataset": dataset},
|
||||
incremental_loading=incremental_loading,
|
||||
context=context,
|
||||
):
|
||||
yield run_info
|
||||
|
||||
|
|
@ -58,16 +59,16 @@ async def run_pipeline_per_dataset(
|
|||
dataset: Dataset,
|
||||
user: User,
|
||||
tasks: list[Task],
|
||||
data=None,
|
||||
data: Optional[list[Data]] = None,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
context: dict = None,
|
||||
incremental_loading=False,
|
||||
context: Optional[Dict] = None,
|
||||
):
|
||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
if not data:
|
||||
data: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
||||
data = await get_dataset_data(dataset_id=dataset.id)
|
||||
|
||||
process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name)
|
||||
if process_pipeline_status:
|
||||
|
|
@ -77,7 +78,7 @@ async def run_pipeline_per_dataset(
|
|||
return
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
||||
tasks, dataset, data, user, pipeline_name, context, incremental_loading
|
||||
)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import Any, List
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
|
|||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
||||
from .run_tasks_data_item import run_tasks_data_item
|
||||
from ..tasks.task import Task
|
||||
|
||||
|
|
@ -54,25 +53,18 @@ def override_run_tasks(new_gen):
|
|||
@override_run_tasks(run_tasks_distributed)
|
||||
async def run_tasks(
|
||||
tasks: List[Task],
|
||||
dataset_id: UUID,
|
||||
data: List[Any] = None,
|
||||
user: User = None,
|
||||
dataset: Dataset,
|
||||
data: Optional[List[Any]] = None,
|
||||
user: Optional[User] = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
context: Optional[Dict] = None,
|
||||
incremental_loading: bool = False,
|
||||
):
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
# Get Dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset.id, data)
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
|
|
@ -99,7 +91,12 @@ async def run_tasks(
|
|||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
{
|
||||
**(context or {}),
|
||||
"user": user,
|
||||
"data": data_item,
|
||||
"dataset": dataset,
|
||||
},
|
||||
user,
|
||||
incremental_loading,
|
||||
)
|
||||
|
|
@ -121,7 +118,7 @@ async def run_tasks(
|
|||
)
|
||||
|
||||
await log_pipeline_run_complete(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data
|
||||
)
|
||||
|
||||
yield PipelineRunCompleted(
|
||||
|
|
@ -141,7 +138,7 @@ async def run_tasks(
|
|||
|
||||
except Exception as error:
|
||||
await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data, error
|
||||
)
|
||||
|
||||
yield PipelineRunErrored(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import asyncio
|
||||
from typing import Type, List, Optional
|
||||
from typing import Dict, Type, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
|
||||
from cognee.tasks.storage import index_graph_edges
|
||||
from cognee.tasks.storage.add_data_points import add_data_points
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.get_default_ontology_resolver import (
|
||||
|
|
@ -32,6 +31,7 @@ async def integrate_chunk_graphs(
|
|||
chunk_graphs: list,
|
||||
graph_model: Type[BaseModel],
|
||||
ontology_resolver: BaseOntologyResolver,
|
||||
context: Dict,
|
||||
) -> List[DocumentChunk]:
|
||||
"""Integrate chunk graphs with ontology validation and store in databases.
|
||||
|
||||
|
|
@ -85,19 +85,19 @@ async def integrate_chunk_graphs(
|
|||
)
|
||||
|
||||
if len(graph_nodes) > 0:
|
||||
await add_data_points(graph_nodes)
|
||||
await add_data_points(graph_nodes, context)
|
||||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
await index_graph_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
async def extract_graph_from_data(
|
||||
data_chunks: List[DocumentChunk],
|
||||
context: Dict,
|
||||
graph_model: Type[BaseModel],
|
||||
config: Config = None,
|
||||
config: Optional[Config] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
|
|
@ -136,16 +136,16 @@ async def extract_graph_from_data(
|
|||
and ontology_config.ontology_resolver
|
||||
and ontology_config.matching_strategy
|
||||
):
|
||||
config: Config = {
|
||||
config = {
|
||||
"ontology_config": {
|
||||
"ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict())
|
||||
}
|
||||
}
|
||||
else:
|
||||
config: Config = {
|
||||
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||
}
|
||||
config = {"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}}
|
||||
|
||||
ontology_resolver = config["ontology_config"]["ontology_resolver"]
|
||||
|
||||
return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver)
|
||||
return await integrate_chunk_graphs(
|
||||
data_chunks, chunk_graphs, graph_model, ontology_resolver, context
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.graph.methods import upsert_edges, upsert_nodes
|
||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||
from cognee.modules.users.models import User
|
||||
from .index_data_points import index_data_points
|
||||
from .index_graph_edges import index_graph_edges
|
||||
from cognee.tasks.storage.exceptions import (
|
||||
|
|
@ -10,7 +12,9 @@ from cognee.tasks.storage.exceptions import (
|
|||
)
|
||||
|
||||
|
||||
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
||||
async def add_data_points(
|
||||
data_points: List[DataPoint], context: Optional[Dict[str, Any]] = None
|
||||
) -> List[DataPoint]:
|
||||
"""
|
||||
Add a batch of data points to the graph database by extracting nodes and edges,
|
||||
deduplicating them, and indexing them for retrieval.
|
||||
|
|
@ -33,8 +37,15 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|||
- Deduplicates nodes and edges across all results.
|
||||
- Updates the node index via `index_data_points`.
|
||||
- Inserts nodes and edges into the graph engine.
|
||||
- Optionally updates the edge index via `index_graph_edges`.
|
||||
"""
|
||||
user: Optional[User] = None
|
||||
data = None
|
||||
dataset = None
|
||||
|
||||
if context:
|
||||
data = context["data"]
|
||||
dataset = context["dataset"]
|
||||
user = context["user"]
|
||||
|
||||
if not isinstance(data_points, list):
|
||||
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
|
||||
|
|
@ -71,6 +82,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
|
|||
await graph_engine.add_nodes(nodes)
|
||||
await index_data_points(nodes)
|
||||
|
||||
if user and dataset and data:
|
||||
await upsert_nodes(nodes, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
|
||||
await upsert_edges(edges, user_id=user.id, dataset_id=dataset.id, data_id=data.id)
|
||||
|
||||
await graph_engine.add_edges(edges)
|
||||
await index_graph_edges(edges)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,26 @@
|
|||
import os
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.modules.data.models import Data
|
||||
|
||||
async def build_graph_with_temporal_awareness(text_list):
|
||||
url = os.getenv("GRAPH_DATABASE_URL")
|
||||
password = os.getenv("GRAPH_DATABASE_PASSWORD")
|
||||
|
||||
async def build_graph_with_temporal_awareness(data: List[Data]):
|
||||
text_list: List[str] = []
|
||||
|
||||
for text_data in data:
|
||||
file_dir = os.path.dirname(text_data.raw_data_location)
|
||||
file_name = os.path.basename(text_data.raw_data_location)
|
||||
file_storage = get_file_storage(file_dir)
|
||||
async with file_storage.open(file_name, "r") as file:
|
||||
text_list.append(file.read())
|
||||
|
||||
url = os.getenv("GRAPH_DATABASE_URL", "")
|
||||
password = os.getenv("GRAPH_DATABASE_PASSWORD", "")
|
||||
graphiti = Graphiti(url, "neo4j", password)
|
||||
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
|
@ -22,4 +35,5 @@ async def build_graph_with_temporal_awareness(text_list):
|
|||
reference_time=datetime.now(),
|
||||
)
|
||||
print(f"Added text: {text[:35]}...")
|
||||
|
||||
return graphiti
|
||||
|
|
|
|||
107
cognee/tests/test_delete_custom_graph.py
Normal file
107
cognee/tests/test_delete_custom_graph.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
import pathlib
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.graph.methods import delete_data_nodes_and_edges
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_custom_graph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_custom_graph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Organization(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class ForProfit(Organization):
|
||||
name: str = "For-Profit"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class NonProfit(Organization):
|
||||
name: str = "Non-Profit"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: List[Organization]
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
companyA = ForProfit(name="Company A")
|
||||
companyB = NonProfit(name="Company B")
|
||||
|
||||
person1 = Person(name="John", works_for=[companyA, companyB])
|
||||
person2 = Person(name="Jane", works_for=[companyB])
|
||||
|
||||
user: User = await get_default_user() # type: ignore
|
||||
|
||||
dataset = Dataset(id=uuid4())
|
||||
data1 = Data(id=uuid4())
|
||||
data2 = Data(id=uuid4())
|
||||
|
||||
await add_data_points(
|
||||
[person1],
|
||||
context={
|
||||
"user": user,
|
||||
"dataset": dataset,
|
||||
"data": data1,
|
||||
},
|
||||
)
|
||||
|
||||
await add_data_points(
|
||||
[person2],
|
||||
context={
|
||||
"user": user,
|
||||
"dataset": dataset,
|
||||
"data": data2,
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 4 and len(edges) == 3, (
|
||||
"Nodes and edges are not correctly added to the graph."
|
||||
)
|
||||
|
||||
await delete_data_nodes_and_edges(dataset.id, data1.id) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
||||
|
||||
await delete_data_nodes_and_edges(dataset.id, data2.id) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
77
cognee/tests/test_delete_default_graph.py
Normal file
77
cognee/tests/test_delete_default_graph.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.data.methods import delete_data, get_dataset_data
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.graph.methods import (
|
||||
delete_data_nodes_and_edges,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_default_graph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
assert not await vector_engine.has_collection("Entity_name")
|
||||
|
||||
await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
dataset_data = await get_dataset_data(dataset_id)
|
||||
added_data = dataset_data[0]
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted."
|
||||
|
||||
await delete_data_nodes_and_edges(dataset_id, added_data.id) # type: ignore
|
||||
|
||||
await delete_data(added_data)
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) >= 8 and len(edges) >= 10, "Nodes and edges are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -100,7 +100,7 @@ async def main():
|
|||
|
||||
pipeline = run_tasks(
|
||||
[Task(ingest_files), Task(add_data_points)],
|
||||
dataset_id=datasets[0].id,
|
||||
dataset=datasets[0],
|
||||
data=data,
|
||||
incremental_loading=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
from uuid import NAMESPACE_OID, UUID, uuid4, uuid5
|
||||
from neo4j import exceptions
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee import prune
|
||||
|
||||
# from cognee import visualize_graph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
|
@ -20,7 +25,6 @@ products_aggregator_node = Products()
|
|||
|
||||
|
||||
class Product(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
price: float
|
||||
|
|
@ -36,7 +40,6 @@ preferences_aggregator_node = Preferences()
|
|||
|
||||
|
||||
class Preference(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
value: str
|
||||
is_type: Preferences = preferences_aggregator_node
|
||||
|
|
@ -50,7 +53,6 @@ customers_aggregator_node = Customers()
|
|||
|
||||
|
||||
class Customer(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
has_preference: list[Preference]
|
||||
purchased: list[Product]
|
||||
|
|
@ -58,17 +60,14 @@ class Customer(DataPoint):
|
|||
is_type: Customers = customers_aggregator_node
|
||||
|
||||
|
||||
def ingest_files():
|
||||
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
|
||||
customers = json.loads(open(customers_file_path, "r").read())
|
||||
|
||||
def ingest_customers(data):
|
||||
customers_data_points = {}
|
||||
products_data_points = {}
|
||||
preferences_data_points = {}
|
||||
|
||||
for customer in customers:
|
||||
for customer in data[0].customers:
|
||||
new_customer = Customer(
|
||||
id=customer["id"],
|
||||
id=uuid5(NAMESPACE_OID, customer["id"]),
|
||||
name=customer["name"],
|
||||
liked=[],
|
||||
purchased=[],
|
||||
|
|
@ -79,7 +78,7 @@ def ingest_files():
|
|||
for product in customer["products"]:
|
||||
if product["id"] not in products_data_points:
|
||||
products_data_points[product["id"]] = Product(
|
||||
id=product["id"],
|
||||
id=uuid5(NAMESPACE_OID, product["id"]),
|
||||
type=product["type"],
|
||||
name=product["name"],
|
||||
price=product["price"],
|
||||
|
|
@ -96,7 +95,7 @@ def ingest_files():
|
|||
for preference in customer["preferences"]:
|
||||
if preference["id"] not in preferences_data_points:
|
||||
preferences_data_points[preference["id"]] = Preference(
|
||||
id=preference["id"],
|
||||
id=uuid5(NAMESPACE_OID, preference["id"]),
|
||||
name=preference["name"],
|
||||
value=preference["value"],
|
||||
)
|
||||
|
|
@ -104,7 +103,7 @@ def ingest_files():
|
|||
new_preference = preferences_data_points[preference["id"]]
|
||||
new_customer.has_preference.append(new_preference)
|
||||
|
||||
return customers_data_points.values()
|
||||
return list(customers_data_points.values())
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -113,7 +112,28 @@ async def main():
|
|||
|
||||
await setup()
|
||||
|
||||
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
|
||||
# Get user and dataset
|
||||
user: User = await get_default_user() # type: ignore
|
||||
main_dataset = Dataset(id=uuid4(), name="demo_dataset")
|
||||
|
||||
customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json")
|
||||
customers = json.loads(open(customers_file_path, "r").read())
|
||||
|
||||
class Data(BaseModel):
|
||||
id: UUID
|
||||
customers: List[Dict]
|
||||
|
||||
data = Data(
|
||||
id=uuid4(),
|
||||
customers=customers,
|
||||
)
|
||||
|
||||
pipeline = run_tasks(
|
||||
[Task(ingest_customers), Task(add_data_points)],
|
||||
dataset=main_dataset,
|
||||
data=[data],
|
||||
user=user,
|
||||
)
|
||||
|
||||
async for status in pipeline:
|
||||
print(status)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee.modules.data.methods import get_dataset_data, get_datasets
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
from cognee.modules.pipelines import Task, run_tasks
|
||||
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
|
||||
|
|
@ -35,10 +36,13 @@ async def main():
|
|||
await cognee.add(text)
|
||||
|
||||
tasks = [
|
||||
Task(build_graph_with_temporal_awareness, text_list=text_list),
|
||||
Task(build_graph_with_temporal_awareness),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, user=user)
|
||||
datasets = await get_datasets(user.id)
|
||||
dataset_data = await get_dataset_data(datasets[0].id) # type: ignore
|
||||
|
||||
pipeline = run_tasks(tasks, dataset=datasets[0], data=dataset_data, user=user)
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
|
|
|||
8
notebooks/cognee_demo.ipynb
vendored
8
notebooks/cognee_demo.ipynb
vendored
|
|
@ -483,7 +483,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "7c431fdef4921ae0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -535,7 +535,7 @@
|
|||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
|
||||
" pipeline_run = run_tasks(tasks, dataset, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n",
|
||||
" pipeline_run_status = None\n",
|
||||
"\n",
|
||||
" async for run_status in pipeline_run:\n",
|
||||
|
|
@ -1831,7 +1831,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "cognee",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -1845,7 +1845,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
7
poetry.lock
generated
7
poetry.lock
generated
|
|
@ -9310,6 +9310,13 @@ optional = false
|
|||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -856,7 +856,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue