From 7c8f65d020fbbe0ac91e5c2a2609f042e55853e7 Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Sat, 28 Jun 2025 08:50:32 +0800 Subject: [PATCH] Add search on neighbor nodes which are source to selected one --- lightrag/kg/mongo_impl.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 23d6ae7f..39410e65 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -710,11 +710,9 @@ class MongoGraphStorage(BaseGraphStorage): return labels def _construct_graph_node( - self, node_data: dict[str, str], seen_nodes: set + self, node_id, node_data: dict[str, str] ) -> KnowledgeGraphNode: - node_id = str(node_data["_id"]) - - graph_node = KnowledgeGraphNode( + return KnowledgeGraphNode( id=node_id, labels=[node_id], properties={ @@ -728,9 +726,6 @@ class MongoGraphStorage(BaseGraphStorage): ] }, ) - seen_nodes.add(node_id) - - return graph_node, seen_nodes async def get_knowledge_graph( self, @@ -810,6 +805,25 @@ class MongoGraphStorage(BaseGraphStorage): "as": "connected_edges", }, }, + { + "$unionWith": { + "coll": "chunk_entity_relation", + "pipeline": [ + {"$match": {"_id": "LightRAG"}}, + { + "$graphLookup": { + "from": "chunk_entity_relation_edges", + "startWith": "$_id", + "connectFromField": "source_node_id", + "connectToField": "target_node_id", + "as": "connected_edges", + "maxDepth": 3, + "depthField": "depth", + } + }, + ], + } + }, ] cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) @@ -817,8 +831,10 @@ class MongoGraphStorage(BaseGraphStorage): async for doc in cursor: # Add the start nodes - graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes) - result.nodes.append(graph_node) + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append(self._construct_graph_node(node_id, doc)) if doc.get("connected_edges", []): node_edges.extend(doc.get("connected_edges")) @@ -857,13 +873,14 @@ class MongoGraphStorage(BaseGraphStorage): node_ids.append(edge["target_node_id"]) seen_nodes.add(edge["target_node_id"]) + # Filter out all the node whose id is same as label so that we do not check existence next step cursor = self.collection.find( {"_id": {"$in": [node for node in node_ids if node != label]}} ) async for doc in cursor: - graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes) - result.nodes.append(graph_node) + node_id = str(doc["_id"]) + result.nodes.append(self._construct_graph_node(node_id, doc)) for edge in node_edges: if (