From f40bc43d5e5d41996b8f3921f438a37dbf3e0a10 Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Thu, 26 Jun 2025 23:11:31 +0800 Subject: [PATCH] Fix nodes & edges are missing when retrieving knowledge subgraph by selecting particular node_id --- lightrag/kg/mongo_impl.py | 99 ++++++++++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 23 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index fbea463b..23d6ae7f 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -709,6 +709,29 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels + def _construct_graph_node( + self, node_data: dict[str, str], seen_nodes: set + ) -> KnowledgeGraphNode: + node_id = str(node_data["_id"]) + + graph_node = KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in node_data.items() + if k + not in [ + "_id", + "connected_edges", + "edge_count", + ] + }, + ) + seen_nodes.add(node_id) + + return graph_node, seen_nodes + async def get_knowledge_graph( self, node_label: str, @@ -730,13 +753,22 @@ class MongoGraphStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() node_edges = [] + project_doc = { + "source_ids": 0, + "created_at": 0, + "entity_type": 0, + "file_path": 0, + } try: # Optimize pipeline to avoid memory issues with large datasets if label == "*": # For getting all nodes, use a simpler pipeline to avoid memory issues pipeline = [ - {"$limit": max_nodes}, # Limit early to reduce memory usage + {"$limit": max_nodes + 1}, # Limit early to reduce memory usage + { + "$project": project_doc + }, # Load without internal fields or unneeded ones shown in WebUI { "$graphLookup": { "from": self._edge_collection_name, @@ -748,6 +780,9 @@ class MongoGraphStorage(BaseGraphStorage): "as": "connected_edges", }, }, + # {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, + # {"$sort": {"edge_count": -1}}, + # {"$limit": max_nodes}, ] # Check if we need to set truncation flag @@ -763,6 +798,7 @@ class MongoGraphStorage(BaseGraphStorage): # For specific node queries, use the original pipeline but optimized pipeline = [ {"$match": {"_id": label}}, + {"$project": project_doc}, { "$graphLookup": { "from": self._edge_collection_name, @@ -774,34 +810,16 @@ class MongoGraphStorage(BaseGraphStorage): "as": "connected_edges", }, }, - {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, - {"$sort": {"edge_count": -1}}, - {"$limit": max_nodes}, ] cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) nodes_processed = 0 async for doc in cursor: - # Add the start node - node_id = str(doc["_id"]) - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "connected_edges", - "edge_count", - ] - }, - ) - ) - seen_nodes.add(node_id) + # Add the start nodes + graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes) + result.nodes.append(graph_node) + if doc.get("connected_edges", []): node_edges.extend(doc.get("connected_edges")) @@ -812,6 +830,41 @@ class MongoGraphStorage(BaseGraphStorage): result.is_truncated = True break + # When label != "*", cursor above only have one node and we need to get the subgraph by connected edges + # Sort the connected edges by depth ascending and weight descending + # And stores the source_node_id and target_node_id in sequence to retrieve the nodes again + if label != "*": + node_edges = sorted( + node_edges, + key=lambda x: (x["depth"], -x["weight"]), + ) + + # As order matters, we need to use another list to store the node_id + # And only take the first max_nodes ones + node_ids = [] + for edge in node_edges: + if ( + len(node_ids) < max_nodes + and edge["source_node_id"] not in seen_nodes + ): + node_ids.append(edge["source_node_id"]) + seen_nodes.add(edge["source_node_id"]) + + if ( + len(node_ids) < max_nodes + and edge["target_node_id"] not in seen_nodes + ): + node_ids.append(edge["target_node_id"]) + seen_nodes.add(edge["target_node_id"]) + + cursor = self.collection.find( + {"_id": {"$in": [node for node in node_ids if node != label]}} + ) + + async for doc in cursor: + graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes) + result.nodes.append(graph_node) + for edge in node_edges: if ( edge["source_node_id"] not in seen_nodes