diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index bbddb285..b4ba0983 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -14,8 +14,8 @@ STORAGE_IMPLEMENTATIONS = { "NetworkXStorage", "Neo4JStorage", "PGGraphStorage", + "MongoGraphStorage", # "AGEStorage", - # "MongoGraphStorage", # "TiDBGraphStorage", # "GremlinStorage", ], diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index d49a36b7..f5a87cbe 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -4,7 +4,7 @@ import numpy as np import configparser import asyncio -from typing import Any, List, Union, final +from typing import Any, Union, final from ..base import ( BaseGraphStorage, @@ -17,32 +17,32 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP + import pipmaster as pm if not pm.is_installed("pymongo"): pm.install("pymongo") -if not pm.is_installed("motor"): - pm.install("motor") - -from motor.motor_asyncio import ( # type: ignore - AsyncIOMotorClient, - AsyncIOMotorDatabase, - AsyncIOMotorCollection, -) +from pymongo import AsyncMongoClient # type: ignore +from pymongo.asynchronous.database import AsyncDatabase # type: ignore +from pymongo.asynchronous.collection import AsyncCollection # type: ignore from pymongo.operations import SearchIndexModel # type: ignore from pymongo.errors import PyMongoError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + class ClientManager: _instances = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @classmethod - async def get_client(cls) -> AsyncIOMotorDatabase: + async def get_client(cls) -> AsyncMongoClient: async with cls._lock: if cls._instances["db"] is None: uri = os.environ.get( @@ -57,7 +57,7 @@ class ClientManager: "MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"), ) - client = AsyncIOMotorClient(uri) + client = AsyncMongoClient(uri) db = client.get_database(database_name) cls._instances["db"] = db cls._instances["ref_count"] = 0 @@ -65,7 +65,7 @@ class ClientManager: return cls._instances["db"] @classmethod - async def release_client(cls, db: AsyncIOMotorDatabase): + async def release_client(cls, db: AsyncDatabase): async with cls._lock: if db is not None: if db is cls._instances["db"]: @@ -77,8 +77,8 @@ class ClientManager: @final @dataclass class MongoKVStorage(BaseKVStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -107,6 +107,19 @@ class MongoKVStorage(BaseKVStorage): existing_ids = {str(x["_id"]) async for x in cursor} return keys - existing_ids + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + cursor = self._data.find({}) + result = {} + async for doc in cursor: + doc_id = doc.pop("_id") + result[doc_id] = doc + return result + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: @@ -159,6 +172,10 @@ class MongoKVStorage(BaseKVStorage): if not ids: return + # Convert to list if it's a set (MongoDB BSON cannot encode sets) + if isinstance(ids, set): + ids = list(ids) + try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.info( @@ -214,8 +231,8 @@ class MongoKVStorage(BaseKVStorage): @final @dataclass class MongoDocStatusStorage(DocStatusStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + _data: AsyncCollection = field(default=None) def __post_init__(self): self._collection_name = self.namespace @@ -259,7 +276,7 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] - cursor = self._data.aggregate(pipeline) + cursor = self._data.aggregate(pipeline, allowDiskUse=True) result = await cursor.to_list() counts = {} for doc in result: @@ -311,6 +328,9 @@ class MongoDocStatusStorage(DocStatusStorage): logger.error(f"Error dropping doc status {self._collection_name}: {e}") return {"status": "error", "message": str(e)} + async def delete(self, ids: list[str]) -> None: + await self._data.delete_many({"_id": {"$in": ids}}) + @final @dataclass @@ -319,8 +339,11 @@ class MongoGraphStorage(BaseGraphStorage): A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. """ - db: AsyncIOMotorDatabase = field(default=None) - collection: AsyncIOMotorCollection = field(default=None) + db: AsyncDatabase = field(default=None) + # node collection storing node_id, node_properties + collection: AsyncCollection = field(default=None) + # edge collection storing source_node_id, target_node_id, and edge_properties + edgeCollection: AsyncCollection = field(default=None) def __init__(self, namespace, global_config, embedding_func): super().__init__( @@ -329,6 +352,7 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self._collection_name = self.namespace + self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): if self.db is None: @@ -336,6 +360,9 @@ class MongoGraphStorage(BaseGraphStorage): self.collection = await get_or_create_collection( self.db, self._collection_name ) + self.edge_collection = await get_or_create_collection( + self.db, self._edge_collection_name + ) logger.debug(f"Use MongoDB as KG {self._collection_name}") async def finalize(self): @@ -343,58 +370,36 @@ class MongoGraphStorage(BaseGraphStorage): await ClientManager.release_client(self.db) self.db = None self.collection = None + self.edge_collection = None - # - # ------------------------------------------------------------------------- - # HELPER: $graphLookup pipeline - # ------------------------------------------------------------------------- - # + # Sample entity document + # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP - async def _graph_lookup( - self, start_node_id: str, max_depth: int = None - ) -> List[dict]: - """ - Performs a $graphLookup starting from 'start_node_id' and returns - all reachable documents (including the start node itself). + # { + # "_id" : "CompanyA", + # "entity_id" : "CompanyA", + # "entity_type" : "Organization", + # "description" : "A major technology company", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"], + # "file_path" : "custom_kg", + # "created_at" : 1749904575 + # } - Pipeline Explanation: - - 1) $match: We match the start node document by _id = start_node_id. - - 2) $graphLookup: - "from": same collection, - "startWith": "$edges.target" (the immediate neighbors in 'edges'), - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "maxDepth": max_depth (if provided), - "depthField": "depth" (used for debugging or filtering). - - 3) We add an $project or $unwind as needed to extract data. - """ - pipeline = [ - {"$match": {"_id": start_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - } - }, - ] - - # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth - if max_depth is not None: - pipeline[1]["$graphLookup"]["maxDepth"] = max_depth - - # Return the matching doc plus a field "reachableNodes" - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - - # If there's no matching node, results = []. - # Otherwise, results[0] is the start node doc, - # plus results[0]["reachableNodes"] is the array of connected docs. - return results + # Sample relation document + # { + # "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId + # "description" : "CompanyA develops ProductX", + # "source_node_id" : "CompanyA", + # "target_node_id" : "ProductX", + # "relationship": "Develops", // To distinguish multiple same-target relations + # "weight" : Double("1"), + # "keywords" : "develop, produce", + # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec", + # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"], + # "file_path" : "custom_kg", + # "created_at" : 1749904575 + # } # # ------------------------------------------------------------------------- @@ -413,44 +418,13 @@ class MongoGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if there's a direct single-hop edge from source_node_id to target_node_id. - - We'll do a $graphLookup with maxDepth=0 from the source node—meaning - "Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1 - and then see if the target node is in the "reachableNodes" at depth=0. - - But typically for a direct edge, we might just do a find_one. - Below is a demonstration approach. """ - # We can do a single-hop graphLookup (maxDepth=0 or 1). - # Then check if the target_node appears among the edges array. - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "reachableNodes", - "depthField": "depth", - "maxDepth": 0, # means: do not follow beyond immediate edges - } - }, - { - "$project": { - "_id": 0, - "reachableNodes._id": 1, # only keep the _id from the subdocs - } - }, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return False - - # results[0]["reachableNodes"] are the immediate neighbors - reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])] - return target_node_id in reachable_ids + # Direct check if the target_node appears among the edges array. + doc = await self.edge_collection.find_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + {"_id": 1}, + ) + return doc is not None # # ------------------------------------------------------------------------- @@ -461,85 +435,25 @@ class MongoGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """ Returns the total number of edges connected to node_id (both inbound and outbound). - The easiest approach is typically two queries: - - count of edges array in node_id's doc - - count of how many other docs have node_id in their edges.target. - - But we'll do a $graphLookup demonstration for inbound edges: - 1) Outbound edges: direct from node's edges array - 2) Inbound edges: we can do a special $graphLookup from all docs - or do an explicit match. - - For demonstration, let's do this in two steps (with second step $graphLookup). """ - # --- 1) Outbound edges (direct from doc) --- - doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) - if not doc: - return 0 - outbound_count = len(doc.get("edges", [])) - - # --- 2) Inbound edges: - # A simple way is: find all docs where "edges.target" == node_id. - # But let's do a $graphLookup from `node_id` in REVERSE. - # There's a trick to do "reverse" graphLookups: you'd store - # reversed edges or do a more advanced pipeline. Typically you'd do - # a direct match. We'll just do a direct match for inbound. - inbound_count_pipeline = [ - {"$match": {"edges.target": node_id}}, - { - "$project": { - "matchingEdgesCount": { - "$size": { - "$filter": { - "input": "$edges", - "as": "edge", - "cond": {"$eq": ["$$edge.target", node_id]}, - } - } - } - } - }, - {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}}, - ] - inbound_cursor = self.collection.aggregate(inbound_count_pipeline) - inbound_result = await inbound_cursor.to_list(None) - inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0 - - return outbound_count + inbound_count + return await self.edge_collection.count_documents( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - """ - If your graph can hold multiple edges from the same src to the same tgt - (e.g. different 'relation' values), you can sum them. If it's always - one edge, this is typically 1 or 0. + """Get the total degree (sum of relationships) of two nodes. - We'll do a single-hop $graphLookup from src_id, - then count how many edges reference tgt_id at depth=0. - """ - pipeline = [ - {"$match": {"_id": src_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - results = await cursor.to_list(None) - if not results: - return 0 + Args: + src_id: Label of the source node + tgt_id: Label of the target node - # We can simply count how many edges in `results[0].edges` have target == tgt_id. - edges = results[0].get("edges", []) - count = sum(1 for e in edges if e.get("target") == tgt_id) - return count + Returns: + int: Sum of the degrees of both nodes + """ + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + return src_degree + trg_degree # # ------------------------------------------------------------------------- @@ -549,65 +463,167 @@ class MongoGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """ - Return the full node document (including "edges"), or None if missing. + Return the full node document, or None if missing. """ return await self.collection.find_one({"_id": node_id}) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - pipeline = [ - {"$match": {"_id": source_node_id}}, + return await self.edge_collection.find_one( { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - docs = await cursor.to_list(None) - if not docs: - return None - - for e in docs[0].get("edges", []): - if e.get("target") == target_node_id: - return e - return None + "$or": [ + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, + ] + } + ) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ - Return a list of (source_id, target_id) for direct edges from source_node_id. - Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. - """ - pipeline = [ - {"$match": {"_id": source_node_id}}, - { - "$graphLookup": { - "from": self.collection.name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "as": "neighbors", - "depthField": "depth", - "maxDepth": 0, - } - }, - {"$project": {"_id": 0, "edges": 1}}, - ] - cursor = self.collection.aggregate(pipeline) - result = await cursor.to_list(None) - if not result: - return None + Retrieves all edges (relationships) for a particular node identified by its label. - edges = result[0].get("edges", []) - return [(source_node_id, e["target"]) for e in edges] + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + """ + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": source_node_id}, + {"target_node_id": source_node_id}, + ] + }, + {"source_node_id": 1, "target_node_id": 1}, + ) + + return [ + (e.get("source_node_id"), e.get("target_node_id")) async for e in cursor + ] + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + result = {} + + async for doc in self.collection.find({"_id": {"$in": node_ids}}): + result[doc.get("_id")] = doc + return result + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + # merge the outbound and inbound results with the same "_id" and sum the "degree" + merged_results = {} + + # Outbound degrees + outbound_pipeline = [ + {"$match": {"source_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, + ] + + cursor = await self.edge_collection.aggregate(outbound_pipeline, allowDiskUse=True) + async for doc in cursor: + merged_results[doc.get("_id")] = doc.get("degree") + + # Inbound degrees + inbound_pipeline = [ + {"$match": {"target_node_id": {"$in": node_ids}}}, + {"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}}, + ] + + cursor = await self.edge_collection.aggregate(inbound_pipeline, allowDiskUse=True) + async for doc in cursor: + merged_results[doc.get("_id")] = merged_results.get( + doc.get("_id"), 0 + ) + doc.get("degree") + + return merged_results + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """ + Batch retrieve edges for multiple nodes. + For each node, returns both outgoing and incoming edges to properly represent + the undirected graph nature. + + Args: + node_ids: List of node IDs (entity_id) for which to retrieve edges. + + Returns: + A dictionary mapping each node ID to its list of edge tuples (source, target). + For each node, the list includes both: + - Outgoing edges: (queried_node, connected_node) + - Incoming edges: (connected_node, queried_node) + """ + result = {node_id: [] for node_id in node_ids} + + # Query outgoing edges (where node is the source) + outgoing_cursor = self.edge_collection.find( + {"source_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, + ) + async for edge in outgoing_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[source].append((source, target)) + + # Query incoming edges (where node is the target) + incoming_cursor = self.edge_collection.find( + {"target_node_id": {"$in": node_ids}}, + {"source_node_id": 1, "target_node_id": 1}, + ) + async for edge in incoming_cursor: + source = edge["source_node_id"] + target = edge["target_node_id"] + result[target].append((source, target)) + + return result + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated nodes for. + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + if not chunk_ids: + return [] + + cursor = self.collection.find({"source_ids": {"$in": chunk_ids}}) + return [doc async for doc in cursor] + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids (list[str]): A list of chunk IDs to find associated edges for. + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + if not chunk_ids: + return [] + + cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}}) + + edges = [] + async for edge in cursor: + edge["source"] = edge["source_node_id"] + edge["target"] = edge["target_node_id"] + edges.append(edge) + + return edges # # ------------------------------------------------------------------------- @@ -617,11 +633,14 @@ class MongoGraphStorage(BaseGraphStorage): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ - Insert or update a node document. If new, create an empty edges array. + Insert or update a node document. """ - # By default, preserve existing 'edges'. - # We'll only set 'edges' to [] on insert (no overwrite). - update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} + update_doc = {"$set": {**node_data}} + if node_data.get("source_id", ""): + update_doc["$set"]["source_ids"] = node_data["source_id"].split( + GRAPH_FIELD_SEP + ) + await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( @@ -634,16 +653,16 @@ class MongoGraphStorage(BaseGraphStorage): # Ensure source node exists await self.upsert_node(source_node_id, {}) - # Remove existing edge (if any) - await self.collection.update_one( - {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} - ) + update_doc = {"$set": edge_data} + if edge_data.get("source_id", ""): + update_doc["$set"]["source_ids"] = edge_data["source_id"].split( + GRAPH_FIELD_SEP + ) - # Insert new edge - new_edge = {"target": target_node_id} - new_edge.update(edge_data) - await self.collection.update_one( - {"_id": source_node_id}, {"$push": {"edges": new_edge}} + await self.edge_collection.update_one( + {"source_node_id": source_node_id, "target_node_id": target_node_id}, + update_doc, + upsert=True, ) # @@ -657,8 +676,10 @@ class MongoGraphStorage(BaseGraphStorage): 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. """ - # Remove inbound edges from all other docs - await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) + # Remove all edges + await self.edge_collection.delete_many( + {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]} + ) # Remove the node doc await self.collection.delete_one({"_id": node_id}) @@ -675,20 +696,18 @@ class MongoGraphStorage(BaseGraphStorage): Returns: [id1, id2, ...] # Alphabetically sorted id list """ - # Use MongoDB's distinct and aggregation to get all unique labels - pipeline = [ - {"$group": {"_id": "$_id"}}, # Group by _id - {"$sort": {"_id": 1}}, # Sort alphabetically - ] - cursor = self.collection.aggregate(pipeline) + cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)]) labels = [] async for doc in cursor: labels.append(doc["_id"]) return labels async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 5, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -704,151 +723,92 @@ class MongoGraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + node_edges = [] try: - if label == "*": - # Get all nodes and edges - async for node_doc in self.collection.find({}): - node_id = str(node_doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_doc.get("_id")], - properties={ - k: v - for k, v in node_doc.items() - if k not in ["_id", "edges"] - }, - ) - ) - seen_nodes.add(node_id) + pipeline = [ + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, + {"$sort": {"edge_count": -1}}, + {"$limit": max_nodes}, + ] - # Process edges - for edge in node_doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + if label == "*": + all_node_count = await self.collection.count_documents({}) + result.is_truncated = all_node_count > max_nodes else: # Verify if starting node exists - start_nodes = self.collection.find({"_id": label}) - start_nodes_exist = await start_nodes.to_list(length=1) - if not start_nodes_exist: + start_node = await self.collection.find_one({"_id": label}) + if not start_node: logger.warning(f"Starting node with label {label} does not exist!") return result - # Use $graphLookup for traversal - pipeline = [ - { - "$match": {"_id": label} - }, # Start with nodes having the specified label - { - "$graphLookup": { - "from": self._collection_name, - "startWith": "$edges.target", - "connectFromField": "edges.target", - "connectToField": "_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_nodes", - } - }, - ] + # Add starting node to pipeline + pipeline.insert(0, {"$match": {"_id": label}}) - async for doc in self.collection.aggregate(pipeline): - # Add the start node - node_id = str(doc["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[ - doc.get( - "_id", - ) - ], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "edges", - "connected_nodes", - "depth", - ] - }, - ) + cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) + async for doc in cursor: + # Add the start node + node_id = str(doc["_id"]) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "connected_edges", + "edge_count", + ] + }, + ) + ) + seen_nodes.add(node_id) + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relationship", ""), + source=edge["source_node_id"], + target=edge["target_node_id"], + properties={ + k: v + for k, v in edge.items() + if k + not in [ + "_id", + "source_node_id", + "target_node_id", + "relationship", + ] + }, ) - seen_nodes.add(node_id) - - # Add edges from start node - for edge in doc.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) - - # Add connected nodes and their edges - for connected in doc.get("connected_nodes", []): - node_id = str(connected["_id"]) - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[connected.get("_id")], - properties={ - k: v - for k, v in connected.items() - if k not in ["_id", "edges", "depth"] - }, - ) - ) - seen_nodes.add(node_id) - - # Add edges from connected nodes - for edge in connected.get("edges", []): - edge_id = f"{node_id}-{edge['target']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relation", ""), - source=node_id, - target=edge["target"], - properties={ - k: v - for k, v in edge.items() - if k not in ["target", "relation"] - }, - ) - ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" @@ -873,9 +833,14 @@ class MongoGraphStorage(BaseGraphStorage): if not nodes: return - # 1. Remove all edges referencing these nodes (remove from edges array of other nodes) - await self.collection.update_many( - {}, {"$pull": {"edges": {"target": {"$in": nodes}}}} + # 1. Remove all edges referencing these nodes + await self.edge_collection.delete_many( + { + "$or": [ + {"source_node_id": {"$in": nodes}}, + {"target_node_id": {"$in": nodes}}, + ] + } ) # 2. Delete the node documents @@ -893,17 +858,16 @@ class MongoGraphStorage(BaseGraphStorage): if not edges: return - update_tasks = [] - for source, target in edges: - # Remove edge pointing to target from source node's edges array - update_tasks.append( - self.collection.update_one( - {"_id": source}, {"$pull": {"edges": {"target": target}}} - ) + all_edge_pairs = [] + for source_id, target_id in edges: + all_edge_pairs.append( + {"source_node_id": source_id, "target_node_id": target_id} + ) + all_edge_pairs.append( + {"source_node_id": target_id, "target_node_id": source_id} ) - if update_tasks: - await asyncio.gather(*update_tasks) + await self.edge_collection.delete_many({"$or": all_edge_pairs}) logger.debug(f"Successfully deleted edges: {edges}") @@ -920,9 +884,16 @@ class MongoGraphStorage(BaseGraphStorage): logger.info( f"Dropped {deleted_count} documents from graph {self._collection_name}" ) + + result = await self.edge_collection.delete_many({}) + edge_count = result.deleted_count + logger.info( + f"Dropped {edge_count} edges from graph {self._edge_collection_name}" + ) + return { "status": "success", - "message": f"{deleted_count} documents dropped", + "message": f"{deleted_count} documents and {edge_count} edges dropped", } except PyMongoError as e: logger.error(f"Error dropping graph {self._collection_name}: {e}") @@ -932,8 +903,8 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - db: AsyncIOMotorDatabase | None = field(default=None) - _data: AsyncIOMotorCollection | None = field(default=None) + db: AsyncDatabase | None = field(default=None) + _data: AsyncCollection | None = field(default=None) def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -967,7 +938,8 @@ class MongoVectorDBStorage(BaseVectorStorage): try: index_name = "vector_knn_index" - indexes = await self._data.list_search_indexes().to_list(length=None) + indexes_cursor = await self._data.list_search_indexes() + indexes = await indexes_cursor.to_list(length=None) for index in indexes: if index["name"] == index_name: logger.debug("vector index already exist") @@ -1062,8 +1034,8 @@ class MongoVectorDBStorage(BaseVectorStorage): ] # Execute the aggregation pipeline - cursor = self._data.aggregate(pipeline) - results = await cursor.to_list() + cursor = await self._data.aggregate(pipeline, allowDiskUse=True) + results = await cursor.to_list(length=None) # Format and return the results with created_at field return [ @@ -1090,6 +1062,10 @@ class MongoVectorDBStorage(BaseVectorStorage): if not ids: return + # Convert to list if it's a set (MongoDB BSON cannot encode sets) + if isinstance(ids, set): + ids = list(ids) + try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.debug( @@ -1232,7 +1208,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return {"status": "error", "message": str(e)} -async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): +async def get_or_create_collection(db: AsyncDatabase, collection_name: str): collection_names = await db.list_collection_names() if collection_name not in collection_names: diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index fb78270d..64e66f48 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -30,6 +30,7 @@ from lightrag.kg import ( verify_storage_implementation, ) from lightrag.kg.shared_storage import initialize_share_data +from lightrag.constants import GRAPH_FIELD_SEP # 模拟的嵌入函数,返回随机向量 @@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage): 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边 """ try: + chunk1_id = "1" + chunk2_id = "2" + chunk3_id = "3" # 1. 插入测试数据 # 插入节点1: 人工智能 node1_id = "人工智能" @@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage): "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。", "keywords": "AI,机器学习,深度学习", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]), } print(f"插入节点1: {node1_id}") await storage.upsert_node(node1_id, node1_data) @@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage): "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。", "keywords": "监督学习,无监督学习,强化学习", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]), } print(f"插入节点2: {node2_id}") await storage.upsert_node(node2_id, node2_data) @@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage): "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。", "keywords": "神经网络,CNN,RNN", "entity_type": "技术领域", + "source_id": GRAPH_FIELD_SEP.join([chunk3_id]), } print(f"插入节点3: {node3_id}") await storage.upsert_node(node3_id, node3_data) @@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "人工智能领域包含机器学习这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]), } print(f"插入边1: {node1_id} -> {node2_id}") await storage.upsert_edge(node1_id, node2_id, edge1_data) @@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "机器学习领域包含深度学习这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]), } print(f"插入边2: {node2_id} -> {node3_id}") await storage.upsert_edge(node2_id, node3_id, edge2_data) @@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage): "relationship": "包含", "weight": 1.0, "description": "人工智能领域包含自然语言处理这个子领域", + "source_id": GRAPH_FIELD_SEP.join([chunk3_id]), } print(f"插入边3: {node1_id} -> {node4_id}") await storage.upsert_edge(node1_id, node4_id, edge3_data) @@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage): print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)") + # 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点 + print("== 测试 get_nodes_by_chunk_ids") + + print("== 测试单个 chunk_id,匹配多个节点") + nodes = await storage.get_nodes_by_chunk_ids([chunk2_id]) + assert len(nodes) == 2, f"{chunk1_id} 应有2个节点,实际有 {len(nodes)} 个" + + has_node1 = any(node["entity_id"] == node1_id for node in nodes) + has_node2 = any(node["entity_id"] == node2_id for node in nodes) + + assert has_node1, f"节点 {node1_id} 应在返回结果中" + assert has_node2, f"节点 {node2_id} 应在返回结果中" + + print("== 测试多个 chunk_id,部分匹配多个节点") + nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id]) + assert ( + len(nodes) == 3 + ), f"{chunk2_id}, {chunk3_id} 应有3个节点,实际有 {len(nodes)} 个" + + has_node1 = any(node["entity_id"] == node1_id for node in nodes) + has_node2 = any(node["entity_id"] == node2_id for node in nodes) + has_node3 = any(node["entity_id"] == node3_id for node in nodes) + + assert has_node1, f"节点 {node1_id} 应在返回结果中" + assert has_node2, f"节点 {node2_id} 应在返回结果中" + assert has_node3, f"节点 {node3_id} 应在返回结果中" + + # 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边 + print("== 测试 get_edges_by_chunk_ids") + + print("== 测试单个 chunk_id,匹配多条边") + edges = await storage.get_edges_by_chunk_ids([chunk2_id]) + assert len(edges) == 2, f"{chunk2_id} 应有2条边,实际有 {len(edges)} 条" + + has_edge_node1_node2 = any( + edge["source"] == node1_id and edge["target"] == node2_id for edge in edges + ) + has_edge_node2_node3 = any( + edge["source"] == node2_id and edge["target"] == node3_id for edge in edges + ) + + assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id} 到 {node2_id} 的边" + assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id} 到 {node3_id} 的边" + + print("== 测试多个 chunk_id,部分匹配多条边") + edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id]) + assert ( + len(edges) == 3 + ), f"{chunk2_id}, {chunk3_id} 应有3条边,实际有 {len(edges)} 条" + + has_edge_node1_node2 = any( + edge["source"] == node1_id and edge["target"] == node2_id for edge in edges + ) + has_edge_node2_node3 = any( + edge["source"] == node2_id and edge["target"] == node3_id for edge in edges + ) + has_edge_node1_node4 = any( + edge["source"] == node1_id and edge["target"] == node4_id for edge in edges + ) + + assert ( + has_edge_node1_node2 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node2_id} 的边" + assert ( + has_edge_node2_node3 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id} 到 {node3_id} 的边" + assert ( + has_edge_node1_node4 + ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node4_id} 的边" + print("\n批量操作测试完成") return True