diff --git a/cognee/complex_demos/crewai_demo/src/crewai_demo/custom_tools/cognee_ingestion.py b/cognee/complex_demos/crewai_demo/src/crewai_demo/custom_tools/cognee_ingestion.py index 1e00a8d0b..2000d43a8 100644 --- a/cognee/complex_demos/crewai_demo/src/crewai_demo/custom_tools/cognee_ingestion.py +++ b/cognee/complex_demos/crewai_demo/src/crewai_demo/custom_tools/cognee_ingestion.py @@ -8,7 +8,7 @@ import asyncio class CogneeIngestionInput(BaseModel): text: str = Field( "", - description="The text of the reportThe format you should follow is {'text': 'your report'}", + description="The text of the report The format you should follow is {'text': 'your report'}", ) diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 719c22e3c..c15a0427d 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -189,3 +189,6 @@ class GraphDBInterface(ABC): ) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]: """Get all nodes connected to a given node with their relationships.""" raise NotImplementedError + + async def get_nodeset_subgraph(self, node_type, node_name): + pass diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 380f3f713..7989a82db 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -5,7 +5,7 @@ import json import os import shutil import asyncio -from typing import Dict, Any, List, Union, Optional, Tuple +from typing import Dict, Any, List, Union, Optional, Tuple, Type from datetime import datetime, timezone from uuid import UUID from contextlib import asynccontextmanager @@ -728,6 +728,66 @@ class KuzuAdapter(GraphDBInterface): logger.error(f"Failed to get graph data: {e}") raise + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]: + label = node_type.__name__ + primary_query = """ + UNWIND $names AS wantedName + MATCH (n:Node) + WHERE n.type = $label AND n.name = wantedName + RETURN DISTINCT n.id + """ + primary_rows = await self.query(primary_query, {"names": node_name, "label": label}) + primary_ids = [row[0] for row in primary_rows] + if not primary_ids: + return [], [] + + neighbor_query = """ + MATCH (n:Node)-[:EDGE]-(nbr:Node) + WHERE n.id IN $ids + RETURN DISTINCT nbr.id + """ + nbr_rows = await self.query(neighbor_query, {"ids": primary_ids}) + neighbor_ids = [row[0] for row in nbr_rows] + + all_ids = list({*primary_ids, *neighbor_ids}) + + nodes_query = """ + MATCH (n:Node) + WHERE n.id IN $ids + RETURN n.id, n.name, n.type, n.properties + """ + node_rows = await self.query(nodes_query, {"ids": all_ids}) + nodes: List[Tuple[str, dict]] = [] + for node_id, name, typ, props in node_rows: + data = {"id": node_id, "name": name, "type": typ} + if props: + try: + data.update(json.loads(props)) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON props for node {node_id}") + nodes.append((node_id, data)) + + edges_query = """ + MATCH (a:Node)-[r:EDGE]-(b:Node) + WHERE a.id IN $ids AND b.id IN $ids + RETURN a.id, b.id, r.relationship_name, r.properties + """ + edge_rows = await self.query(edges_query, {"ids": all_ids}) + edges: List[Tuple[str, str, str, dict]] = [] + for from_id, to_id, rel_type, props in edge_rows: + data = {} + if props: + try: + data = json.loads(props) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") + + edges.append((from_id, to_id, rel_type, data)) + + return nodes, edges + async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] ): diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 9d8f1dcd0..d0532fc53 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -517,7 +517,7 @@ class Neo4jAdapter(GraphDBInterface): return (nodes, edges) - async def get_subgraph( + async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: label = node_type.__name__ diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4a6b4a13e..1913fc970 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -69,7 +69,7 @@ class CogneeGraph(CogneeAbstractGraph): try: if node_type is not None and node_name is not None: - nodes_data, edges_data = await adapter.get_subgraph( + nodes_data, edges_data = await adapter.get_nodeset_subgraph( node_type=node_type, node_name=node_name ) elif len(memory_fragment_filter) == 0: