From 5739f52d2952333ff3fcf8e19bbc84668963ced8 Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Sat, 28 Jun 2025 17:10:39 +0800 Subject: [PATCH] Rewrite get_knowledge_graph with label * by degree --- lightrag/kg/mongo_impl.py | 112 +++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 32 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ca0c6805..6b158d7a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -743,10 +743,85 @@ class MongoGraphStorage(BaseGraphStorage): "source_node_id", "target_node_id", "relationship", + "source_ids", ] }, ) + async def get_knowledge_graph_all_by_degree( + self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + """ + It's possible that the node with one or multiple relationships is retrieved, + while its neighbor is not. Then this node might seem like disconnected in UI. + """ + + total_node_count = await self.collection.count_documents({}) + result = KnowledgeGraph() + seen_edges = set() + + result.is_truncated = total_node_count > max_nodes + if result.is_truncated: + # Get all node_ids ranked by degree if max_nodes exceeds total node count + pipeline = [ + {"$project": {"source_node_id": 1, "_id": 0}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, + { + "$unionWith": { + "coll": self._edge_collection_name, + "pipeline": [ + {"$project": {"target_node_id": 1, "_id": 0}}, + { + "$group": { + "_id": "$target_node_id", + "degree": {"$sum": 1}, + } + }, + ], + } + }, + {"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}}, + {"$sort": {"degree": -1}}, + {"$limit": max_nodes}, + ] + cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True) + + node_ids = [] + async for doc in cursor: + node_id = str(doc["_id"]) + node_ids.append(node_id) + + cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0}) + async for doc in cursor: + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + # As node count reaches the limit, only need to fetch the edges that directly connect to these nodes + edge_cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": node_ids}}, + {"target_node_id": {"$in": node_ids}}, + ] + } + ) + else: + # All nodes and edges are needed + cursor = self.collection.find({}, {"source_ids": 0}) + + async for doc in cursor: + node_id = str(doc["_id"]) + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + edge_cursor = self.edge_collection.find({}) + + async for edge in edge_cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + seen_edges.add(edge_id) + result.edges.append(self._construct_graph_edge(edge_id, edge)) + + return result + async def get_knowledge_graph( self, node_label: str, @@ -778,31 +853,9 @@ class MongoGraphStorage(BaseGraphStorage): try: # Optimize pipeline to avoid memory issues with large datasets if label == "*": - # For getting all nodes, use a simpler pipeline to avoid memory issues - pipeline = [ - {"$limit": max_nodes + 1}, # Limit early to reduce memory usage - { - "$project": project_doc - }, # Load without internal fields or unneeded ones shown in WebUI - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - # {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, - # {"$sort": {"edge_count": -1}}, - # {"$limit": max_nodes}, - ] - - # Check if we need to set truncation flag - all_node_count = await self.collection.count_documents({}) - result.is_truncated = all_node_count > max_nodes + return await self.get_knowledge_graph_all_by_degree( + max_depth, max_nodes + ) else: # Verify if starting node exists start_node = await self.collection.find_one({"_id": label}) @@ -827,7 +880,7 @@ class MongoGraphStorage(BaseGraphStorage): }, { "$unionWith": { - "coll": "chunk_entity_relation", + "coll": self._collection_name, "pipeline": [ {"$match": {"_id": label}}, {"$project": project_doc}, @@ -929,13 +982,8 @@ class MongoGraphStorage(BaseGraphStorage): try: simple_cursor = self.collection.find({}).limit(max_nodes) async for doc in simple_cursor: - node_id = str(doc["_id"]) result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={k: v for k, v in doc.items() if k != "_id"}, - ) + self._construct_graph_node(str(doc["_id"]), doc) ) result.is_truncated = True logger.info(