feat: adds kuzu support to crewAI demo
This commit is contained in:
parent
b4b55b820d
commit
c9590ef760
5 changed files with 67 additions and 4 deletions
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
class CogneeIngestionInput(BaseModel):
|
||||
text: str = Field(
|
||||
"",
|
||||
description="The text of the reportThe format you should follow is {'text': 'your report'}",
|
||||
description="The text of the report The format you should follow is {'text': 'your report'}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -189,3 +189,6 @@ class GraphDBInterface(ABC):
|
|||
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
||||
"""Get all nodes connected to a given node with their relationships."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_nodeset_subgraph(self, node_type, node_name):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -728,6 +728,66 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to get graph data: {e}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
primary_query = """
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:Node)
|
||||
WHERE n.type = $label AND n.name = wantedName
|
||||
RETURN DISTINCT n.id
|
||||
"""
|
||||
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
|
||||
primary_ids = [row[0] for row in primary_rows]
|
||||
if not primary_ids:
|
||||
return [], []
|
||||
|
||||
neighbor_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(nbr:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN DISTINCT nbr.id
|
||||
"""
|
||||
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
|
||||
neighbor_ids = [row[0] for row in nbr_rows]
|
||||
|
||||
all_ids = list({*primary_ids, *neighbor_ids})
|
||||
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $ids
|
||||
RETURN n.id, n.name, n.type, n.properties
|
||||
"""
|
||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||
nodes: List[Tuple[str, dict]] = []
|
||||
for node_id, name, typ, props in node_rows:
|
||||
data = {"id": node_id, "name": name, "type": typ}
|
||||
if props:
|
||||
try:
|
||||
data.update(json.loads(props))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for node {node_id}")
|
||||
nodes.append((node_id, data))
|
||||
|
||||
edges_query = """
|
||||
MATCH (a:Node)-[r:EDGE]-(b:Node)
|
||||
WHERE a.id IN $ids AND b.id IN $ids
|
||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||
edges: List[Tuple[str, str, str, dict]] = []
|
||||
for from_id, to_id, rel_type, props in edge_rows:
|
||||
data = {}
|
||||
if props:
|
||||
try:
|
||||
data = json.loads(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||
|
||||
edges.append((from_id, to_id, rel_type, data))
|
||||
|
||||
return nodes, edges
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
|
|
|
|||
|
|
@ -517,7 +517,7 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_subgraph(
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
label = node_type.__name__
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
try:
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_subgraph(
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue