diff --git a/lightrag/base.py b/lightrag/base.py index 12d142c1..c6811471 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC): @dataclass class BaseGraphStorage(StorageNameSpace, ABC): + """All operations related to edges in graph should be undirected.""" + embedding_func: EmbeddingFunc @abstractmethod diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index fbea463b..6483f6fd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,4 +1,5 @@ import os +import time from dataclasses import dataclass, field import numpy as np import configparser @@ -35,6 +36,7 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") class ClientManager: @@ -417,11 +419,21 @@ class MongoGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ - Check if there's a direct single-hop edge from source_node_id to target_node_id. + Check if there's a direct single-hop edge between source_node_id and target_node_id. """ - # Direct check if the target_node appears among the edges array. doc = await self.edge_collection.find_one( - {"source_node_id": source_node_id, "target_node_id": target_node_id}, + { + "$or": [ + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, + ] + }, {"_id": 1}, ) return doc is not None @@ -651,7 +663,7 @@ class MongoGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """ - Upsert an edge from source_node_id -> target_node_id with optional 'relation'. + Upsert an edge between source_node_id and target_node_id with optional 'relation'. If an edge with the same target exists, we remove it and re-insert with updated data. """ # Ensure source node exists @@ -663,8 +675,22 @@ class MongoGraphStorage(BaseGraphStorage): GRAPH_FIELD_SEP ) + edge_data["source_node_id"] = source_node_id + edge_data["target_node_id"] = target_node_id + await self.edge_collection.update_one( - {"source_node_id": source_node_id, "target_node_id": target_node_id}, + { + "$or": [ + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, + ] + }, update_doc, upsert=True, ) @@ -678,7 +704,7 @@ class MongoGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """ 1) Remove node's doc entirely. - 2) Remove inbound edges from any doc that references node_id. + 2) Remove inbound & outbound edges from any doc that references node_id. """ # Remove all edges await self.edge_collection.delete_many( @@ -709,141 +735,369 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels + def _construct_graph_node( + self, node_id, node_data: dict[str, str] + ) -> KnowledgeGraphNode: + return KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in node_data.items() + if k + not in [ + "_id", + "connected_edges", + "source_ids", + "edge_count", + ] + }, + ) + + def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]): + return KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relationship", ""), + source=edge["source_node_id"], + target=edge["target_node_id"], + properties={ + k: v + for k, v in edge.items() + if k + not in [ + "_id", + "source_node_id", + "target_node_id", + "relationship", + "source_ids", + ] + }, + ) + + async def get_knowledge_graph_all_by_degree( + self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + """ + It's possible that the node with one or multiple relationships is retrieved, + while its neighbor is not. Then this node might seem like disconnected in UI. + """ + + total_node_count = await self.collection.count_documents({}) + result = KnowledgeGraph() + seen_edges = set() + + result.is_truncated = total_node_count > max_nodes + if result.is_truncated: + # Get all node_ids ranked by degree if max_nodes exceeds total node count + pipeline = [ + {"$project": {"source_node_id": 1, "_id": 0}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, + { + "$unionWith": { + "coll": self._edge_collection_name, + "pipeline": [ + {"$project": {"target_node_id": 1, "_id": 0}}, + { + "$group": { + "_id": "$target_node_id", + "degree": {"$sum": 1}, + } + }, + ], + } + }, + {"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}}, + {"$sort": {"degree": -1}}, + {"$limit": max_nodes}, + ] + cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True) + + node_ids = [] + async for doc in cursor: + node_id = str(doc["_id"]) + node_ids.append(node_id) + + cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0}) + async for doc in cursor: + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + # As node count reaches the limit, only need to fetch the edges that directly connect to these nodes + edge_cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": node_ids}}, + {"target_node_id": {"$in": node_ids}}, + ] + } + ) + else: + # All nodes and edges are needed + cursor = self.collection.find({}, {"source_ids": 0}) + + async for doc in cursor: + node_id = str(doc["_id"]) + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + edge_cursor = self.edge_collection.find({}) + + async for edge in edge_cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + seen_edges.add(edge_id) + result.edges.append(self._construct_graph_edge(edge_id, edge)) + + return result + + async def _bidirectional_bfs_nodes( + self, + node_labels: list[str], + seen_nodes: set[str], + result: KnowledgeGraph, + depth: int = 0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + if depth > max_depth or len(result.nodes) > max_nodes: + return result + + cursor = self.collection.find({"_id": {"$in": node_labels}}) + + async for node in cursor: + node_id = node["_id"] + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append(self._construct_graph_node(node_id, node)) + if len(result.nodes) > max_nodes: + return result + + # Collect neighbors + # Get both inbound and outbound one hop nodes + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": {"$in": node_labels}}, + {"target_node_id": {"$in": node_labels}}, + ] + } + ) + + neighbor_nodes = [] + async for edge in cursor: + if edge["source_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["source_node_id"]) + if edge["target_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["target_node_id"]) + + if neighbor_nodes: + result = await self._bidirectional_bfs_nodes( + neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes + ) + + return result + + async def get_knowledge_subgraph_bidirectional_bfs( + self, + node_label: str, + depth=0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + + result = await self._bidirectional_bfs_nodes( + [node_label], seen_nodes, result, depth, max_depth, max_nodes + ) + + # Get all edges from seen_nodes + all_node_ids = list(seen_nodes) + cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": all_node_ids}}, + {"target_node_id": {"$in": all_node_ids}}, + ] + } + ) + + async for edge in cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + + async def get_knowledge_subgraph_in_out_bound_bfs( + self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + project_doc = { + "source_ids": 0, + "created_at": 0, + "entity_type": 0, + "file_path": 0, + } + + # Verify if starting node exists + start_node = await self.collection.find_one({"_id": node_label}) + if not start_node: + logger.warning(f"Starting node with label {node_label} does not exist!") + return result + + seen_nodes.add(node_label) + result.nodes.append(self._construct_graph_node(node_label, start_node)) + + if max_depth == 0: + return result + + # In MongoDB, depth = 0 means one-hop + max_depth = max_depth - 1 + + pipeline = [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + { + "$unionWith": { + "coll": self._collection_name, + "pipeline": [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "source_node_id", + "connectToField": "target_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + } + }, + ], + } + }, + ] + + cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) + node_edges = [] + + # Two records for node_label are returned capturing outbound and inbound connected_edges + async for doc in cursor: + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + # Sort the connected edges by depth ascending and weight descending + # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes + node_edges = sorted( + node_edges, + key=lambda x: (x["depth"], -x["weight"]), + ) + + # As order matters, we need to use another list to store the node_id + # And only take the first max_nodes ones + node_ids = [] + for edge in node_edges: + if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes: + node_ids.append(edge["source_node_id"]) + seen_nodes.add(edge["source_node_id"]) + + if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes: + node_ids.append(edge["target_node_id"]) + seen_nodes.add(edge["target_node_id"]) + + # Filter out all the node whose id is same as node_label so that we do not check existence next step + cursor = self.collection.find({"_id": {"$in": node_ids}}) + + async for doc in cursor: + result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc)) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + async def get_knowledge_graph( self, node_label: str, - max_depth: int = 5, + max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Args: - node_label: Label of the nodes to start from - max_depth: Maximum depth of traversal (default: 5) + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000 Returns: - KnowledgeGraph object containing nodes and edges of the subgraph + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + If a graph is like this and starting from B: + A → B ← C ← F, B -> E, C → D + + Outbound BFS: + B → E + + Inbound BFS: + A → B + C → B + F → C + + Bidirectional BFS: + A → B + B → E + F → C + C → B + C → D """ - label = node_label result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - node_edges = [] + start = time.perf_counter() try: # Optimize pipeline to avoid memory issues with large datasets - if label == "*": - # For getting all nodes, use a simpler pipeline to avoid memory issues - pipeline = [ - {"$limit": max_nodes}, # Limit early to reduce memory usage - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - ] - - # Check if we need to set truncation flag - all_node_count = await self.collection.count_documents({}) - result.is_truncated = all_node_count > max_nodes - else: - # Verify if starting node exists - start_node = await self.collection.find_one({"_id": label}) - if not start_node: - logger.warning(f"Starting node with label {label} does not exist!") - return result - - # For specific node queries, use the original pipeline but optimized - pipeline = [ - {"$match": {"_id": label}}, - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, - {"$sort": {"edge_count": -1}}, - {"$limit": max_nodes}, - ] - - cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) - nodes_processed = 0 - - async for doc in cursor: - # Add the start node - node_id = str(doc["_id"]) - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "connected_edges", - "edge_count", - ] - }, - ) + if node_label == "*": + result = await self.get_knowledge_graph_all_by_degree( + max_depth, max_nodes + ) + elif GRAPH_BFS_MODE == "in_out_bound": + result = await self.get_knowledge_subgraph_in_out_bound_bfs( + node_label, max_depth, max_nodes + ) + else: + result = await self.get_knowledge_subgraph_bidirectional_bfs( + node_label, 0, max_depth, max_nodes ) - seen_nodes.add(node_id) - if doc.get("connected_edges", []): - node_edges.extend(doc.get("connected_edges")) - nodes_processed += 1 - - # Additional safety check to prevent memory issues - if nodes_processed >= max_nodes: - result.is_truncated = True - break - - for edge in node_edges: - if ( - edge["source_node_id"] not in seen_nodes - or edge["target_node_id"] not in seen_nodes - ): - continue - - edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relationship", ""), - source=edge["source_node_id"], - target=edge["target_node_id"], - properties={ - k: v - for k, v in edge.items() - if k - not in [ - "_id", - "source_node_id", - "target_node_id", - "relationship", - ] - }, - ) - ) - seen_edges.add(edge_id) + duration = time.perf_counter() - start logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" + f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" ) except PyMongoError as e: @@ -856,13 +1110,8 @@ class MongoGraphStorage(BaseGraphStorage): try: simple_cursor = self.collection.find({}).limit(max_nodes) async for doc in simple_cursor: - node_id = str(doc["_id"]) result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={k: v for k, v in doc.items() if k != "_id"}, - ) + self._construct_graph_node(str(doc["_id"]), doc) ) result.is_truncated = True logger.info( @@ -1028,8 +1277,6 @@ class MongoVectorDBStorage(BaseVectorStorage): return # Add current time as Unix timestamp - import time - current_time = int(time.time()) list_data = [