From cf441aa84c2478da4064e2fe23df40d6b1c0bd66 Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Sun, 15 Jun 2025 21:22:32 +0800 Subject: [PATCH 1/7] Add missing methods for MongoGraphStorage --- lightrag/kg/mongo_impl.py | 386 +++++++++++++++++++++----------------- 1 file changed, 212 insertions(+), 174 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index d49a36b7..e6576420 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -22,27 +22,25 @@ import pipmaster as pm if not pm.is_installed("pymongo"): pm.install("pymongo") -if not pm.is_installed("motor"): - pm.install("motor") - -from motor.motor_asyncio import ( # type: ignore - AsyncIOMotorClient, - AsyncIOMotorDatabase, - AsyncIOMotorCollection, -) +from pymongo import AsyncMongoClient # type: ignore +from pymongo.asynchronous.database import AsyncDatabase # type: ignore +from pymongo.asynchronous.collection import AsyncCollection # type: ignore from pymongo.operations import SearchIndexModel # type: ignore from pymongo.errors import PyMongoError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + class ClientManager: _instances = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @classmethod - async def get_client(cls) -> AsyncIOMotorDatabase: + async def get_client(cls) -> AsyncMongoClient: async with cls._lock: if cls._instances["db"] is None: uri = os.environ.get( @@ -57,7 +55,7 @@ class ClientManager: "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) - client = AsyncIOMotorClient(uri) + client = AsyncMongoClient(uri) db = client.get_database(database_name) cls._instances["db"] = db cls._instances["ref_count"] = 0 @@ -65,7 +63,7 @@ class ClientManager: return cls._instances["db"] @classmethod - async def release_client(cls, db: AsyncIOMotorDatabase): + async def release_client(cls, db: AsyncDatabase): async with cls._lock: if db is not None: if db is cls._instances["db"]: @@ -77,8 +75,8 @@ class ClientManager: @final @dataclass class MongoKVStorage(BaseKVStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage): @final @dataclass class MongoDocStatusStorage(DocStatusStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage): logger.error(f"Error dropping doc status {self._collection_name}: {e}") return {"status": "error", "message": str(e)} + async def delete(self, ids: list[str]) -> None: + await self._data.delete_many({"_id": {"$in": ids}}) + @final @dataclass @@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage): A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. """ - db: AsyncIOMotorDatabase = field(default=None) - collection: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + collection: AsyncCollection = field(default=None) def __init__(self, namespace, global_config, embedding_func): super().__init__( @@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # + # Sample entity_relation document + # { + # "_id" : "CompanyA", + # "created_at" : 1749904575, + # "description" : "A major technology company", + # "edges" : [ + # { + # "target" : "ProductX", + # "relation": "Develops", // To distinguish multiple same-target relations + # "weight" : Double("1"), + # "description" : "CompanyA develops ProductX", + # "keywords" : "develop, produce", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "file_path" : "custom_kg", + # "created_at" : 1749904575 + # } + # ], + # "entity_id" : "CompanyA", + # "entity_type" : "Organization", + # "file_path" : "custom_kg", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec" + # } + async def _graph_lookup( self, start_node_id: str, max_depth: int = None ) -> List[dict]: @@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage): pipeline[1]["$graphLookup"]["maxDepth"] = max_depth # Return the matching doc plus a field "reachableNodes" - cursor = self.collection.aggregate(pipeline) + cursor = await self.collection.aggregate(pipeline) results = await cursor.to_list(None) # If there's no matching node, results = []. @@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if there's a direct single-hop edge from source_node_id to target_node_id. - - We'll do a $graphLookup with maxDepth=0 from the source node—meaning - "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 - and then see if the target node is in the "reachableNodes" at depth=0. - - But typically for a direct edge, we might just do a find_one. - Below is a demonstration approach. """ - # We can do a single-hop graphLookup (maxDepth=0 or 1). - # Then check if the target_node appears among the edges array. - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - "maxDepth": 0, # means: do not follow beyond immediate edges - } - }, - { - "$project": { - "_id": 0, - "reachableNodes._id": 1, # only keep the _id from the subdocs - } - }, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return False - - # results[0]["reachableNodes"] are the immediate neighbors - reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])] - return target_node_id in reachable_ids + # Direct check if the target_node appears among the edges array. + doc = await self.collection.find_one( + {"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1} + ) + return doc is not None # # ------------------------------------------------------------------------- @@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage): The easiest approach is typically two queries: - count of edges array in node_id's doc - count of how many other docs have node_id in their edges.target. - - But we'll do a $graphLookup demonstration for inbound edges: - 1) Outbound edges: direct from node's edges array - 2) Inbound edges: we can do a special $graphLookup from all docs - or do an explicit match. - - For demonstration, let's do this in two steps (with second step $graphLookup). """ # --- 1) Outbound edges (direct from doc) --- - doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) + doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1}) if not doc: return 0 + outbound_count = len(doc.get("edges", [])) # --- 2) Inbound edges: - # A simple way is: find all docs where "edges.target" == node_id. - # But let's do a $graphLookup from `node_id` in REVERSE. - # There's a trick to do "reverse" graphLookups: you'd store - # reversed edges or do a more advanced pipeline. Typically you'd do - # a direct match. We'll just do a direct match for inbound. - inbound_count_pipeline = [ - {"$match": {"edges.target": node_id}}, - { - "$project": { - "matchingEdgesCount": { - "$size": { - "$filter": { - "input": "$edges", - "as": "edge", - "cond": {"$eq": ["$$edge.target", node_id]}, - } - } - } - } - }, - {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}}, - ] - inbound_cursor = self.collection.aggregate(inbound_count_pipeline) - inbound_result = await inbound_cursor.to_list(None) - inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0 + inbound_count = await self.collection.count_documents({"edges.target": node_id}) return outbound_count + inbound_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: - """ - If your graph can hold multiple edges from the same src to the same tgt - (e.g. different 'relation' values), you can sum them. If it's always - one edge, this is typically 1 or 0. + """Get the total degree (sum of relationships) of two nodes. - We'll do a single-hop $graphLookup from src_id, - then count how many edges reference tgt_id at depth=0. - """ - pipeline = [ - {"$match": {"_id": src_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return 0 + Args: + src_id: Label of the source node + tgt_id: Label of the target node - # We can simply count how many edges in `results[0].edges` have target == tgt_id. - edges = results[0].get("edges", []) - count = sum(1 for e in edges if e.get("target") == tgt_id) - return count + Returns: + int: Sum of the degrees of both nodes + """ + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + return degrees # # ------------------------------------------------------------------------- @@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - pipeline = [ - {"$match": {"_id": source_node_id}}, + doc = await self.collection.find_one( { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } + "$or": [ + {"_id": source_node_id, "edges.target": target_node_id}, + {"_id": target_node_id, "edges.target": source_node_id}, + ] }, - {"$project": {"edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - docs = await cursor.to_list(None) - if not docs: + {"edges": 1}, + ) + if not doc: return None - for e in docs[0].get("edges", []): + for e in doc.get("edges", []): if e.get("target") == target_node_id: return e + if e.get("target") == source_node_id: + return e + return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ - Return a list of (source_id, target_id) for direct edges from source_node_id. - Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. + Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found """ - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"_id": 0, "edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - result = await cursor.to_list(None) - if not result: + doc = await self.get_node(source_node_id) + if not doc: return None - edges = result[0].get("edges", []) - return [(source_node_id, e["target"]) for e in edges] + edges = [] + for e in doc.get("edges", []): + edges.append((source_node_id, e.get("target"))) + + cursor = self.collection.find({"edges.target": source_node_id}) + source_docs = await cursor.to_list(None) + if not source_docs: + return edges + + for doc in source_docs: + edges.append((doc.get("_id"), source_node_id)) + + return edges + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + result = {} + + async for doc in self.collection.find({"_id": {"$in": node_ids}}): + result[doc.get("_id")] = doc + return result + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + # merge the outbound and inbound results with the same "_id" and sum the "degree" + merged_results = {} + + # Outbound degrees + outbound_pipeline = [ + {"$match": {"_id": {"$in": node_ids}}}, + {"$project": {"_id": 1, "degree": {"$size": "$edges"}}}, + ] + + cursor = await self.collection.aggregate(outbound_pipeline) + async for doc in cursor: + merged_results[doc.get("_id")] = doc.get("degree") + + # Inbound degrees + inbound_pipeline = [ + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": 1, "edges.target": 1}}, + {"$unwind": "$edges"}, + {"$match": {"edges.target": {"$in": node_ids}}}, + { + "$group": { + "_id": "$edges.target", + "degree": { + "$sum": 1 + }, # Count the number of incoming edges for each target node + } + }, + ] + + cursor = await self.collection.aggregate(inbound_pipeline) + async for doc in cursor: + merged_results[doc.get("_id")] = merged_results.get( + doc.get("_id"), 0 + ) + doc.get("degree") + + return merged_results + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """ + Batch retrieve edges for multiple nodes. + For each node, returns both outgoing and incoming edges to properly represent + the undirected graph nature. + + Args: + node_ids: List of node IDs (entity_id) for which to retrieve edges. + + Returns: + A dictionary mapping each node ID to its list of edge tuples (source, target). + For each node, the list includes both: + - Outgoing edges: (queried_node, connected_node) + - Incoming edges: (connected_node, queried_node) + """ + result = {} + + cursor = self.collection.find( + {"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1} + ) + async for doc in cursor: + node_id = doc.get("_id") + edges = doc.get("edges", []) + result[node_id] = [(node_id, e["target"]) for e in edges] + + inbound_pipeline = [ + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": 1, "edges.target": 1}}, + {"$unwind": "$edges"}, + {"$match": {"edges.target": {"$in": node_ids}}}, + {"$project": {"_id": "$_id", "target": "$edges.target"}}, + ] + + cursor = await self.collection.aggregate(inbound_pipeline) + async for doc in cursor: + node_id = doc.get("target") + result[node_id] = result.get(node_id, []) + result[node_id].append((doc.get("_id"), node_id)) + + return result # # ------------------------------------------------------------------------- @@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage): Returns: [id1, id2, ...] # Alphabetically sorted id list """ - # Use MongoDB's distinct and aggregation to get all unique labels - pipeline = [ - {"$group": {"_id": "$_id"}}, # Group by _id - {"$sort": {"_id": 1}}, # Sort alphabetically - ] - cursor = self.collection.aggregate(pipeline) + cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)]) labels = [] async for doc in cursor: labels.append(doc["_id"]) return labels async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 5, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - update_tasks = [] - for source, target in edges: - # Remove edge pointing to target from source node's edges array - update_tasks.append( + # Group edges by source node (document _id) for efficient updates + edge_groups = {} + for source_id, target_id in edges: + if source_id not in edge_groups: + edge_groups[source_id] = [] + edge_groups[source_id].append(target_id) + + # Bulk update documents to remove the edges + update_operations = [] + for source_id, target_ids in edge_groups.items(): + update_operations.append( self.collection.update_one( - {"_id": source}, {"$pull": {"edges": {"target": target}}} + {"_id": source_id}, + {"$pull": {"edges": {"target": {"$in": target_ids}}}}, ) ) - if update_tasks: - await asyncio.gather(*update_tasks) + if update_operations: + await asyncio.gather(*update_operations) logger.debug(f"Successfully deleted edges: {edges}") @@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - db: AsyncIOMotorDatabase | None = field(default=None) - _data: AsyncIOMotorCollection | None = field(default=None) + db: AsyncDatabase | None = field(default=None) + _data: AsyncCollection | None = field(default=None) def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return {"status": "error", "message": str(e)} -async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): +async def get_or_create_collection(db: AsyncDatabase, collection_name: str): collection_names = await db.list_collection_names() if collection_name not in collection_names: From a047d966abfaa745da6d415f234363717553ec70 Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Sat, 21 Jun 2025 21:05:04 +0800 Subject: [PATCH 2/7] MongoGraph: Separate edges from node collection --- lightrag/kg/mongo_impl.py | 484 ++++++++++++++------------------------ 1 file changed, 173 insertions(+), 311 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e6576420..1e03132a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,7 +4,7 @@ import numpy as np import configparser import asyncio -from typing import Any, List, Union, final +from typing import Any, Union, final from ..base import ( BaseGraphStorage, @@ -321,7 +321,10 @@ class MongoGraphStorage(BaseGraphStorage): """ db: AsyncDatabase = field(default=None) + # node collection storing node_id, node_properties collection: AsyncCollection = field(default=None) + # edge collection storing source_node_id, target_node_id, and edge_properties + edgeCollection: AsyncCollection = field(default=None) def __init__(self, namespace, global_config, embedding_func): super().__init__( @@ -330,6 +333,7 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self._collection_name = self.namespace + self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): if self.db is None: @@ -337,6 +341,9 @@ class MongoGraphStorage(BaseGraphStorage): self.collection = await get_or_create_collection( self.db, self._collection_name ) + self.edge_collection = await get_or_create_collection( + self.db, self._edge_collection_name + ) logger.debug(f"Use MongoDB as KG {self._collection_name}") async def finalize(self): @@ -344,6 +351,7 @@ class MongoGraphStorage(BaseGraphStorage): await ClientManager.release_client(self.db) self.db = None self.collection = None + self.edge_collection = None # # ------------------------------------------------------------------------- @@ -374,52 +382,6 @@ class MongoGraphStorage(BaseGraphStorage): # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec" # } - async def _graph_lookup( - self, start_node_id: str, max_depth: int = None - ) -> List[dict]: - """ - Performs a $graphLookup starting from 'start_node_id' and returns - all reachable documents (including the start node itself). - - Pipeline Explanation: - - 1) $match: We match the start node document by _id = start_node_id. - - 2) $graphLookup: - "from": same collection, - "startWith": "$edges.target" (the immediate neighbors in 'edges'), - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "maxDepth": max_depth (if provided), - "depthField": "depth" (used for debugging or filtering). - - 3) We add an $project or $unwind as needed to extract data. - """ - pipeline = [ - {"$match": {"_id": start_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - } - }, - ] - - # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth - if max_depth is not None: - pipeline[1]["$graphLookup"]["maxDepth"] = max_depth - - # Return the matching doc plus a field "reachableNodes" - cursor = await self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - - # If there's no matching node, results = []. - # Otherwise, results[0] is the start node doc, - # plus results[0]["reachableNodes"] is the array of connected docs. - return results - # # ------------------------------------------------------------------------- # BASIC QUERIES @@ -439,8 +401,9 @@ class MongoGraphStorage(BaseGraphStorage): Check if there's a direct single-hop edge from source_node_id to target_node_id. """ # Direct check if the target_node appears among the edges array. - doc = await self.collection.find_one( - {"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1} + doc = await self.edge_collection.find_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + {"_id": 1}, ) return doc is not None @@ -453,21 +416,10 @@ class MongoGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """ Returns the total number of edges connected to node_id (both inbound and outbound). - The easiest approach is typically two queries: - - count of edges array in node_id's doc - - count of how many other docs have node_id in their edges.target. """ - # --- 1) Outbound edges (direct from doc) --- - doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1}) - if not doc: - return 0 - - outbound_count = len(doc.get("edges", [])) - - # --- 2) Inbound edges: - inbound_count = await self.collection.count_documents({"edges.target": node_id}) - - return outbound_count + inbound_count + return await self.edge_collection.count_documents( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. @@ -482,12 +434,7 @@ class MongoGraphStorage(BaseGraphStorage): src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - return degrees + return src_degree + trg_degree # # ------------------------------------------------------------------------- @@ -497,32 +444,27 @@ class MongoGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """ - Return the full node document (including "edges"), or None if missing. + Return the full node document, or None if missing. """ return await self.collection.find_one({"_id": node_id}) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - doc = await self.collection.find_one( + return await self.edge_collection.find_one( { "$or": [ - {"_id": source_node_id, "edges.target": target_node_id}, - {"_id": target_node_id, "edges.target": source_node_id}, + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, ] - }, - {"edges": 1}, + } ) - if not doc: - return None - - for e in doc.get("edges", []): - if e.get("target") == target_node_id: - return e - if e.get("target") == source_node_id: - return e - - return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ @@ -535,23 +477,19 @@ class MongoGraphStorage(BaseGraphStorage): list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges None: If no edges found """ - doc = await self.get_node(source_node_id) - if not doc: - return None + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": source_node_id}, + {"target_node_id": source_node_id}, + ] + }, + {"source_node_id": 1, "target_node_id": 1}, + ) - edges = [] - for e in doc.get("edges", []): - edges.append((source_node_id, e.get("target"))) - - cursor = self.collection.find({"edges.target": source_node_id}) - source_docs = await cursor.to_list(None) - if not source_docs: - return edges - - for doc in source_docs: - edges.append((doc.get("_id"), source_node_id)) - - return edges + return [ + (e.get("source_node_id"), e.get("target_node_id")) async for e in cursor + ] async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: result = {} @@ -566,31 +504,21 @@ class MongoGraphStorage(BaseGraphStorage): # Outbound degrees outbound_pipeline = [ - {"$match": {"_id": {"$in": node_ids}}}, - {"$project": {"_id": 1, "degree": {"$size": "$edges"}}}, + {"$match": {"source_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.collection.aggregate(outbound_pipeline) + cursor = await self.edge_collection.aggregate(outbound_pipeline) async for doc in cursor: merged_results[doc.get("_id")] = doc.get("degree") # Inbound degrees inbound_pipeline = [ - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": 1, "edges.target": 1}}, - {"$unwind": "$edges"}, - {"$match": {"edges.target": {"$in": node_ids}}}, - { - "$group": { - "_id": "$edges.target", - "degree": { - "$sum": 1 - }, # Count the number of incoming edges for each target node - } - }, + {"$match": {"target_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.collection.aggregate(inbound_pipeline) + cursor = await self.edge_collection.aggregate(inbound_pipeline) async for doc in cursor: merged_results[doc.get("_id")] = merged_results.get( doc.get("_id"), 0 @@ -615,29 +543,27 @@ class MongoGraphStorage(BaseGraphStorage): - Outgoing edges: (queried_node, connected_node) - Incoming edges: (connected_node, queried_node) """ - result = {} + result = {node_id: [] for node_id in node_ids} - cursor = self.collection.find( - {"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1} + # Query outgoing edges (where node is the source) + outgoing_cursor = self.edge_collection.find( + {"source_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, ) - async for doc in cursor: - node_id = doc.get("_id") - edges = doc.get("edges", []) - result[node_id] = [(node_id, e["target"]) for e in edges] + async for edge in outgoing_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[source].append((source, target)) - inbound_pipeline = [ - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": 1, "edges.target": 1}}, - {"$unwind": "$edges"}, - {"$match": {"edges.target": {"$in": node_ids}}}, - {"$project": {"_id": "$_id", "target": "$edges.target"}}, - ] - - cursor = await self.collection.aggregate(inbound_pipeline) - async for doc in cursor: - node_id = doc.get("target") - result[node_id] = result.get(node_id, []) - result[node_id].append((doc.get("_id"), node_id)) + # Query incoming edges (where node is the target) + incoming_cursor = self.edge_collection.find( + {"target_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, + ) + async for edge in incoming_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[target].append((source, target)) return result @@ -649,11 +575,9 @@ class MongoGraphStorage(BaseGraphStorage): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ - Insert or update a node document. If new, create an empty edges array. + Insert or update a node document. """ - # By default, preserve existing 'edges'. - # We'll only set 'edges' to [] on insert (no overwrite). - update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} + update_doc = {"$set": {**node_data}} await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( @@ -666,16 +590,10 @@ class MongoGraphStorage(BaseGraphStorage): # Ensure source node exists await self.upsert_node(source_node_id, {}) - # Remove existing edge (if any) - await self.collection.update_one( - {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} - ) - - # Insert new edge - new_edge = {"target": target_node_id} - new_edge.update(edge_data) - await self.collection.update_one( - {"_id": source_node_id}, {"$push": {"edges": new_edge}} + await self.edge_collection.update_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + {"$set": edge_data}, + upsert=True, ) # @@ -689,8 +607,10 @@ class MongoGraphStorage(BaseGraphStorage): 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ - # Remove inbound edges from all other docs - await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) + # Remove all edges + await self.edge_collection.delete_many( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) # Remove the node doc await self.collection.delete_one({"_id": node_id}) @@ -734,151 +654,92 @@ class MongoGraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + node_edges = [] try: - if label == "*": - # Get all nodes and edges - async for node_doc in self.collection.find({}): - node_id = str(node_doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_doc.get("_id")], - properties={ - k: v - for k, v in node_doc.items() - if k not in ["_id", "edges"] - }, - ) - ) - seen_nodes.add(node_id) + pipeline = [ + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, + {"$sort": {"edge_count": -1}}, + {"$limit": max_nodes}, + ] - # Process edges - for edge in node_doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + if label == "*": + all_node_count = await self.collection.count_documents({}) + result.is_truncated = all_node_count > max_nodes else: # Verify if starting node exists - start_nodes = self.collection.find({"_id": label}) - start_nodes_exist = await start_nodes.to_list(length=1) - if not start_nodes_exist: + start_node = await self.collection.find_one({"_id": label}) + if not start_node: logger.warning(f"Starting node with label {label} does not exist!") return result - # Use $graphLookup for traversal - pipeline = [ - { - "$match": {"_id": label} - }, # Start with nodes having the specified label - { - "$graphLookup": { - "from": self._collection_name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_nodes", - } - }, - ] + # Add starting node to pipeline + pipeline.insert(0, {"$match": {"_id": label}}) - async for doc in self.collection.aggregate(pipeline): - # Add the start node - node_id = str(doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[ - doc.get( - "_id", - ) - ], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "edges", - "connected_nodes", - "depth", - ] - }, - ) + cursor = await self.collection.aggregate(pipeline) + async for doc in cursor: + # Add the start node + node_id = str(doc["_id"]) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "connected_edges", + "edge_count", + ] + }, + ) + ) + seen_nodes.add(node_id) + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relationship", ""), + source=edge["source_node_id"], + target=edge["target_node_id"], + properties={ + k: v + for k, v in edge.items() + if k + not in [ + "_id", + "source_node_id", + "target_node_id", + "relationship", + ] + }, ) - seen_nodes.add(node_id) - - # Add edges from start node - for edge in doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) - - # Add connected nodes and their edges - for connected in doc.get("connected_nodes", []): - node_id = str(connected["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[connected.get("_id")], - properties={ - k: v - for k, v in connected.items() - if k not in ["_id", "edges", "depth"] - }, - ) - ) - seen_nodes.add(node_id) - - # Add edges from connected nodes - for edge in connected.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" @@ -903,9 +764,14 @@ class MongoGraphStorage(BaseGraphStorage): if not nodes: return - # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) - await self.collection.update_many( - {}, {"$pull": {"edges": {"target": {"$in": nodes}}}} + # 1. Remove all edges referencing these nodes + await self.edge_collection.delete_many( + { + "$or": [ + {"source_node_id": {"$in": nodes}}, + {"target_node_id": {"$in": nodes}}, + ] + } ) # 2. Delete the node documents @@ -923,25 +789,14 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - # Group edges by source node (document _id) for efficient updates - edge_groups = {} - for source_id, target_id in edges: - if source_id not in edge_groups: - edge_groups[source_id] = [] - edge_groups[source_id].append(target_id) - - # Bulk update documents to remove the edges - update_operations = [] - for source_id, target_ids in edge_groups.items(): - update_operations.append( - self.collection.update_one( - {"_id": source_id}, - {"$pull": {"edges": {"target": {"$in": target_ids}}}}, - ) - ) - - if update_operations: - await asyncio.gather(*update_operations) + await self.edge_collection.delete_many( + { + "$or": [ + {"source_node_id": source_id, "target_node_id": target_id} + for source_id, target_id in edges + ] + } + ) logger.debug(f"Successfully deleted edges: {edges}") @@ -958,9 +813,16 @@ class MongoGraphStorage(BaseGraphStorage): logger.info( f"Dropped {deleted_count} documents from graph {self._collection_name}" ) + + result = await self.edge_collection.delete_many({}) + edge_count = result.deleted_count + logger.info( + f"Dropped {edge_count} edges from graph {self._edge_collection_name}" + ) + return { "status": "success", - "message": f"{deleted_count} documents dropped", + "message": f"{deleted_count} documents and {edge_count} edges dropped", } except PyMongoError as e: logger.error(f"Error dropping graph {self._collection_name}: {e}") From a3865caaeac0b37e54b160a9362d8c3a2ee9f51f Mon Sep 17 00:00:00 2001 From: Ken Chen Date: Wed, 25 Jun 2025 22:17:17 +0800 Subject: [PATCH 3/7] Implement get_nodes_by_chunk_ids and get_edges_by_chunk_ids, --- lightrag/kg/mongo_impl.py | 114 ++++++++++++++++++++++++++---------- tests/test_graph_storage.py | 80 +++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 30 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 1e03132a..2fdbb270 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -17,6 +17,8 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP + import pipmaster as pm if not pm.is_installed("pymongo"): @@ -353,33 +355,33 @@ class MongoGraphStorage(BaseGraphStorage): self.collection = None self.edge_collection = None - # - # ------------------------------------------------------------------------- - # HELPER: $graphLookup pipeline - # ------------------------------------------------------------------------- - # + # Sample entity document + # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP - # Sample entity_relation document # { # "_id" : "CompanyA", - # "created_at" : 1749904575, - # "description" : "A major technology company", - # "edges" : [ - # { - # "target" : "ProductX", - # "relation": "Develops", // To distinguish multiple same-target relations - # "weight" : Double("1"), - # "description" : "CompanyA develops ProductX", - # "keywords" : "develop, produce", - # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", - # "file_path" : "custom_kg", - # "created_at" : 1749904575 - # } - # ], # "entity_id" : "CompanyA", # "entity_type" : "Organization", + # "description" : "A major technology company", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"], # "file_path" : "custom_kg", - # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec" + # "created_at" : 1749904575 + # } + + # Sample relation document + # { + # "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId + # "description" : "CompanyA develops ProductX", + # "source_node_id" : "CompanyA", + # "target_node_id" : "ProductX", + # "relationship": "Develops", // To distinguish multiple same-target relations + # "weight" : Double("1"), + # "keywords" : "develop, produce", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"], + # "file_path" : "custom_kg", + # "created_at" : 1749904575 # } # @@ -567,6 +569,45 @@ class MongoGraphStorage(BaseGraphStorage): return result + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + if not chunk_ids: + return [] + + cursor = self.collection.find({"source_ids": {"$in": chunk_ids}}) + return [doc async for doc in cursor] + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated edges for. + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + if not chunk_ids: + return [] + + cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}}) + + edges = [] + async for edge in cursor: + edge["source"] = edge["source_node_id"] + edge["target"] = edge["target_node_id"] + edges.append(edge) + + return edges + # # ------------------------------------------------------------------------- # UPSERTS @@ -578,6 +619,11 @@ class MongoGraphStorage(BaseGraphStorage): Insert or update a node document. """ update_doc = {"$set": {**node_data}} + if node_data.get("source_id", ""): + update_doc["$set"]["source_ids"] = node_data["source_id"].split( + GRAPH_FIELD_SEP + ) + await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( @@ -590,9 +636,15 @@ class MongoGraphStorage(BaseGraphStorage): # Ensure source node exists await self.upsert_node(source_node_id, {}) + update_doc = {"$set": edge_data} + if edge_data.get("source_id", ""): + update_doc["$set"]["source_ids"] = edge_data["source_id"].split( + GRAPH_FIELD_SEP + ) + await self.edge_collection.update_one( {"source_node_id": source_node_id, "target_node_id": target_node_id}, - {"$set": edge_data}, + update_doc, upsert=True, ) @@ -789,14 +841,16 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - await self.edge_collection.delete_many( - { - "$or": [ - {"source_node_id": source_id, "target_node_id": target_id} - for source_id, target_id in edges - ] - } - ) + all_edge_pairs = [] + for source_id, target_id in edges: + all_edge_pairs.append( + {"source_node_id": source_id, "target_node_id": target_id} + ) + all_edge_pairs.append( + {"source_node_id": target_id, "target_node_id": source_id} + ) + + await self.edge_collection.delete_many({"$or": all_edge_pairs}) logger.debug(f"Successfully deleted edges: {edges}") diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index fb78270d..64e66f48 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -30,6 +30,7 @@ from lightrag.kg import ( verify_storage_implementation, ) from lightrag.kg.shared_storage import initialize_share_data +from lightrag.constants import GRAPH_FIELD_SEP # 模拟的嵌入函数,返回随机向量 @@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage): 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边 """ try: + chunk1_id = "1" + chunk2_id = "2" + chunk3_id = "3" # 1. 插入测试数据 # 插入节点1: 人工智能 node1_id = "人工智能" @@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage): "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。", "keywords": "AI,机器学习,深度学习", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]), } print(f"插入节点1: {node1_id}") await storage.upsert_node(node1_id, node1_data) @@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage): "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。", "keywords": "监督学习,无监督学习,强化学习", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]), } print(f"插入节点2: {node2_id}") await storage.upsert_node(node2_id, node2_data) @@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage): "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。", "keywords": "神经网络,CNN,RNN", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk3_id]), } print(f"插入节点3: {node3_id}") await storage.upsert_node(node3_id, node3_data) @@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "人工智能领域包含机器学习这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]), } print(f"插入边1: {node1_id} -> {node2_id}") await storage.upsert_edge(node1_id, node2_id, edge1_data) @@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "机器学习领域包含深度学习这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]), } print(f"插入边2: {node2_id} -> {node3_id}") await storage.upsert_edge(node2_id, node3_id, edge2_data) @@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "人工智能领域包含自然语言处理这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk3_id]), } print(f"插入边3: {node1_id} -> {node4_id}") await storage.upsert_edge(node1_id, node4_id, edge3_data) @@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage): print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)") + # 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点 + print("== 测试 get_nodes_by_chunk_ids") + + print("== 测试单个 chunk_id,匹配多个节点") + nodes = await storage.get_nodes_by_chunk_ids([chunk2_id]) + assert len(nodes) == 2, f"{chunk1_id} 应有2个节点,实际有 {len(nodes)} 个" + + has_node1 = any(node["entity_id"] == node1_id for node in nodes) + has_node2 = any(node["entity_id"] == node2_id for node in nodes) + + assert has_node1, f"节点 {node1_id} 应在返回结果中" + assert has_node2, f"节点 {node2_id} 应在返回结果中" + + print("== 测试多个 chunk_id,部分匹配多个节点") + nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id]) + assert ( + len(nodes) == 3 + ), f"{chunk2_id}, {chunk3_id} 应有3个节点,实际有 {len(nodes)} 个" + + has_node1 = any(node["entity_id"] == node1_id for node in nodes) + has_node2 = any(node["entity_id"] == node2_id for node in nodes) + has_node3 = any(node["entity_id"] == node3_id for node in nodes) + + assert has_node1, f"节点 {node1_id} 应在返回结果中" + assert has_node2, f"节点 {node2_id} 应在返回结果中" + assert has_node3, f"节点 {node3_id} 应在返回结果中" + + # 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边 + print("== 测试 get_edges_by_chunk_ids") + + print("== 测试单个 chunk_id,匹配多条边") + edges = await storage.get_edges_by_chunk_ids([chunk2_id]) + assert len(edges) == 2, f"{chunk2_id} 应有2条边,实际有 {len(edges)} 条" + + has_edge_node1_node2 = any( + edge["source"] == node1_id and edge["target"] == node2_id for edge in edges + ) + has_edge_node2_node3 = any( + edge["source"] == node2_id and edge["target"] == node3_id for edge in edges + ) + + assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id} 到 {node2_id} 的边" + assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id} 到 {node3_id} 的边" + + print("== 测试多个 chunk_id,部分匹配多条边") + edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id]) + assert ( + len(edges) == 3 + ), f"{chunk2_id}, {chunk3_id} 应有3条边,实际有 {len(edges)} 条" + + has_edge_node1_node2 = any( + edge["source"] == node1_id and edge["target"] == node2_id for edge in edges + ) + has_edge_node2_node3 = any( + edge["source"] == node2_id and edge["target"] == node3_id for edge in edges + ) + has_edge_node1_node4 = any( + edge["source"] == node1_id and edge["target"] == node4_id for edge in edges + ) + + assert ( + has_edge_node1_node2 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node2_id} 的边" + assert ( + has_edge_node2_node3 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id} 到 {node3_id} 的边" + assert ( + has_edge_node1_node4 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node4_id} 的边" + print("\n批量操作测试完成") return True From 6364d076aa91291a588a69c522a32a5f6b5e5c58 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 26 Jun 2025 13:47:50 +0800 Subject: [PATCH 4/7] Enable MongoGraphStorage --- lightrag/kg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index bbddb285..b4ba0983 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -14,8 +14,8 @@ STORAGE_IMPLEMENTATIONS = { "NetworkXStorage", "Neo4JStorage", "PGGraphStorage", + "MongoGraphStorage", # "AGEStorage", - # "MongoGraphStorage", # "TiDBGraphStorage", # "GremlinStorage", ], From d512db26e4969975e670964c2964d2d1cd15a923 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 26 Jun 2025 13:50:19 +0800 Subject: [PATCH 5/7] Fix MongoDB set handling in delete operations --- lightrag/kg/mongo_impl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 2fdbb270..109ba59d 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -159,6 +159,10 @@ class MongoKVStorage(BaseKVStorage): if not ids: return + # Convert to list if it's a set (MongoDB BSON cannot encode sets) + if isinstance(ids, set): + ids = list(ids) + try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.info( @@ -1044,6 +1048,10 @@ class MongoVectorDBStorage(BaseVectorStorage): if not ids: return + # Convert to list if it's a set (MongoDB BSON cannot encode sets) + if isinstance(ids, set): + ids = list(ids) + try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.debug( From 71565f47945a93c69312ec6963941984fdfa88e8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 26 Jun 2025 13:51:15 +0800 Subject: [PATCH 6/7] Add get_all method to MongoKVStorage --- lightrag/kg/mongo_impl.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 109ba59d..6be02d1d 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -107,6 +107,19 @@ class MongoKVStorage(BaseKVStorage): existing_ids = {str(x["_id"]) async for x in cursor} return keys - existing_ids + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + cursor = self._data.find({}) + result = {} + async for doc in cursor: + doc_id = doc.pop("_id") + result[doc_id] = doc + return result + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: From d60db573dc1db438239bb619436af343eb0ca7e3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 26 Jun 2025 13:51:53 +0800 Subject: [PATCH 7/7] Add allowDiskUse flag to MongoDB aggregations - Enable disk use for large aggregations - Fix cursor handling for list_search_indexes - Improve query performance for big datasets - Update vector search index check - Set proper length for to_list results --- lightrag/kg/mongo_impl.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 6be02d1d..f5a87cbe 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -276,7 +276,7 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] - cursor = self._data.aggregate(pipeline) + cursor = self._data.aggregate(pipeline, allowDiskUse=True) result = await cursor.to_list() counts = {} for doc in result: @@ -527,7 +527,7 @@ class MongoGraphStorage(BaseGraphStorage): {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.edge_collection.aggregate(outbound_pipeline) + cursor = await self.edge_collection.aggregate(outbound_pipeline, allowDiskUse=True) async for doc in cursor: merged_results[doc.get("_id")] = doc.get("degree") @@ -537,7 +537,7 @@ class MongoGraphStorage(BaseGraphStorage): {"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}}, ] - cursor = await self.edge_collection.aggregate(inbound_pipeline) + cursor = await self.edge_collection.aggregate(inbound_pipeline, allowDiskUse=True) async for doc in cursor: merged_results[doc.get("_id")] = merged_results.get( doc.get("_id"), 0 @@ -756,7 +756,7 @@ class MongoGraphStorage(BaseGraphStorage): # Add starting node to pipeline pipeline.insert(0, {"$match": {"_id": label}}) - cursor = await self.collection.aggregate(pipeline) + cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) async for doc in cursor: # Add the start node node_id = str(doc["_id"]) @@ -938,7 +938,8 @@ class MongoVectorDBStorage(BaseVectorStorage): try: index_name = "vector_knn_index" - indexes = await self._data.list_search_indexes().to_list(length=None) + indexes_cursor = await self._data.list_search_indexes() + indexes = await indexes_cursor.to_list(length=None) for index in indexes: if index["name"] == index_name: logger.debug("vector index already exist") @@ -1033,8 +1034,8 @@ class MongoVectorDBStorage(BaseVectorStorage): ] # Execute the aggregation pipeline - cursor = self._data.aggregate(pipeline) - results = await cursor.to_list() + cursor = await self._data.aggregate(pipeline, allowDiskUse=True) + results = await cursor.to_list(length=None) # Format and return the results with created_at field return [