diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e6576420..1e03132a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,7 +4,7 @@ import numpy as np import configparser import asyncio -from typing import Any, List, Union, final +from typing import Any, Union, final from ..base import ( BaseGraphStorage, @@ -321,7 +321,10 @@ class MongoGraphStorage(BaseGraphStorage): """ db: AsyncDatabase = field(default=None) + # node collection storing node_id, node_properties collection: AsyncCollection = field(default=None) + # edge collection storing source_node_id, target_node_id, and edge_properties + edgeCollection: AsyncCollection = field(default=None) def __init__(self, namespace, global_config, embedding_func): super().__init__( @@ -330,6 +333,7 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self._collection_name = self.namespace + self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): if self.db is None: @@ -337,6 +341,9 @@ class MongoGraphStorage(BaseGraphStorage): self.collection = await get_or_create_collection( self.db, self._collection_name ) + self.edge_collection = await get_or_create_collection( + self.db, self._edge_collection_name + ) logger.debug(f"Use MongoDB as KG {self._collection_name}") async def finalize(self): @@ -344,6 +351,7 @@ class MongoGraphStorage(BaseGraphStorage): await ClientManager.release_client(self.db) self.db = None self.collection = None + self.edge_collection = None # # ------------------------------------------------------------------------- @@ -374,52 +382,6 @@ class MongoGraphStorage(BaseGraphStorage): # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec" # } - async def _graph_lookup( - self, start_node_id: str, max_depth: int = None - ) -> List[dict]: - """ - Performs a $graphLookup starting from 'start_node_id' and returns - all reachable documents (including the start node itself). - - Pipeline Explanation: - - 1) $match: We match the start node document by _id = start_node_id. - - 2) $graphLookup: - "from": same collection, - "startWith": "$edges.target" (the immediate neighbors in 'edges'), - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "maxDepth": max_depth (if provided), - "depthField": "depth" (used for debugging or filtering). - - 3) We add an $project or $unwind as needed to extract data. - """ - pipeline = [ - {"$match": {"_id": start_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - } - }, - ] - - # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth - if max_depth is not None: - pipeline[1]["$graphLookup"]["maxDepth"] = max_depth - - # Return the matching doc plus a field "reachableNodes" - cursor = await self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - - # If there's no matching node, results = []. - # Otherwise, results[0] is the start node doc, - # plus results[0]["reachableNodes"] is the array of connected docs. - return results - # # ------------------------------------------------------------------------- # BASIC QUERIES @@ -439,8 +401,9 @@ class MongoGraphStorage(BaseGraphStorage): Check if there's a direct single-hop edge from source_node_id to target_node_id. """ # Direct check if the target_node appears among the edges array. - doc = await self.collection.find_one( - {"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1} + doc = await self.edge_collection.find_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + {"_id": 1}, ) return doc is not None @@ -453,21 +416,10 @@ class MongoGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """ Returns the total number of edges connected to node_id (both inbound and outbound). - The easiest approach is typically two queries: - - count of edges array in node_id's doc - - count of how many other docs have node_id in their edges.target. """ - # --- 1) Outbound edges (direct from doc) --- - doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1}) - if not doc: - return 0 - - outbound_count = len(doc.get("edges", [])) - - # --- 2) Inbound edges: - inbound_count = await self.collection.count_documents({"edges.target": node_id}) - - return outbound_count + inbound_count + return await self.edge_collection.count_documents( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. @@ -482,12 +434,7 @@ class MongoGraphStorage(BaseGraphStorage): src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - return degrees + return src_degree + trg_degree # # ------------------------------------------------------------------------- @@ -497,32 +444,27 @@ class MongoGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """ - Return the full node document (including "edges"), or None if missing. + Return the full node document, or None if missing. """ return await self.collection.find_one({"_id": node_id}) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - doc = await self.collection.find_one( + return await self.edge_collection.find_one( { "$or": [ - {"_id": source_node_id, "edges.target": target_node_id}, - {"_id": target_node_id, "edges.target": source_node_id}, + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, ] - }, - {"edges": 1}, + } ) - if not doc: - return None - - for e in doc.get("edges", []): - if e.get("target") == target_node_id: - return e - if e.get("target") == source_node_id: - return e - - return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ @@ -535,23 +477,19 @@ class MongoGraphStorage(BaseGraphStorage): list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges None: If no edges found """ - doc = await self.get_node(source_node_id) - if not doc: - return None + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": source_node_id}, + {"target_node_id": source_node_id}, + ] + }, + {"source_node_id": 1, "target_node_id": 1}, + ) - edges = [] - for e in doc.get("edges", []): - edges.append((source_node_id, e.get("target"))) - - cursor = self.collection.find({"edges.target": source_node_id}) - source_docs = await cursor.to_list(None) - if not source_docs: - return edges - - for doc in source_docs: - edges.append((doc.get("_id"), source_node_id)) - - return edges + return [ + (e.get("source_node_id"), e.get("target_node_id")) async for e in cursor + ] async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: result = {} @@ -566,31 +504,21 @@ class MongoGraphStorage(BaseGraphStorage): # Outbound degrees outbound_pipeline = [ - {"$match": {"_id": {"$in": node_ids}}}, - {"$project": {"_id": 1, "degree": {"$size": "$edges"}}}, + {"$match": {"source_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.collection.aggregate(outbound_pipeline) + cursor = await self.edge_collection.aggregate(outbound_pipeline) async for doc in cursor: merged_results[doc.get("_id")] = doc.get("degree") # Inbound degrees inbound_pipeline = [ - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": 1, "edges.target": 1}}, - {"$unwind": "$edges"}, - {"$match": {"edges.target": {"$in": node_ids}}}, - { - "$group": { - "_id": "$edges.target", - "degree": { - "$sum": 1 - }, # Count the number of incoming edges for each target node - } - }, + {"$match": {"target_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.collection.aggregate(inbound_pipeline) + cursor = await self.edge_collection.aggregate(inbound_pipeline) async for doc in cursor: merged_results[doc.get("_id")] = merged_results.get( doc.get("_id"), 0 @@ -615,29 +543,27 @@ class MongoGraphStorage(BaseGraphStorage): - Outgoing edges: (queried_node, connected_node) - Incoming edges: (connected_node, queried_node) """ - result = {} + result = {node_id: [] for node_id in node_ids} - cursor = self.collection.find( - {"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1} + # Query outgoing edges (where node is the source) + outgoing_cursor = self.edge_collection.find( + {"source_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, ) - async for doc in cursor: - node_id = doc.get("_id") - edges = doc.get("edges", []) - result[node_id] = [(node_id, e["target"]) for e in edges] + async for edge in outgoing_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[source].append((source, target)) - inbound_pipeline = [ - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": 1, "edges.target": 1}}, - {"$unwind": "$edges"}, - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": "$_id", "target": "$edges.target"}}, - ] - - cursor = await self.collection.aggregate(inbound_pipeline) - async for doc in cursor: - node_id = doc.get("target") - result[node_id] = result.get(node_id, []) - result[node_id].append((doc.get("_id"), node_id)) + # Query incoming edges (where node is the target) + incoming_cursor = self.edge_collection.find( + {"target_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, + ) + async for edge in incoming_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[target].append((source, target)) return result @@ -649,11 +575,9 @@ class MongoGraphStorage(BaseGraphStorage): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ - Insert or update a node document. If new, create an empty edges array. + Insert or update a node document. """ - # By default, preserve existing 'edges'. - # We'll only set 'edges' to [] on insert (no overwrite). - update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} + update_doc = {"$set": {**node_data}} await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( @@ -666,16 +590,10 @@ class MongoGraphStorage(BaseGraphStorage): # Ensure source node exists await self.upsert_node(source_node_id, {}) - # Remove existing edge (if any) - await self.collection.update_one( - {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} - ) - - # Insert new edge - new_edge = {"target": target_node_id} - new_edge.update(edge_data) - await self.collection.update_one( - {"_id": source_node_id}, {"$push": {"edges": new_edge}} + await self.edge_collection.update_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + {"$set": edge_data}, + upsert=True, ) # @@ -689,8 +607,10 @@ class MongoGraphStorage(BaseGraphStorage): 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ - # Remove inbound edges from all other docs - await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) + # Remove all edges + await self.edge_collection.delete_many( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) # Remove the node doc await self.collection.delete_one({"_id": node_id}) @@ -734,151 +654,92 @@ class MongoGraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + node_edges = [] try: - if label == "*": - # Get all nodes and edges - async for node_doc in self.collection.find({}): - node_id = str(node_doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_doc.get("_id")], - properties={ - k: v - for k, v in node_doc.items() - if k not in ["_id", "edges"] - }, - ) - ) - seen_nodes.add(node_id) + pipeline = [ + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, + {"$sort": {"edge_count": -1}}, + {"$limit": max_nodes}, + ] - # Process edges - for edge in node_doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + if label == "*": + all_node_count = await self.collection.count_documents({}) + result.is_truncated = all_node_count > max_nodes else: # Verify if starting node exists - start_nodes = self.collection.find({"_id": label}) - start_nodes_exist = await start_nodes.to_list(length=1) - if not start_nodes_exist: + start_node = await self.collection.find_one({"_id": label}) + if not start_node: logger.warning(f"Starting node with label {label} does not exist!") return result - # Use $graphLookup for traversal - pipeline = [ - { - "$match": {"_id": label} - }, # Start with nodes having the specified label - { - "$graphLookup": { - "from": self._collection_name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_nodes", - } - }, - ] + # Add starting node to pipeline + pipeline.insert(0, {"$match": {"_id": label}}) - async for doc in self.collection.aggregate(pipeline): - # Add the start node - node_id = str(doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[ - doc.get( - "_id", - ) - ], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "edges", - "connected_nodes", - "depth", - ] - }, - ) + cursor = await self.collection.aggregate(pipeline) + async for doc in cursor: + # Add the start node + node_id = str(doc["_id"]) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "connected_edges", + "edge_count", + ] + }, + ) + ) + seen_nodes.add(node_id) + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relationship", ""), + source=edge["source_node_id"], + target=edge["target_node_id"], + properties={ + k: v + for k, v in edge.items() + if k + not in [ + "_id", + "source_node_id", + "target_node_id", + "relationship", + ] + }, ) - seen_nodes.add(node_id) - - # Add edges from start node - for edge in doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) - - # Add connected nodes and their edges - for connected in doc.get("connected_nodes", []): - node_id = str(connected["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[connected.get("_id")], - properties={ - k: v - for k, v in connected.items() - if k not in ["_id", "edges", "depth"] - }, - ) - ) - seen_nodes.add(node_id) - - # Add edges from connected nodes - for edge in connected.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" @@ -903,9 +764,14 @@ class MongoGraphStorage(BaseGraphStorage): if not nodes: return - # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) - await self.collection.update_many( - {}, {"$pull": {"edges": {"target": {"$in": nodes}}}} + # 1. Remove all edges referencing these nodes + await self.edge_collection.delete_many( + { + "$or": [ + {"source_node_id": {"$in": nodes}}, + {"target_node_id": {"$in": nodes}}, + ] + } ) # 2. Delete the node documents @@ -923,25 +789,14 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - # Group edges by source node (document _id) for efficient updates - edge_groups = {} - for source_id, target_id in edges: - if source_id not in edge_groups: - edge_groups[source_id] = [] - edge_groups[source_id].append(target_id) - - # Bulk update documents to remove the edges - update_operations = [] - for source_id, target_ids in edge_groups.items(): - update_operations.append( - self.collection.update_one( - {"_id": source_id}, - {"$pull": {"edges": {"target": {"$in": target_ids}}}}, - ) - ) - - if update_operations: - await asyncio.gather(*update_operations) + await self.edge_collection.delete_many( + { + "$or": [ + {"source_node_id": source_id, "target_node_id": target_id} + for source_id, target_id in edges + ] + } + ) logger.debug(f"Successfully deleted edges: {edges}") @@ -958,9 +813,16 @@ class MongoGraphStorage(BaseGraphStorage): logger.info( f"Dropped {deleted_count} documents from graph {self._collection_name}" ) + + result = await self.edge_collection.delete_many({}) + edge_count = result.deleted_count + logger.info( + f"Dropped {edge_count} edges from graph {self._edge_collection_name}" + ) + return { "status": "success", - "message": f"{deleted_count} documents dropped", + "message": f"{deleted_count} documents and {edge_count} edges dropped", } except PyMongoError as e: logger.error(f"Error dropping graph {self._collection_name}: {e}")