Add missing methods for MongoGraphStorage

This commit is contained in:
Ken Chen 2025-06-15 21:22:32 +08:00
parent 1389265695
commit cf441aa84c

View file

@ -22,27 +22,25 @@ import pipmaster as pm
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
pm.install("pymongo") pm.install("pymongo")
if not pm.is_installed("motor"): from pymongo import AsyncMongoClient # type: ignore
pm.install("motor") from pymongo.asynchronous.database import AsyncDatabase # type: ignore
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
from motor.motor_asyncio import ( # type: ignore
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel # type: ignore from pymongo.operations import SearchIndexModel # type: ignore
from pymongo.errors import PyMongoError # type: ignore from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock() _lock = asyncio.Lock()
@classmethod @classmethod
async def get_client(cls) -> AsyncIOMotorDatabase: async def get_client(cls) -> AsyncMongoClient:
async with cls._lock: async with cls._lock:
if cls._instances["db"] is None: if cls._instances["db"] is None:
uri = os.environ.get( uri = os.environ.get(
@ -57,7 +55,7 @@ class ClientManager:
"MONGO_DATABASE", "MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"), config.get("mongodb", "database", fallback="LightRAG"),
) )
client = AsyncIOMotorClient(uri) client = AsyncMongoClient(uri)
db = client.get_database(database_name) db = client.get_database(database_name)
cls._instances["db"] = db cls._instances["db"] = db
cls._instances["ref_count"] = 0 cls._instances["ref_count"] = 0
@ -65,7 +63,7 @@ class ClientManager:
return cls._instances["db"] return cls._instances["db"]
@classmethod @classmethod
async def release_client(cls, db: AsyncIOMotorDatabase): async def release_client(cls, db: AsyncDatabase):
async with cls._lock: async with cls._lock:
if db is not None: if db is not None:
if db is cls._instances["db"]: if db is cls._instances["db"]:
@ -77,8 +75,8 @@ class ClientManager:
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None) _data: AsyncCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None) _data: AsyncCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage):
logger.error(f"Error dropping doc status {self._collection_name}: {e}") logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None:
await self._data.delete_many({"_id": {"$in": ids}})
@final @final
@dataclass @dataclass
@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage):
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
""" """
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None) collection: AsyncCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
# Sample entity_relation document
# {
# "_id" : "CompanyA",
# "created_at" : 1749904575,
# "description" : "A major technology company",
# "edges" : [
# {
# "target" : "ProductX",
# "relation": "Develops", // To distinguish multiple same-target relations
# "weight" : Double("1"),
# "description" : "CompanyA develops ProductX",
# "keywords" : "develop, produce",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "file_path" : "custom_kg",
# "created_at" : 1749904575
# }
# ],
# "entity_id" : "CompanyA",
# "entity_type" : "Organization",
# "file_path" : "custom_kg",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
# }
async def _graph_lookup( async def _graph_lookup(
self, start_node_id: str, max_depth: int = None self, start_node_id: str, max_depth: int = None
) -> List[dict]: ) -> List[dict]:
@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage):
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
# Return the matching doc plus a field "reachableNodes" # Return the matching doc plus a field "reachableNodes"
cursor = self.collection.aggregate(pipeline) cursor = await self.collection.aggregate(pipeline)
results = await cursor.to_list(None) results = await cursor.to_list(None)
# If there's no matching node, results = []. # If there's no matching node, results = [].
@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if there's a direct single-hop edge from source_node_id to target_node_id. Check if there's a direct single-hop edge from source_node_id to target_node_id.
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
"Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
and then see if the target node is in the "reachableNodes" at depth=0.
But typically for a direct edge, we might just do a find_one.
Below is a demonstration approach.
""" """
# We can do a single-hop graphLookup (maxDepth=0 or 1). # Direct check if the target_node appears among the edges array.
# Then check if the target_node appears among the edges array. doc = await self.collection.find_one(
pipeline = [ {"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
{"$match": {"_id": source_node_id}}, )
{ return doc is not None
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
"maxDepth": 0, # means: do not follow beyond immediate edges
}
},
{
"$project": {
"_id": 0,
"reachableNodes._id": 1, # only keep the _id from the subdocs
}
},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return False
# results[0]["reachableNodes"] are the immediate neighbors
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
return target_node_id in reachable_ids
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage):
The easiest approach is typically two queries: The easiest approach is typically two queries:
- count of edges array in node_id's doc - count of edges array in node_id's doc
- count of how many other docs have node_id in their edges.target. - count of how many other docs have node_id in their edges.target.
But we'll do a $graphLookup demonstration for inbound edges:
1) Outbound edges: direct from node's edges array
2) Inbound edges: we can do a special $graphLookup from all docs
or do an explicit match.
For demonstration, let's do this in two steps (with second step $graphLookup).
""" """
# --- 1) Outbound edges (direct from doc) --- # --- 1) Outbound edges (direct from doc) ---
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
if not doc: if not doc:
return 0 return 0
outbound_count = len(doc.get("edges", [])) outbound_count = len(doc.get("edges", []))
# --- 2) Inbound edges: # --- 2) Inbound edges:
# A simple way is: find all docs where "edges.target" == node_id. inbound_count = await self.collection.count_documents({"edges.target": node_id})
# But let's do a $graphLookup from `node_id` in REVERSE.
# There's a trick to do "reverse" graphLookups: you'd store
# reversed edges or do a more advanced pipeline. Typically you'd do
# a direct match. We'll just do a direct match for inbound.
inbound_count_pipeline = [
{"$match": {"edges.target": node_id}},
{
"$project": {
"matchingEdgesCount": {
"$size": {
"$filter": {
"input": "$edges",
"as": "edge",
"cond": {"$eq": ["$$edge.target", node_id]},
}
}
}
}
},
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
]
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
inbound_result = await inbound_cursor.to_list(None)
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
return outbound_count + inbound_count return outbound_count + inbound_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
""" """Get the total degree (sum of relationships) of two nodes.
If your graph can hold multiple edges from the same src to the same tgt
(e.g. different 'relation' values), you can sum them. If it's always
one edge, this is typically 1 or 0.
We'll do a single-hop $graphLookup from src_id, Args:
then count how many edges reference tgt_id at depth=0. src_id: Label of the source node
""" tgt_id: Label of the target node
pipeline = [
{"$match": {"_id": src_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return 0
# We can simply count how many edges in `results[0].edges` have target == tgt_id. Returns:
edges = results[0].get("edges", []) int: Sum of the degrees of both nodes
count = sum(1 for e in edges if e.get("target") == tgt_id) """
return count src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
pipeline = [ doc = await self.collection.find_one(
{"$match": {"_id": source_node_id}},
{ {
"$graphLookup": { "$or": [
"from": self.collection.name, {"_id": source_node_id, "edges.target": target_node_id},
"startWith": "$edges.target", {"_id": target_node_id, "edges.target": source_node_id},
"connectFromField": "edges.target", ]
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
}, },
{"$project": {"edges": 1}}, {"edges": 1},
] )
cursor = self.collection.aggregate(pipeline) if not doc:
docs = await cursor.to_list(None)
if not docs:
return None return None
for e in docs[0].get("edges", []): for e in doc.get("edges", []):
if e.get("target") == target_node_id: if e.get("target") == target_node_id:
return e return e
if e.get("target") == source_node_id:
return e
return None return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Return a list of (source_id, target_id) for direct edges from source_node_id. Retrieves all edges (relationships) for a particular node identified by its label.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
""" """
pipeline = [ doc = await self.get_node(source_node_id)
{"$match": {"_id": source_node_id}}, if not doc:
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"_id": 0, "edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
result = await cursor.to_list(None)
if not result:
return None return None
edges = result[0].get("edges", []) edges = []
return [(source_node_id, e["target"]) for e in edges] for e in doc.get("edges", []):
edges.append((source_node_id, e.get("target")))
cursor = self.collection.find({"edges.target": source_node_id})
source_docs = await cursor.to_list(None)
if not source_docs:
return edges
for doc in source_docs:
edges.append((doc.get("_id"), source_node_id))
return edges
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
result = {}
async for doc in self.collection.find({"_id": {"$in": node_ids}}):
result[doc.get("_id")] = doc
return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
# merge the outbound and inbound results with the same "_id" and sum the "degree"
merged_results = {}
# Outbound degrees
outbound_pipeline = [
{"$match": {"_id": {"$in": node_ids}}},
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
]
cursor = await self.collection.aggregate(outbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = doc.get("degree")
# Inbound degrees
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{
"$group": {
"_id": "$edges.target",
"degree": {
"$sum": 1
}, # Count the number of incoming edges for each target node
}
},
]
cursor = await self.collection.aggregate(inbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = merged_results.get(
doc.get("_id"), 0
) + doc.get("degree")
return merged_results
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
result = {}
cursor = self.collection.find(
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
)
async for doc in cursor:
node_id = doc.get("_id")
edges = doc.get("edges", [])
result[node_id] = [(node_id, e["target"]) for e in edges]
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": "$_id", "target": "$edges.target"}},
]
cursor = await self.collection.aggregate(inbound_pipeline)
async for doc in cursor:
node_id = doc.get("target")
result[node_id] = result.get(node_id, [])
result[node_id].append((doc.get("_id"), node_id))
return result
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage):
Returns: Returns:
[id1, id2, ...] # Alphabetically sorted id list [id1, id2, ...] # Alphabetically sorted id list
""" """
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline) cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)])
labels = [] labels = []
async for doc in cursor: async for doc in cursor:
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 5,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage):
if not edges: if not edges:
return return
update_tasks = [] # Group edges by source node (document _id) for efficient updates
for source, target in edges: edge_groups = {}
# Remove edge pointing to target from source node's edges array for source_id, target_id in edges:
update_tasks.append( if source_id not in edge_groups:
edge_groups[source_id] = []
edge_groups[source_id].append(target_id)
# Bulk update documents to remove the edges
update_operations = []
for source_id, target_ids in edge_groups.items():
update_operations.append(
self.collection.update_one( self.collection.update_one(
{"_id": source}, {"$pull": {"edges": {"target": target}}} {"_id": source_id},
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
) )
) )
if update_tasks: if update_operations:
await asyncio.gather(*update_tasks) await asyncio.gather(*update_operations)
logger.debug(f"Successfully deleted edges: {edges}") logger.debug(f"Successfully deleted edges: {edges}")
@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase | None = field(default=None) db: AsyncDatabase | None = field(default=None)
_data: AsyncIOMotorCollection | None = field(default=None) _data: AsyncCollection | None = field(default=None)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
collection_names = await db.list_collection_names() collection_names = await db.list_collection_names()
if collection_name not in collection_names: if collection_name not in collection_names: