diff --git a/cognee-frontend/src/app/dashboard/Dashboard.tsx b/cognee-frontend/src/app/dashboard/Dashboard.tsx index 75e3d7518..fca5bc821 100644 --- a/cognee-frontend/src/app/dashboard/Dashboard.tsx +++ b/cognee-frontend/src/app/dashboard/Dashboard.tsx @@ -108,6 +108,13 @@ export default function Dashboard({ accessToken }: DashboardProps) { setDatasets(datasets); }, []); + const [searchValue, setSearchValue] = useState(""); + + const handleSearchDatasetInputChange = useCallback((event: React.ChangeEvent) => { + const newSearchValue = event.currentTarget.value; + setSearchValue(newSearchValue); + }, []); + const isCloudEnv = isCloudEnvironment(); return ( @@ -129,7 +136,7 @@ export default function Dashboard({ accessToken }: DashboardProps) {
- +
diff --git a/cognee-frontend/src/app/dashboard/DatasetsAccordion.tsx b/cognee-frontend/src/app/dashboard/DatasetsAccordion.tsx index 0c764ef92..6aebba99e 100644 --- a/cognee-frontend/src/app/dashboard/DatasetsAccordion.tsx +++ b/cognee-frontend/src/app/dashboard/DatasetsAccordion.tsx @@ -19,11 +19,13 @@ interface DatasetsChangePayload { export interface DatasetsAccordionProps extends Omit { onDatasetsChange?: (payload: DatasetsChangePayload) => void; useCloud?: boolean; + searchValue: string; } export default function DatasetsAccordion({ title, tools, + searchValue, switchCaretPosition = false, className, contentClassName, @@ -43,13 +45,7 @@ export default function DatasetsAccordion({ removeDataset, getDatasetData, removeDatasetData, - } = useDatasets(useCloud); - - useEffect(() => { - if (datasets.length === 0) { - refreshDatasets(); - } - }, [datasets.length, refreshDatasets]); + } = useDatasets(useCloud, searchValue); const [openDatasets, openDataset] = useState>(new Set()); @@ -237,11 +233,16 @@ export default function DatasetsAccordion({ contentClassName={contentClassName} >
- {datasets.length === 0 && ( + {datasets.length === 0 && !searchValue && (
No datasets here, add one by clicking +
)} + {datasets.length === 0 && searchValue && ( +
+ No datasets found, please adjust your search term +
+ )} {datasets.map((dataset) => { return ( ; +interface InstanceDatasetsAccordionProps extends Omit { + searchValue: string; +} -export default function InstanceDatasetsAccordion({ onDatasetsChange }: InstanceDatasetsAccordionProps) { +export default function InstanceDatasetsAccordion({ searchValue, onDatasetsChange }: InstanceDatasetsAccordionProps) { const { value: isLocalCogneeConnected, setTrue: setLocalCogneeConnected, @@ -19,7 +21,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance const { value: isCloudCogneeConnected, setTrue: setCloudCogneeConnected, - } = useBoolean(isCloudEnvironment()); + } = useBoolean(isCloudEnvironment() || isCloudApiKeySet()); const checkConnectionToCloudCognee = useCallback((apiKey?: string) => { if (apiKey) { @@ -71,6 +73,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance
)} + searchValue={searchValue} tools={isLocalCogneeConnected ? Connected : Not connected} switchCaretPosition={true} className="pt-3 pb-1.5" @@ -88,6 +91,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance )} + searchValue={searchValue} tools={Connected} switchCaretPosition={true} className="pt-3 pb-1.5" diff --git a/cognee-frontend/src/modules/ingestion/useDatasets.ts b/cognee-frontend/src/modules/ingestion/useDatasets.ts index 6a125b591..7ff4f698a 100644 --- a/cognee-frontend/src/modules/ingestion/useDatasets.ts +++ b/cognee-frontend/src/modules/ingestion/useDatasets.ts @@ -1,4 +1,4 @@ -import { useCallback, useState } from 'react'; +import { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react'; import { fetch } from '@/utils'; import { DataFile } from './useData'; @@ -11,7 +11,20 @@ export interface Dataset { status: string; } -function useDatasets(useCloud = false) { +function filterDatasets(datasets: Dataset[], searchValue: string) { + if (searchValue.trim() === "") { + return datasets; + } + + const lowercaseSearchValue = searchValue.toLowerCase(); + + return datasets.filter((dataset) => + dataset.name.toLowerCase().includes(lowercaseSearchValue) + ); +} + +function useDatasets(useCloud = false, searchValue: string = "") { + const allDatasets = useRef([]); const [datasets, setDatasets] = useState([]); // eslint-disable-next-line @typescript-eslint/no-explicit-any // const statusTimeout = useRef(null); @@ -57,26 +70,32 @@ function useDatasets(useCloud = false) { // }; // }, []); + useLayoutEffect(() => { + setDatasets(filterDatasets(allDatasets.current, searchValue)); + }, [searchValue]); + const addDataset = useCallback((datasetName: string) => { return createDataset({ name: datasetName }, useCloud) .then((dataset) => { - setDatasets((datasets) => [ - ...datasets, + const newDatasets = [ + ...allDatasets.current, dataset, - ]); + ]; + allDatasets.current = newDatasets; + setDatasets(filterDatasets(newDatasets, searchValue)); }); - }, [useCloud]); + }, [searchValue, useCloud]); const removeDataset = useCallback((datasetId: string) => { return fetch(`/v1/datasets/${datasetId}`, { method: 'DELETE', }, useCloud) .then(() => { - setDatasets((datasets) => - datasets.filter((dataset) => dataset.id !== datasetId) - ); + const newDatasets = allDatasets.current.filter((dataset) => dataset.id !== datasetId) + allDatasets.current = newDatasets; + setDatasets(filterDatasets(newDatasets, searchValue)); }); - }, [useCloud]); + }, [searchValue, useCloud]); const fetchDatasets = useCallback(() => { return fetch('/v1/datasets', { @@ -86,7 +105,8 @@ function useDatasets(useCloud = false) { }, useCloud) .then((response) => response.json()) .then((datasets) => { - setDatasets(datasets); + allDatasets.current = datasets; + setDatasets(filterDatasets(datasets, searchValue)); // if (datasets.length > 0) { // checkDatasetStatuses(datasets); @@ -97,28 +117,38 @@ function useDatasets(useCloud = false) { .catch((error) => { console.error('Error fetching datasets:', error); }); - }, [useCloud]); + }, [searchValue, useCloud]); + + useEffect(() => { + if (allDatasets.current.length === 0) { + fetchDatasets(); + } + }, [fetchDatasets]); const getDatasetData = useCallback((datasetId: string) => { return fetch(`/v1/datasets/${datasetId}/data`, {}, useCloud) .then((response) => response.json()) .then((data) => { - const datasetIndex = datasets.findIndex((dataset) => dataset.id === datasetId); + const datasetIndex = allDatasets.current.findIndex((dataset) => dataset.id === datasetId); if (datasetIndex >= 0) { - setDatasets((datasets) => [ - ...datasets.slice(0, datasetIndex), - { - ...datasets[datasetIndex], - data, - }, - ...datasets.slice(datasetIndex + 1), - ]); + const newDatasets = [ + ...allDatasets.current.slice(0, datasetIndex), + { + ...allDatasets.current[datasetIndex], + data, + }, + ...allDatasets.current.slice(datasetIndex + 1), + ]; + + allDatasets.current = newDatasets; + + setDatasets(filterDatasets(newDatasets, searchValue)); } return data; }); - }, [datasets, useCloud]); + }, [searchValue, useCloud]); const removeDatasetData = useCallback((datasetId: string, dataId: string) => { return fetch(`/v1/datasets/${datasetId}/data/${dataId}`, { diff --git a/cognee-starter-kit/src/pipelines/low_level.py b/cognee-starter-kit/src/pipelines/low_level.py index 80f4a22e9..07bcb1687 100644 --- a/cognee-starter-kit/src/pipelines/low_level.py +++ b/cognee-starter-kit/src/pipelines/low_level.py @@ -7,15 +7,19 @@ import json import logging from collections import defaultdict from pathlib import Path -from typing import Any, Iterable, List, Mapping +from typing import Any, Dict, Iterable, List, Mapping +from uuid import UUID, uuid4 + +from pydantic import BaseModel from cognee import config, prune, search, SearchType, visualize_graph from cognee.low_level import setup, DataPoint +from cognee.modules.data.models import Dataset +from cognee.modules.users.models import User from cognee.pipelines import run_tasks, Task from cognee.tasks.storage import add_data_points from cognee.tasks.storage.index_graph_edges import index_graph_edges from cognee.modules.users.methods import get_default_user -from cognee.modules.data.methods import load_or_create_datasets class Person(DataPoint): @@ -76,18 +80,6 @@ def remove_duplicates_preserve_order(seq: Iterable[Any]) -> list[Any]: return out -def collect_people(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]: - """Collect people from payloads.""" - people = [person for payload in payloads for person in payload.get("people", [])] - return people - - -def collect_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Mapping[str, Any]]: - """Collect companies from payloads.""" - companies = [company for payload in payloads for company in payload.get("companies", [])] - return companies - - def build_people_nodes(people: Iterable[Mapping[str, Any]]) -> dict: """Build person nodes keyed by name.""" nodes = {p["name"]: Person(name=p["name"]) for p in people if p.get("name")} @@ -176,10 +168,10 @@ def attach_employees_to_departments( target.employees = employees -def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]: +def build_companies(data: Data) -> list[Company]: """Build company nodes from payloads.""" - people = collect_people(payloads) - companies = collect_companies(payloads) + people = data.people + companies = data.companies people_nodes = build_people_nodes(people) groups = group_people_by_department(people) dept_names = collect_declared_departments(groups, companies) @@ -192,19 +184,29 @@ def build_companies(payloads: Iterable[Mapping[str, Any]]) -> list[Company]: return result -def load_default_payload() -> list[Mapping[str, Any]]: +class Data(BaseModel): + id: UUID + companies: List[Dict[str, Any]] + people: List[Dict[str, Any]] + + +def load_default_payload() -> Data: """Load the default payload from data files.""" companies = load_json_file(COMPANIES_JSON) people = load_json_file(PEOPLE_JSON) - payload = [{"companies": companies, "people": people}] - return payload + + data = Data( + id=uuid4(), + companies=companies, + people=people, + ) + + return data -def ingest_payloads(data: List[Any] | None) -> list[Company]: +def ingest_payloads(data: List[Data]) -> list[Company]: """Ingest payloads and build company nodes.""" - if not data or data == [None]: - data = load_default_payload() - companies = build_companies(data) + companies = build_companies(data[0]) return companies @@ -221,18 +223,16 @@ async def execute_pipeline() -> None: await setup() # Get user and dataset - user = await get_default_user() - datasets = await load_or_create_datasets(["demo_dataset"], [], user) - dataset_id = datasets[0].id + user: User = await get_default_user() # type: ignore + dataset = Dataset(id=uuid4(), name="demo_dataset") + data = load_default_payload() # Build and run pipeline tasks = [Task(ingest_payloads), Task(add_data_points)] - pipeline = run_tasks(tasks, dataset_id, None, user, "demo_pipeline") + pipeline = run_tasks(tasks, dataset, [data], user, "demo_pipeline") async for status in pipeline: logging.info("Pipeline status: %s", status) - # Post-process: index graph edges and visualize - await index_graph_edges() await visualize_graph(str(GRAPH_HTML)) # Run query against graph diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index fb3612857..a534e9fbc 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -85,13 +85,13 @@ async def run_code_graph_pipeline( if include_docs: non_code_pipeline_run = run_tasks( - non_code_tasks, dataset.id, repo_path, user, "cognify_pipeline" + non_code_tasks, dataset, repo_path, user, "cognify_pipeline" ) async for run_status in non_code_pipeline_run: yield run_status async for run_status in run_tasks( - tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False + tasks, dataset, repo_path, user, "cognify_code_pipeline", incremental_loading=False ): yield run_status diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 1292d243a..29a2cf27c 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -17,6 +17,7 @@ from cognee.modules.ontology.get_default_ontology_resolver import ( get_ontology_resolver_from_env, ) from cognee.modules.users.models import User +from cognee.modules.users.methods import get_default_user from cognee.tasks.documents import ( check_permissions_on_dataset, @@ -208,6 +209,9 @@ async def cognify( "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } + if user is None: + user = await get_default_user() + if temporal_cognify: tasks = await get_temporal_tasks(user, chunker, chunk_size) else: diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 65afdf275..15f5e3df3 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -1,13 +1,8 @@ -import inspect -from functools import wraps +from uuid import UUID from abc import abstractmethod, ABC -from datetime import datetime, timezone from typing import Optional, Dict, Any, List, Tuple, Type, Union -from uuid import NAMESPACE_OID, UUID, uuid5 from cognee.shared.logging_utils import get_logger from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger -from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine logger = get_logger() @@ -19,121 +14,6 @@ EdgeData = Tuple[ Node = Tuple[str, NodeData] # (node_id, properties) -def record_graph_changes(func): - """ - Decorator to record graph changes in the relationship database. - - Parameters: - ----------- - - - func: The asynchronous function to wrap, which likely modifies graph data. - - Returns: - -------- - - Returns the wrapped function that manages database relationships. - """ - - @wraps(func) - async def wrapper(self, *args, **kwargs): - """ - Wraps the given asynchronous function to handle database relationships. - - Tracks the caller's function and class name for context. When the wrapped function is - called, it manages database relationships for nodes or edges by adding entries to a - ledger and committing the changes to the database session. Errors during relationship - addition or session commit are logged and will not disrupt the execution of the wrapped - function. - - Parameters: - ----------- - - - *args: Positional arguments passed to the wrapped function. - - **kwargs: Keyword arguments passed to the wrapped function. - - Returns: - -------- - - Returns the result of the wrapped function call. - """ - db_engine = get_relational_engine() - frame = inspect.currentframe() - while frame: - if frame.f_back and frame.f_back.f_code.co_name != "wrapper": - caller_frame = frame.f_back - break - frame = frame.f_back - - caller_name = caller_frame.f_code.co_name - caller_class = ( - caller_frame.f_locals.get("self", None).__class__.__name__ - if caller_frame.f_locals.get("self", None) - else None - ) - creator = f"{caller_class}.{caller_name}" if caller_class else caller_name - - result = await func(self, *args, **kwargs) - - async with db_engine.get_async_session() as session: - if func.__name__ == "add_nodes": - nodes: List[DataPoint] = args[0] - - relationship_ledgers = [] - - for node in nodes: - node_id = UUID(str(node.id)) - relationship_ledgers.append( - GraphRelationshipLedger( - id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), - source_node_id=node_id, - destination_node_id=node_id, - creator_function=f"{creator}.node", - node_label=getattr(node, "name", None) or str(node.id), - ) - ) - - try: - session.add_all(relationship_ledgers) - await session.flush() - except Exception as e: - logger.debug(f"Error adding relationship: {e}") - await session.rollback() - - elif func.__name__ == "add_edges": - edges = args[0] - - relationship_ledgers = [] - - for edge in edges: - source_id = UUID(str(edge[0])) - target_id = UUID(str(edge[1])) - rel_type = str(edge[2]) - relationship_ledgers.append( - GraphRelationshipLedger( - id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), - source_node_id=source_id, - destination_node_id=target_id, - creator_function=f"{creator}.{rel_type}", - ) - ) - - try: - session.add_all(relationship_ledgers) - await session.flush() - except Exception as e: - logger.debug(f"Error adding relationship: {e}") - await session.rollback() - - try: - await session.commit() - except Exception as e: - logger.debug(f"Error committing session: {e}") - - return result - - return wrapper - - class GraphDBInterface(ABC): """ Define an interface for graph database operations to be implemented by concrete classes. @@ -189,7 +69,6 @@ class GraphDBInterface(ABC): raise NotImplementedError @abstractmethod - @record_graph_changes async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None: """ Add multiple nodes to the graph in a single operation. @@ -273,7 +152,6 @@ class GraphDBInterface(ABC): raise NotImplementedError @abstractmethod - @record_graph_changes async def add_edges( self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] ) -> None: diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 7b772097f..0b5a71b1f 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -17,7 +17,6 @@ from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.files.storage import get_file_storage from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, - record_graph_changes, ) from cognee.infrastructure.engine import DataPoint from cognee.modules.storage.utils import JSONEncoder @@ -378,7 +377,6 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add node: {e}") raise - @record_graph_changes async def add_nodes(self, nodes: List[DataPoint]) -> None: """ Add multiple nodes to the graph in a batch operation. @@ -675,7 +673,6 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to add edge: {e}") raise - @record_graph_changes async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: """ Add multiple edges in a batch operation. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 520295ed2..e1097451b 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -16,7 +16,6 @@ from cognee.tasks.temporal_graph.models import Timestamp from cognee.shared.logging_utils import get_logger, ERROR from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, - record_graph_changes, ) from cognee.modules.storage.utils import JSONEncoder @@ -175,7 +174,6 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) - @record_graph_changes @override_distributed(queued_add_nodes) async def add_nodes(self, nodes: list[DataPoint]) -> None: """ @@ -446,7 +444,6 @@ class Neo4jAdapter(GraphDBInterface): return flattened - @record_graph_changes @override_distributed(queued_add_edges) async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: """ diff --git a/cognee/infrastructure/databases/graph/neptune_driver/adapter.py b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py index a2bc589af..9caa2a3c2 100644 --- a/cognee/infrastructure/databases/graph/neptune_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neptune_driver/adapter.py @@ -6,7 +6,6 @@ from uuid import UUID from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph.graph_db_interface import ( GraphDBInterface, - record_graph_changes, NodeData, EdgeData, Node, @@ -229,7 +228,6 @@ class NeptuneGraphDB(GraphDBInterface): logger.error(f"Failed to add node {node.id}: {error_msg}") raise Exception(f"Failed to add node: {error_msg}") from e - @record_graph_changes async def add_nodes(self, nodes: List[DataPoint]) -> None: """ Add multiple nodes to the graph in a single operation. @@ -534,7 +532,6 @@ class NeptuneGraphDB(GraphDBInterface): logger.error(f"Failed to add edge {source_id} -> {target_id}: {error_msg}") raise Exception(f"Failed to add edge: {error_msg}") from e - @record_graph_changes async def add_edges(self, edges: List[Tuple[str, str, str, Optional[Dict[str, Any]]]]) -> None: """ Add multiple edges to the graph in a single operation. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 91f20898e..ce0606fc9 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,5 +1,6 @@ import asyncio from os import path +from uuid import UUID import lancedb from pydantic import BaseModel from lancedb.pydantic import LanceModel, Vector @@ -282,7 +283,7 @@ class LanceDBAdapter(VectorDBInterface): ] ) - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]): collection = await self.get_collection(collection_name) # Delete one at a time to avoid commit conflicts diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 1986fae48..6f93c3395 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,5 +1,6 @@ import asyncio from typing import List, Optional, get_type_hints +from uuid import UUID from sqlalchemy.inspection import inspect from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import insert @@ -384,7 +385,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ] ) - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]): async with self.get_async_session() as session: # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 3a3df62eb..5fc2001e2 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -1,7 +1,7 @@ -from typing import List, Protocol, Optional, Union, Any +from typing import List, Protocol, Optional, Any from abc import abstractmethod +from uuid import UUID from cognee.infrastructure.engine import DataPoint -from .models.PayloadSchema import PayloadSchema class VectorDBInterface(Protocol): @@ -127,9 +127,7 @@ class VectorDBInterface(Protocol): raise NotImplementedError @abstractmethod - async def delete_data_points( - self, collection_name: str, data_point_ids: Union[List[str], list[str]] - ): + async def delete_data_points(self, collection_name: str, data_point_ids: List[UUID]): """ Delete specified data points from a collection. diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 812380eaa..8736e1ce4 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,4 +1,3 @@ -import pickle from uuid import UUID, uuid4 from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone @@ -28,8 +27,6 @@ class DataPoint(BaseModel): - update_version - to_json - from_json - - to_pickle - - from_pickle - to_dict - from_dict """ diff --git a/cognee/modules/data/methods/create_dataset.py b/cognee/modules/data/methods/create_dataset.py index c080de0e8..0a3109611 100644 --- a/cognee/modules/data/methods/create_dataset.py +++ b/cognee/modules/data/methods/create_dataset.py @@ -1,12 +1,16 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import joinedload -from cognee.modules.data.models import Dataset +from cognee.infrastructure.databases.relational import with_async_session + +from cognee.modules.data.models import Dataset from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id + from cognee.modules.users.models import User +@with_async_session async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset: owner_id = user.id diff --git a/cognee/modules/engine/utils/__init__.py b/cognee/modules/engine/utils/__init__.py index 892315259..975d87a1d 100644 --- a/cognee/modules/engine/utils/__init__.py +++ b/cognee/modules/engine/utils/__init__.py @@ -1,4 +1,5 @@ from .generate_node_id import generate_node_id +from .generate_edge_id import generate_edge_id from .generate_node_name import generate_node_name from .generate_edge_name import generate_edge_name from .generate_event_datapoint import generate_event_datapoint diff --git a/cognee/modules/graph/methods/__init__.py b/cognee/modules/graph/methods/__init__.py index e0752c6b0..169e1b6dc 100644 --- a/cognee/modules/graph/methods/__init__.py +++ b/cognee/modules/graph/methods/__init__.py @@ -1 +1,8 @@ from .get_formatted_graph_data import get_formatted_graph_data +from .upsert_edges import upsert_edges +from .upsert_nodes import upsert_nodes +from .get_data_related_nodes import get_data_related_nodes +from .delete_data_related_nodes import delete_data_related_nodes +from .delete_data_related_edges import delete_data_related_edges +from .get_data_related_edges import get_data_related_edges +from .delete_data_nodes_and_edges import delete_data_nodes_and_edges diff --git a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py new file mode 100644 index 000000000..dd9619cbe --- /dev/null +++ b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py @@ -0,0 +1,43 @@ +from uuid import UUID +from typing import Dict, List + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine +from cognee.modules.engine.utils import generate_edge_id +from cognee.modules.graph.methods import ( + delete_data_related_edges, + delete_data_related_nodes, + get_data_related_nodes, + get_data_related_edges, +) + + +async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None: + affected_nodes = await get_data_related_nodes(dataset_id, data_id) + + graph_engine = await get_graph_engine() + await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes]) + + affected_vector_collections: Dict[str, List] = {} + for node in affected_nodes: + for indexed_field in node.indexed_fields: + collection_name = f"{node.type}_{indexed_field}" + if collection_name not in affected_vector_collections: + affected_vector_collections[collection_name] = [] + affected_vector_collections[collection_name].append(node) + + vector_engine = get_vector_engine() + for affected_collection, affected_nodes in affected_vector_collections.items(): + await vector_engine.delete_data_points( + affected_collection, [node.id for node in affected_nodes] + ) + + affected_relationships = await get_data_related_edges(dataset_id, data_id) + + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [generate_edge_id(edge.relationship_name) for edge in affected_relationships], + ) + + await delete_data_related_nodes(data_id) + await delete_data_related_edges(data_id) diff --git a/cognee/modules/graph/methods/delete_data_related_edges.py b/cognee/modules/graph/methods/delete_data_related_edges.py new file mode 100644 index 000000000..c8d6bc563 --- /dev/null +++ b/cognee/modules/graph/methods/delete_data_related_edges.py @@ -0,0 +1,13 @@ +from uuid import UUID +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Edge + + +@with_async_session +async def delete_data_related_edges(data_id: UUID, session: AsyncSession): + nodes = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all() + + await session.execute(delete(Edge).where(Edge.id.in_([node.id for node in nodes]))) diff --git a/cognee/modules/graph/methods/delete_data_related_nodes.py b/cognee/modules/graph/methods/delete_data_related_nodes.py new file mode 100644 index 000000000..adc8e8f20 --- /dev/null +++ b/cognee/modules/graph/methods/delete_data_related_nodes.py @@ -0,0 +1,13 @@ +from uuid import UUID +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Node + + +@with_async_session +async def delete_data_related_nodes(data_id: UUID, session: AsyncSession): + nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all() + + await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes]))) diff --git a/cognee/modules/graph/methods/get_data_related_edges.py b/cognee/modules/graph/methods/get_data_related_edges.py new file mode 100644 index 000000000..cbbe2f296 --- /dev/null +++ b/cognee/modules/graph/methods/get_data_related_edges.py @@ -0,0 +1,17 @@ +from uuid import UUID +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Edge + + +@with_async_session +async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: AsyncSession): + return ( + await session.scalars( + select(Edge) + .where(Edge.data_id == data_id, Edge.dataset_id == dataset_id) + .distinct(Edge.relationship_name) + ) + ).all() diff --git a/cognee/modules/graph/methods/get_data_related_nodes.py b/cognee/modules/graph/methods/get_data_related_nodes.py new file mode 100644 index 000000000..f6bff754e --- /dev/null +++ b/cognee/modules/graph/methods/get_data_related_nodes.py @@ -0,0 +1,26 @@ +from uuid import UUID +from sqlalchemy import and_, exists, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Node + + +@with_async_session +async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession): + NodeAlias = Node.__table__.alias("n2") + + subq = select(NodeAlias.c.id).where( + and_( + NodeAlias.c.slug == Node.slug, + NodeAlias.c.dataset_id == dataset_id, + NodeAlias.c.data_id != data_id, + ) + ) + + query_statement = select(Node).where( + and_(Node.data_id == data_id, Node.dataset_id == dataset_id, ~exists(subq)) + ) + + data_related_nodes = await session.scalars(query_statement) + return data_related_nodes.all() diff --git a/cognee/modules/graph/methods/set_current_user.py b/cognee/modules/graph/methods/set_current_user.py new file mode 100644 index 000000000..49d471f0e --- /dev/null +++ b/cognee/modules/graph/methods/set_current_user.py @@ -0,0 +1,8 @@ +import uuid +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + + +async def set_current_user(session: AsyncSession, user_id: uuid.UUID, local: bool = False): + scope = "LOCAL " if local else "" + await session.execute(text(f"SET {scope}app.current_user_id = '{user_id}'")) diff --git a/cognee/modules/graph/methods/upsert_edges.py b/cognee/modules/graph/methods/upsert_edges.py new file mode 100644 index 000000000..9bfa8876e --- /dev/null +++ b/cognee/modules/graph/methods/upsert_edges.py @@ -0,0 +1,57 @@ +from uuid import UUID, uuid5, NAMESPACE_OID +from typing import Any, Dict, List, Tuple +from fastapi.encoders import jsonable_encoder +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.dialects.postgresql import insert + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models.Edge import Edge +from .set_current_user import set_current_user + + +@with_async_session +async def upsert_edges( + edges: List[Tuple[UUID, UUID, str, Dict[str, Any]]], + user_id: UUID, + data_id: UUID, + dataset_id: UUID, + session: AsyncSession, +): + """ + Adds edges to the edges table. + + Parameters: + ----------- + - edges (list): A list of edges to be added to the graph. + """ + if session.get_bind().dialect.name == "postgresql": + # Set the session-level RLS variable + await set_current_user(session, user_id) + + upsert_statement = ( + insert(Edge) + .values( + [ + { + "id": uuid5( + NAMESPACE_OID, + str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]), + ), + "user_id": user_id, + "data_id": data_id, + "dataset_id": dataset_id, + "source_node_id": edge[0], + "destination_node_id": edge[1], + "relationship_name": edge[2], + "label": edge[2], + "props": jsonable_encoder(edge[3]), + } + for edge in edges + ] + ) + .on_conflict_do_nothing(index_elements=["id"]) + ) + + await session.execute(upsert_statement) + + await session.commit() diff --git a/cognee/modules/graph/methods/upsert_nodes.py b/cognee/modules/graph/methods/upsert_nodes.py new file mode 100644 index 000000000..eeb159c84 --- /dev/null +++ b/cognee/modules/graph/methods/upsert_nodes.py @@ -0,0 +1,51 @@ +from typing import List +from uuid import NAMESPACE_OID, UUID, uuid5 +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.dialects.postgresql import insert + +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Node +from .set_current_user import set_current_user + + +@with_async_session +async def upsert_nodes( + nodes: List[DataPoint], user_id: UUID, dataset_id: UUID, data_id: UUID, session: AsyncSession +): + """ + Adds nodes to the nodes table. + + Parameters: + ----------- + - nodes (list): A list of nodes to be added to the graph. + """ + if session.get_bind().dialect.name == "postgresql": + # Set the session-level RLS variable + await set_current_user(session, user_id) + + upsert_statement = ( + insert(Node) + .values( + [ + { + "id": uuid5( + NAMESPACE_OID, str(user_id) + str(dataset_id) + str(data_id) + str(node.id) + ), + "slug": node.id, + "user_id": user_id, + "data_id": data_id, + "dataset_id": dataset_id, + "type": node.type, + "indexed_fields": DataPoint.get_embeddable_property_names(node), + "label": getattr(node, "label", getattr(node, "name", str(node.id))), + } + for node in nodes + ] + ) + .on_conflict_do_nothing(index_elements=["id"]) + ) + + await session.execute(upsert_statement) + + await session.commit() diff --git a/cognee/modules/graph/models/Edge.py b/cognee/modules/graph/models/Edge.py new file mode 100644 index 000000000..72f7399be --- /dev/null +++ b/cognee/modules/graph/models/Edge.py @@ -0,0 +1,50 @@ +from sqlalchemy import ( + # event, + String, + JSON, + UUID, +) + +# from sqlalchemy.schema import DDL +from sqlalchemy.orm import Mapped, mapped_column + +from cognee.infrastructure.databases.relational import Base + + +class Edge(Base): + __tablename__ = "edges" + + id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True) + + user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False) + + dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + relationship_name: Mapped[str | None] = mapped_column(String(255)) + + label: Mapped[str | None] = mapped_column(String(255)) + props: Mapped[dict | None] = mapped_column(JSON) + + # __table_args__ = ( + # {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user + # ) + + +# Enable row-level security (RLS) for edges +# enable_edge_rls = DDL(""" +# ALTER TABLE edges ENABLE ROW LEVEL SECURITY; +# """) +# create_user_isolation_policy = DDL(""" +# CREATE POLICY user_isolation_policy +# ON edges +# USING (user_id = current_setting('app.current_user_id')::uuid) +# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid); +# """) + +# event.listen(Edge.__table__, "after_create", enable_edge_rls) +# event.listen(Edge.__table__, "after_create", create_user_isolation_policy) diff --git a/cognee/modules/graph/models/Node.py b/cognee/modules/graph/models/Node.py new file mode 100644 index 000000000..7c774578f --- /dev/null +++ b/cognee/modules/graph/models/Node.py @@ -0,0 +1,53 @@ +from sqlalchemy import ( + Index, + # event, + String, + JSON, + UUID, +) + +# from sqlalchemy.schema import DDL +from sqlalchemy.orm import Mapped, mapped_column + +from cognee.infrastructure.databases.relational import Base + + +class Node(Base): + __tablename__ = "nodes" + + id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True) + + slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) + + label: Mapped[str] = mapped_column(String(255)) + type: Mapped[str] = mapped_column(String(255), nullable=False) + indexed_fields: Mapped[list] = mapped_column(JSON) + + # props: Mapped[dict] = mapped_column(JSON) + + __table_args__ = ( + Index("index_node_dataset_slug", "dataset_id", "slug"), + Index("index_node_dataset_data", "dataset_id", "data_id"), + # {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id + ) + + +# Enable row-level security (RLS) for nodes +# enable_node_rls = DDL(""" +# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY; +# """) +# create_user_isolation_policy = DDL(""" +# CREATE POLICY user_isolation_policy +# ON nodes +# USING (user_id = current_setting('app.current_user_id')::uuid) +# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid); +# """) + +# event.listen(Node.__table__, "after_create", enable_node_rls) +# event.listen(Node.__table__, "after_create", create_user_isolation_policy) diff --git a/cognee/modules/graph/models/__init__.py b/cognee/modules/graph/models/__init__.py new file mode 100644 index 000000000..5920e89c4 --- /dev/null +++ b/cognee/modules/graph/models/__init__.py @@ -0,0 +1,2 @@ +from .Edge import Edge +from .Node import Node diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 5497207a4..924a61dfc 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -134,8 +134,8 @@ def _targets_generator( async def get_graph_from_model( data_point: DataPoint, - added_nodes: Dict[str, bool], - added_edges: Dict[str, bool], + added_nodes: Optional[Dict[str, bool]] = None, + added_edges: Optional[Dict[str, bool]] = None, visited_properties: Optional[Dict[str, bool]] = None, include_root: bool = True, ) -> Tuple[List[DataPoint], List[Tuple[str, str, str, Dict[str, Any]]]]: @@ -152,6 +152,12 @@ async def get_graph_from_model( Returns: Tuple of (nodes, edges) extracted from the model """ + if not added_nodes: + added_nodes = {} + + if not added_edges: + added_edges = {} + if str(data_point.id) in added_nodes: return [], [] diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index b59a171f7..eec5874ad 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -1,6 +1,6 @@ import asyncio from uuid import UUID -from typing import Union +from typing import Dict, Optional, Union from cognee.modules.pipelines.layers.setup_and_check_environment import ( setup_and_check_environment, @@ -29,12 +29,13 @@ update_status_lock = asyncio.Lock() async def run_pipeline( tasks: list[Task], data=None, - datasets: Union[str, list[str], list[UUID]] = None, - user: User = None, + datasets: Optional[Union[str, list[str], list[UUID]]] = None, + user: Optional[User] = None, pipeline_name: str = "custom_pipeline", - vector_db_config: dict = None, - graph_db_config: dict = None, + vector_db_config: Optional[dict] = None, + graph_db_config: Optional[dict] = None, incremental_loading: bool = False, + context: Optional[Dict] = None, ): validate_pipeline_tasks(tasks) await setup_and_check_environment(vector_db_config, graph_db_config) @@ -48,8 +49,8 @@ async def run_pipeline( tasks=tasks, data=data, pipeline_name=pipeline_name, - context={"dataset": dataset}, incremental_loading=incremental_loading, + context=context, ): yield run_info @@ -58,16 +59,16 @@ async def run_pipeline_per_dataset( dataset: Dataset, user: User, tasks: list[Task], - data=None, + data: Optional[list[Data]] = None, pipeline_name: str = "custom_pipeline", - context: dict = None, incremental_loading=False, + context: Optional[Dict] = None, ): # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) if not data: - data: list[Data] = await get_dataset_data(dataset_id=dataset.id) + data = await get_dataset_data(dataset_id=dataset.id) process_pipeline_status = await check_pipeline_run_qualification(dataset, data, pipeline_name) if process_pipeline_status: @@ -77,7 +78,7 @@ async def run_pipeline_per_dataset( return pipeline_run = run_tasks( - tasks, dataset.id, data, user, pipeline_name, context, incremental_loading + tasks, dataset, data, user, pipeline_name, context, incremental_loading ) async for pipeline_run_info in pipeline_run: diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 4a0c77309..9a9cf3dcc 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -1,12 +1,12 @@ import os import asyncio -from uuid import UUID -from typing import Any, List from functools import wraps +from typing import Any, Dict, List, Optional from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Dataset from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed from cognee.modules.users.models import User from cognee.shared.logging_utils import get_logger @@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import ( log_pipeline_run_complete, log_pipeline_run_error, ) -from .run_tasks_with_telemetry import run_tasks_with_telemetry from .run_tasks_data_item import run_tasks_data_item from ..tasks.task import Task @@ -54,25 +53,18 @@ def override_run_tasks(new_gen): @override_run_tasks(run_tasks_distributed) async def run_tasks( tasks: List[Task], - dataset_id: UUID, - data: List[Any] = None, - user: User = None, + dataset: Dataset, + data: Optional[List[Any]] = None, + user: Optional[User] = None, pipeline_name: str = "unknown_pipeline", - context: dict = None, + context: Optional[Dict] = None, incremental_loading: bool = False, ): if not user: user = await get_default_user() - # Get Dataset object - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - from cognee.modules.data.models import Dataset - - dataset = await session.get(Dataset, dataset_id) - pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) - pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data) + pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset.id, data) pipeline_run_id = pipeline_run.pipeline_run_id yield PipelineRunStarted( @@ -99,7 +91,12 @@ async def run_tasks( pipeline_name, pipeline_id, pipeline_run_id, - context, + { + **(context or {}), + "user": user, + "data": data_item, + "dataset": dataset, + }, user, incremental_loading, ) @@ -121,7 +118,7 @@ async def run_tasks( ) await log_pipeline_run_complete( - pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data + pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data ) yield PipelineRunCompleted( @@ -141,7 +138,7 @@ async def run_tasks( except Exception as error: await log_pipeline_run_error( - pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error + pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data, error ) yield PipelineRunErrored( diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 49b51af2d..4aec04526 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -1,10 +1,9 @@ import asyncio -from typing import Type, List, Optional +from typing import Dict, Type, List, Optional from pydantic import BaseModel from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.ontology.ontology_env_config import get_ontology_env_config -from cognee.tasks.storage import index_graph_edges from cognee.tasks.storage.add_data_points import add_data_points from cognee.modules.ontology.ontology_config import Config from cognee.modules.ontology.get_default_ontology_resolver import ( @@ -32,6 +31,7 @@ async def integrate_chunk_graphs( chunk_graphs: list, graph_model: Type[BaseModel], ontology_resolver: BaseOntologyResolver, + context: Dict, ) -> List[DocumentChunk]: """Integrate chunk graphs with ontology validation and store in databases. @@ -85,19 +85,19 @@ async def integrate_chunk_graphs( ) if len(graph_nodes) > 0: - await add_data_points(graph_nodes) + await add_data_points(graph_nodes, context) if len(graph_edges) > 0: await graph_engine.add_edges(graph_edges) - await index_graph_edges(graph_edges) return data_chunks async def extract_graph_from_data( data_chunks: List[DocumentChunk], + context: Dict, graph_model: Type[BaseModel], - config: Config = None, + config: Optional[Config] = None, custom_prompt: Optional[str] = None, ) -> List[DocumentChunk]: """ @@ -136,16 +136,16 @@ async def extract_graph_from_data( and ontology_config.ontology_resolver and ontology_config.matching_strategy ): - config: Config = { + config = { "ontology_config": { "ontology_resolver": get_ontology_resolver_from_env(**ontology_config.to_dict()) } } else: - config: Config = { - "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} - } + config = {"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}} ontology_resolver = config["ontology_config"]["ontology_resolver"] - return await integrate_chunk_graphs(data_chunks, chunk_graphs, graph_model, ontology_resolver) + return await integrate_chunk_graphs( + data_chunks, chunk_graphs, graph_model, ontology_resolver, context + ) diff --git a/cognee/tasks/storage/add_data_points.py b/cognee/tasks/storage/add_data_points.py index ad1693e82..793f3f173 100644 --- a/cognee/tasks/storage/add_data_points.py +++ b/cognee/tasks/storage/add_data_points.py @@ -1,8 +1,10 @@ import asyncio -from typing import List +from typing import Any, Dict, List, Optional from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.modules.graph.methods import upsert_edges, upsert_nodes from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model +from cognee.modules.users.models import User from .index_data_points import index_data_points from .index_graph_edges import index_graph_edges from cognee.tasks.storage.exceptions import ( @@ -10,7 +12,9 @@ from cognee.tasks.storage.exceptions import ( ) -async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: +async def add_data_points( + data_points: List[DataPoint], context: Optional[Dict[str, Any]] = None +) -> List[DataPoint]: """ Add a batch of data points to the graph database by extracting nodes and edges, deduplicating them, and indexing them for retrieval. @@ -33,8 +37,15 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: - Deduplicates nodes and edges across all results. - Updates the node index via `index_data_points`. - Inserts nodes and edges into the graph engine. - - Optionally updates the edge index via `index_graph_edges`. """ + user: Optional[User] = None + data = None + dataset = None + + if context: + data = context["data"] + dataset = context["dataset"] + user = context["user"] if not isinstance(data_points, list): raise InvalidDataPointsInAddDataPointsError("data_points must be a list.") @@ -71,6 +82,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]: await graph_engine.add_nodes(nodes) await index_data_points(nodes) + if user and dataset and data: + await upsert_nodes(nodes, user_id=user.id, dataset_id=dataset.id, data_id=data.id) + await upsert_edges(edges, user_id=user.id, dataset_id=dataset.id, data_id=data.id) + await graph_engine.add_edges(edges) await index_graph_edges(edges) diff --git a/cognee/tasks/temporal_awareness/build_graph_with_temporal_awareness.py b/cognee/tasks/temporal_awareness/build_graph_with_temporal_awareness.py index ecbf2b6be..11c34e3bf 100644 --- a/cognee/tasks/temporal_awareness/build_graph_with_temporal_awareness.py +++ b/cognee/tasks/temporal_awareness/build_graph_with_temporal_awareness.py @@ -1,13 +1,26 @@ import os +from typing import List from datetime import datetime from graphiti_core import Graphiti from graphiti_core.nodes import EpisodeType +from cognee.infrastructure.files.storage import get_file_storage +from cognee.modules.data.models import Data -async def build_graph_with_temporal_awareness(text_list): - url = os.getenv("GRAPH_DATABASE_URL") - password = os.getenv("GRAPH_DATABASE_PASSWORD") + +async def build_graph_with_temporal_awareness(data: List[Data]): + text_list: List[str] = [] + + for text_data in data: + file_dir = os.path.dirname(text_data.raw_data_location) + file_name = os.path.basename(text_data.raw_data_location) + file_storage = get_file_storage(file_dir) + async with file_storage.open(file_name, "r") as file: + text_list.append(file.read()) + + url = os.getenv("GRAPH_DATABASE_URL", "") + password = os.getenv("GRAPH_DATABASE_PASSWORD", "") graphiti = Graphiti(url, "neo4j", password) await graphiti.build_indices_and_constraints() @@ -22,4 +35,5 @@ async def build_graph_with_temporal_awareness(text_list): reference_time=datetime.now(), ) print(f"Added text: {text[:35]}...") + return graphiti diff --git a/cognee/tests/test_delete_custom_graph.py b/cognee/tests/test_delete_custom_graph.py new file mode 100644 index 000000000..32ef95659 --- /dev/null +++ b/cognee/tests/test_delete_custom_graph.py @@ -0,0 +1,107 @@ +import os +import pathlib +from typing import List +from uuid import uuid4 + +import cognee +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.models import Data, Dataset +from cognee.modules.engine.operations.setup import setup +from cognee.modules.graph.methods import delete_data_nodes_and_edges +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger +from cognee.tasks.storage import add_data_points + +logger = get_logger() + + +async def main(): + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_custom_graph") + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_custom_graph") + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Organization(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class ForProfit(Organization): + name: str = "For-Profit" + metadata: dict = {"index_fields": ["name"]} + + class NonProfit(Organization): + name: str = "Non-Profit" + metadata: dict = {"index_fields": ["name"]} + + class Person(DataPoint): + name: str + works_for: List[Organization] + metadata: dict = {"index_fields": ["name"]} + + companyA = ForProfit(name="Company A") + companyB = NonProfit(name="Company B") + + person1 = Person(name="John", works_for=[companyA, companyB]) + person2 = Person(name="Jane", works_for=[companyB]) + + user: User = await get_default_user() # type: ignore + + dataset = Dataset(id=uuid4()) + data1 = Data(id=uuid4()) + data2 = Data(id=uuid4()) + + await add_data_points( + [person1], + context={ + "user": user, + "dataset": dataset, + "data": data1, + }, + ) + + await add_data_points( + [person2], + context={ + "user": user, + "dataset": dataset, + "data": data2, + }, + ) + + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 4 and len(edges) == 3, ( + "Nodes and edges are not correctly added to the graph." + ) + + await delete_data_nodes_and_edges(dataset.id, data1.id) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly." + + await delete_data_nodes_and_edges(dataset.id, data2.id) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/cognee/tests/test_delete_default_graph.py b/cognee/tests/test_delete_default_graph.py new file mode 100644 index 000000000..7ffc66e31 --- /dev/null +++ b/cognee/tests/test_delete_default_graph.py @@ -0,0 +1,77 @@ +import os +import pathlib + +import cognee +from cognee.api.v1.visualize.visualize import visualize_graph +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.modules.data.methods import delete_data, get_dataset_data +from cognee.modules.engine.operations.setup import setup +from cognee.modules.graph.methods import ( + delete_data_nodes_and_edges, +) +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +async def main(): + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph") + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_delete_default_graph") + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + vector_engine = get_vector_engine() + + assert not await vector_engine.has_collection("EdgeType_relationship_name") + assert not await vector_engine.has_collection("Entity_name") + + await cognee.add( + "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + ) + await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.") + + cognify_result: dict = await cognee.cognify() + dataset_id = list(cognify_result.keys())[0] + + dataset_data = await get_dataset_data(dataset_id) + added_data = dataset_data[0] + + file_path = os.path.join( + pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html" + ) + await visualize_graph(file_path) + + graph_engine = await get_graph_engine() + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted." + + await delete_data_nodes_and_edges(dataset_id, added_data.id) # type: ignore + + await delete_data(added_data) + + file_path = os.path.join( + pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html" + ) + await visualize_graph(file_path) + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) >= 8 and len(edges) >= 10, "Nodes and edges are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/examples/low_level/pipeline.py b/examples/low_level/pipeline.py index 085d313a7..a7d1750fa 100644 --- a/examples/low_level/pipeline.py +++ b/examples/low_level/pipeline.py @@ -100,7 +100,7 @@ async def main(): pipeline = run_tasks( [Task(ingest_files), Task(add_data_points)], - dataset_id=datasets[0].id, + dataset=datasets[0], data=data, incremental_loading=False, ) diff --git a/examples/low_level/product_recommendation.py b/examples/low_level/product_recommendation.py index 782311618..eb2e6ad4c 100644 --- a/examples/low_level/product_recommendation.py +++ b/examples/low_level/product_recommendation.py @@ -1,13 +1,18 @@ import os import json import asyncio +from typing import Dict, List +from uuid import NAMESPACE_OID, UUID, uuid4, uuid5 from neo4j import exceptions +from pydantic import BaseModel from cognee import prune # from cognee import visualize_graph from cognee.infrastructure.databases.graph import get_graph_engine from cognee.low_level import setup, DataPoint +from cognee.modules.data.models import Data, Dataset +from cognee.modules.users.methods import get_default_user from cognee.pipelines import run_tasks, Task from cognee.tasks.storage import add_data_points @@ -20,7 +25,6 @@ products_aggregator_node = Products() class Product(DataPoint): - id: str name: str type: str price: float @@ -36,7 +40,6 @@ preferences_aggregator_node = Preferences() class Preference(DataPoint): - id: str name: str value: str is_type: Preferences = preferences_aggregator_node @@ -50,7 +53,6 @@ customers_aggregator_node = Customers() class Customer(DataPoint): - id: str name: str has_preference: list[Preference] purchased: list[Product] @@ -58,17 +60,14 @@ class Customer(DataPoint): is_type: Customers = customers_aggregator_node -def ingest_files(): - customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json") - customers = json.loads(open(customers_file_path, "r").read()) - +def ingest_customers(data): customers_data_points = {} products_data_points = {} preferences_data_points = {} - for customer in customers: + for customer in data[0].customers: new_customer = Customer( - id=customer["id"], + id=uuid5(NAMESPACE_OID, customer["id"]), name=customer["name"], liked=[], purchased=[], @@ -79,7 +78,7 @@ def ingest_files(): for product in customer["products"]: if product["id"] not in products_data_points: products_data_points[product["id"]] = Product( - id=product["id"], + id=uuid5(NAMESPACE_OID, product["id"]), type=product["type"], name=product["name"], price=product["price"], @@ -96,7 +95,7 @@ def ingest_files(): for preference in customer["preferences"]: if preference["id"] not in preferences_data_points: preferences_data_points[preference["id"]] = Preference( - id=preference["id"], + id=uuid5(NAMESPACE_OID, preference["id"]), name=preference["name"], value=preference["value"], ) @@ -104,7 +103,7 @@ def ingest_files(): new_preference = preferences_data_points[preference["id"]] new_customer.has_preference.append(new_preference) - return customers_data_points.values() + return list(customers_data_points.values()) async def main(): @@ -113,7 +112,28 @@ async def main(): await setup() - pipeline = run_tasks([Task(ingest_files), Task(add_data_points)]) + # Get user and dataset + user: User = await get_default_user() # type: ignore + main_dataset = Dataset(id=uuid4(), name="demo_dataset") + + customers_file_path = os.path.join(os.path.dirname(__file__), "customers.json") + customers = json.loads(open(customers_file_path, "r").read()) + + class Data(BaseModel): + id: UUID + customers: List[Dict] + + data = Data( + id=uuid4(), + customers=customers, + ) + + pipeline = run_tasks( + [Task(ingest_customers), Task(add_data_points)], + dataset=main_dataset, + data=[data], + user=user, + ) async for status in pipeline: print(status) diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index ece9c452b..ecc616b9b 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -1,6 +1,7 @@ import asyncio import cognee +from cognee.modules.data.methods import get_dataset_data, get_datasets from cognee.shared.logging_utils import setup_logging, ERROR from cognee.modules.pipelines import Task, run_tasks from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness @@ -35,10 +36,13 @@ async def main(): await cognee.add(text) tasks = [ - Task(build_graph_with_temporal_awareness, text_list=text_list), + Task(build_graph_with_temporal_awareness), ] - pipeline = run_tasks(tasks, user=user) + datasets = await get_datasets(user.id) + dataset_data = await get_dataset_data(datasets[0].id) # type: ignore + + pipeline = run_tasks(tasks, dataset=datasets[0], data=dataset_data, user=user) async for result in pipeline: print(result) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 51eeab560..616bb1b9f 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -483,7 +483,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "7c431fdef4921ae0", "metadata": { "ExecuteTime": { @@ -535,7 +535,7 @@ " Task(add_data_points, task_config={\"batch_size\": 10}),\n", " ]\n", "\n", - " pipeline_run = run_tasks(tasks, dataset.id, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n", + " pipeline_run = run_tasks(tasks, dataset, data_documents, user, \"cognify_pipeline\", context={\"dataset\": dataset})\n", " pipeline_run_status = None\n", "\n", " async for run_status in pipeline_run:\n", @@ -1831,7 +1831,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "cognee", "language": "python", "name": "python3" }, @@ -1845,7 +1845,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index 551295733..eecce7777 100644 --- a/poetry.lock +++ b/poetry.lock @@ -9310,6 +9310,13 @@ optional = false python-versions = ">=3.8" groups = ["main"] files = [ + {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, + {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, + {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, + {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, diff --git a/uv.lock b/uv.lock index 570da9289..0cdf2c56d 100644 --- a/uv.lock +++ b/uv.lock @@ -856,7 +856,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.3.4" +version = "0.3.5" source = { editable = "." } dependencies = [ { name = "aiofiles" },