Add search on neighbor nodes which are source to selected one

This commit is contained in:
Ken Chen 2025-06-28 08:50:32 +08:00
parent f40bc43d5e
commit 7c8f65d020

View file

@ -710,11 +710,9 @@ class MongoGraphStorage(BaseGraphStorage):
return labels
def _construct_graph_node(
self, node_data: dict[str, str], seen_nodes: set
self, node_id, node_data: dict[str, str]
) -> KnowledgeGraphNode:
node_id = str(node_data["_id"])
graph_node = KnowledgeGraphNode(
return KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
@ -728,9 +726,6 @@ class MongoGraphStorage(BaseGraphStorage):
]
},
)
seen_nodes.add(node_id)
return graph_node, seen_nodes
async def get_knowledge_graph(
self,
@ -810,6 +805,25 @@ class MongoGraphStorage(BaseGraphStorage):
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": "chunk_entity_relation",
"pipeline": [
{"$match": {"_id": "LightRAG"}},
{
"$graphLookup": {
"from": "chunk_entity_relation_edges",
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"as": "connected_edges",
"maxDepth": 3,
"depthField": "depth",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
@ -817,8 +831,10 @@ class MongoGraphStorage(BaseGraphStorage):
async for doc in cursor:
# Add the start nodes
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append(graph_node)
node_id = str(doc["_id"])
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, doc))
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
@ -857,13 +873,14 @@ class MongoGraphStorage(BaseGraphStorage):
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as label so that we do not check existence next step
cursor = self.collection.find(
{"_id": {"$in": [node for node in node_ids if node != label]}}
)
async for doc in cursor:
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append(graph_node)
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(node_id, doc))
for edge in node_edges:
if (