diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index d49a36b7..e6576420 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -22,27 +22,25 @@ import pipmaster as pm if not pm.is_installed("pymongo"): pm.install("pymongo") -if not pm.is_installed("motor"): - pm.install("motor") - -from motor.motor_asyncio import ( # type: ignore - AsyncIOMotorClient, - AsyncIOMotorDatabase, - AsyncIOMotorCollection, -) +from pymongo import AsyncMongoClient # type: ignore +from pymongo.asynchronous.database import AsyncDatabase # type: ignore +from pymongo.asynchronous.collection import AsyncCollection # type: ignore from pymongo.operations import SearchIndexModel # type: ignore from pymongo.errors import PyMongoError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + class ClientManager: _instances = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @classmethod - async def get_client(cls) -> AsyncIOMotorDatabase: + async def get_client(cls) -> AsyncMongoClient: async with cls._lock: if cls._instances["db"] is None: uri = os.environ.get( @@ -57,7 +55,7 @@ class ClientManager: "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) - client = AsyncIOMotorClient(uri) + client = AsyncMongoClient(uri) db = client.get_database(database_name) cls._instances["db"] = db cls._instances["ref_count"] = 0 @@ -65,7 +63,7 @@ class ClientManager: return cls._instances["db"] @classmethod - async def release_client(cls, db: AsyncIOMotorDatabase): + async def release_client(cls, db: AsyncDatabase): async with cls._lock: if db is not None: if db is cls._instances["db"]: @@ -77,8 +75,8 @@ class ClientManager: @final @dataclass class MongoKVStorage(BaseKVStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage): @final @dataclass class MongoDocStatusStorage(DocStatusStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage): logger.error(f"Error dropping doc status {self._collection_name}: {e}") return {"status": "error", "message": str(e)} + async def delete(self, ids: list[str]) -> None: + await self._data.delete_many({"_id": {"$in": ids}}) + @final @dataclass @@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage): A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. """ - db: AsyncIOMotorDatabase = field(default=None) - collection: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + collection: AsyncCollection = field(default=None) def __init__(self, namespace, global_config, embedding_func): super().__init__( @@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # + # Sample entity_relation document + # { + # "_id" : "CompanyA", + # "created_at" : 1749904575, + # "description" : "A major technology company", + # "edges" : [ + # { + # "target" : "ProductX", + # "relation": "Develops", // To distinguish multiple same-target relations + # "weight" : Double("1"), + # "description" : "CompanyA develops ProductX", + # "keywords" : "develop, produce", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "file_path" : "custom_kg", + # "created_at" : 1749904575 + # } + # ], + # "entity_id" : "CompanyA", + # "entity_type" : "Organization", + # "file_path" : "custom_kg", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec" + # } + async def _graph_lookup( self, start_node_id: str, max_depth: int = None ) -> List[dict]: @@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage): pipeline[1]["$graphLookup"]["maxDepth"] = max_depth # Return the matching doc plus a field "reachableNodes" - cursor = self.collection.aggregate(pipeline) + cursor = await self.collection.aggregate(pipeline) results = await cursor.to_list(None) # If there's no matching node, results = []. @@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if there's a direct single-hop edge from source_node_id to target_node_id. - - We'll do a $graphLookup with maxDepth=0 from the source node—meaning - "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 - and then see if the target node is in the "reachableNodes" at depth=0. - - But typically for a direct edge, we might just do a find_one. - Below is a demonstration approach. """ - # We can do a single-hop graphLookup (maxDepth=0 or 1). - # Then check if the target_node appears among the edges array. - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - "maxDepth": 0, # means: do not follow beyond immediate edges - } - }, - { - "$project": { - "_id": 0, - "reachableNodes._id": 1, # only keep the _id from the subdocs - } - }, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return False - - # results[0]["reachableNodes"] are the immediate neighbors - reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])] - return target_node_id in reachable_ids + # Direct check if the target_node appears among the edges array. + doc = await self.collection.find_one( + {"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1} + ) + return doc is not None # # ------------------------------------------------------------------------- @@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage): The easiest approach is typically two queries: - count of edges array in node_id's doc - count of how many other docs have node_id in their edges.target. - - But we'll do a $graphLookup demonstration for inbound edges: - 1) Outbound edges: direct from node's edges array - 2) Inbound edges: we can do a special $graphLookup from all docs - or do an explicit match. - - For demonstration, let's do this in two steps (with second step $graphLookup). """ # --- 1) Outbound edges (direct from doc) --- - doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) + doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1}) if not doc: return 0 + outbound_count = len(doc.get("edges", [])) # --- 2) Inbound edges: - # A simple way is: find all docs where "edges.target" == node_id. - # But let's do a $graphLookup from `node_id` in REVERSE. - # There's a trick to do "reverse" graphLookups: you'd store - # reversed edges or do a more advanced pipeline. Typically you'd do - # a direct match. We'll just do a direct match for inbound. - inbound_count_pipeline = [ - {"$match": {"edges.target": node_id}}, - { - "$project": { - "matchingEdgesCount": { - "$size": { - "$filter": { - "input": "$edges", - "as": "edge", - "cond": {"$eq": ["$$edge.target", node_id]}, - } - } - } - } - }, - {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}}, - ] - inbound_cursor = self.collection.aggregate(inbound_count_pipeline) - inbound_result = await inbound_cursor.to_list(None) - inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0 + inbound_count = await self.collection.count_documents({"edges.target": node_id}) return outbound_count + inbound_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: - """ - If your graph can hold multiple edges from the same src to the same tgt - (e.g. different 'relation' values), you can sum them. If it's always - one edge, this is typically 1 or 0. + """Get the total degree (sum of relationships) of two nodes. - We'll do a single-hop $graphLookup from src_id, - then count how many edges reference tgt_id at depth=0. - """ - pipeline = [ - {"$match": {"_id": src_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return 0 + Args: + src_id: Label of the source node + tgt_id: Label of the target node - # We can simply count how many edges in `results[0].edges` have target == tgt_id. - edges = results[0].get("edges", []) - count = sum(1 for e in edges if e.get("target") == tgt_id) - return count + Returns: + int: Sum of the degrees of both nodes + """ + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + return degrees # # ------------------------------------------------------------------------- @@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - pipeline = [ - {"$match": {"_id": source_node_id}}, + doc = await self.collection.find_one( { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } + "$or": [ + {"_id": source_node_id, "edges.target": target_node_id}, + {"_id": target_node_id, "edges.target": source_node_id}, + ] }, - {"$project": {"edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - docs = await cursor.to_list(None) - if not docs: + {"edges": 1}, + ) + if not doc: return None - for e in docs[0].get("edges", []): + for e in doc.get("edges", []): if e.get("target") == target_node_id: return e + if e.get("target") == source_node_id: + return e + return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ - Return a list of (source_id, target_id) for direct edges from source_node_id. - Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. + Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found """ - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"_id": 0, "edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - result = await cursor.to_list(None) - if not result: + doc = await self.get_node(source_node_id) + if not doc: return None - edges = result[0].get("edges", []) - return [(source_node_id, e["target"]) for e in edges] + edges = [] + for e in doc.get("edges", []): + edges.append((source_node_id, e.get("target"))) + + cursor = self.collection.find({"edges.target": source_node_id}) + source_docs = await cursor.to_list(None) + if not source_docs: + return edges + + for doc in source_docs: + edges.append((doc.get("_id"), source_node_id)) + + return edges + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + result = {} + + async for doc in self.collection.find({"_id": {"$in": node_ids}}): + result[doc.get("_id")] = doc + return result + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + # merge the outbound and inbound results with the same "_id" and sum the "degree" + merged_results = {} + + # Outbound degrees + outbound_pipeline = [ + {"$match": {"_id": {"$in": node_ids}}}, + {"$project": {"_id": 1, "degree": {"$size": "$edges"}}}, + ] + + cursor = await self.collection.aggregate(outbound_pipeline) + async for doc in cursor: + merged_results[doc.get("_id")] = doc.get("degree") + + # Inbound degrees + inbound_pipeline = [ + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": 1, "edges.target": 1}}, + {"$unwind": "$edges"}, + {"$match": {"edges.target": {"$in": node_ids}}}, + { + "$group": { + "_id": "$edges.target", + "degree": { + "$sum": 1 + }, # Count the number of incoming edges for each target node + } + }, + ] + + cursor = await self.collection.aggregate(inbound_pipeline) + async for doc in cursor: + merged_results[doc.get("_id")] = merged_results.get( + doc.get("_id"), 0 + ) + doc.get("degree") + + return merged_results + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """ + Batch retrieve edges for multiple nodes. + For each node, returns both outgoing and incoming edges to properly represent + the undirected graph nature. + + Args: + node_ids: List of node IDs (entity_id) for which to retrieve edges. + + Returns: + A dictionary mapping each node ID to its list of edge tuples (source, target). + For each node, the list includes both: + - Outgoing edges: (queried_node, connected_node) + - Incoming edges: (connected_node, queried_node) + """ + result = {} + + cursor = self.collection.find( + {"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1} + ) + async for doc in cursor: + node_id = doc.get("_id") + edges = doc.get("edges", []) + result[node_id] = [(node_id, e["target"]) for e in edges] + + inbound_pipeline = [ + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": 1, "edges.target": 1}}, + {"$unwind": "$edges"}, + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": "$_id", "target": "$edges.target"}}, + ] + + cursor = await self.collection.aggregate(inbound_pipeline) + async for doc in cursor: + node_id = doc.get("target") + result[node_id] = result.get(node_id, []) + result[node_id].append((doc.get("_id"), node_id)) + + return result # # ------------------------------------------------------------------------- @@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage): Returns: [id1, id2, ...] # Alphabetically sorted id list """ - # Use MongoDB's distinct and aggregation to get all unique labels - pipeline = [ - {"$group": {"_id": "$_id"}}, # Group by _id - {"$sort": {"_id": 1}}, # Sort alphabetically - ] - cursor = self.collection.aggregate(pipeline) + cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)]) labels = [] async for doc in cursor: labels.append(doc["_id"]) return labels async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 5, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - update_tasks = [] - for source, target in edges: - # Remove edge pointing to target from source node's edges array - update_tasks.append( + # Group edges by source node (document _id) for efficient updates + edge_groups = {} + for source_id, target_id in edges: + if source_id not in edge_groups: + edge_groups[source_id] = [] + edge_groups[source_id].append(target_id) + + # Bulk update documents to remove the edges + update_operations = [] + for source_id, target_ids in edge_groups.items(): + update_operations.append( self.collection.update_one( - {"_id": source}, {"$pull": {"edges": {"target": target}}} + {"_id": source_id}, + {"$pull": {"edges": {"target": {"$in": target_ids}}}}, ) ) - if update_tasks: - await asyncio.gather(*update_tasks) + if update_operations: + await asyncio.gather(*update_operations) logger.debug(f"Successfully deleted edges: {edges}") @@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - db: AsyncIOMotorDatabase | None = field(default=None) - _data: AsyncIOMotorCollection | None = field(default=None) + db: AsyncDatabase | None = field(default=None) + _data: AsyncCollection | None = field(default=None) def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return {"status": "error", "message": str(e)} -async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): +async def get_or_create_collection(db: AsyncDatabase, collection_name: str): collection_names = await db.list_collection_names() if collection_name not in collection_names: