Add missing methods for MongoGraphStorage

This commit is contained in:
Ken Chen 2025-06-15 21:22:32 +08:00
parent 1389265695
commit cf441aa84c

View file

@ -22,27 +22,25 @@ import pipmaster as pm
if not pm.is_installed("pymongo"):
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
from motor.motor_asyncio import ( # type: ignore
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo import AsyncMongoClient # type: ignore
from pymongo.asynchronous.database import AsyncDatabase # type: ignore
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
from pymongo.operations import SearchIndexModel # type: ignore
from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async def get_client(cls) -> AsyncMongoClient:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
@ -57,7 +55,7 @@ class ClientManager:
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
client = AsyncMongoClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
@ -65,7 +63,7 @@ class ClientManager:
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async def release_client(cls, db: AsyncDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
@ -77,8 +75,8 @@ class ClientManager:
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
db: AsyncDatabase = field(default=None)
_data: AsyncCollection = field(default=None)
def __post_init__(self):
self._collection_name = self.namespace
@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage):
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
db: AsyncDatabase = field(default=None)
_data: AsyncCollection = field(default=None)
def __post_init__(self):
self._collection_name = self.namespace
@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage):
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None:
await self._data.delete_many({"_id": {"$in": ids}})
@final
@dataclass
@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage):
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
"""
db: AsyncIOMotorDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None)
db: AsyncDatabase = field(default=None)
collection: AsyncCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage):
# -------------------------------------------------------------------------
#
# Sample entity_relation document
# {
# "_id" : "CompanyA",
# "created_at" : 1749904575,
# "description" : "A major technology company",
# "edges" : [
# {
# "target" : "ProductX",
# "relation": "Develops", // To distinguish multiple same-target relations
# "weight" : Double("1"),
# "description" : "CompanyA develops ProductX",
# "keywords" : "develop, produce",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "file_path" : "custom_kg",
# "created_at" : 1749904575
# }
# ],
# "entity_id" : "CompanyA",
# "entity_type" : "Organization",
# "file_path" : "custom_kg",
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
# }
async def _graph_lookup(
self, start_node_id: str, max_depth: int = None
) -> List[dict]:
@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage):
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
# Return the matching doc plus a field "reachableNodes"
cursor = self.collection.aggregate(pipeline)
cursor = await self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
# If there's no matching node, results = [].
@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if there's a direct single-hop edge from source_node_id to target_node_id.
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
"Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
and then see if the target node is in the "reachableNodes" at depth=0.
But typically for a direct edge, we might just do a find_one.
Below is a demonstration approach.
"""
# We can do a single-hop graphLookup (maxDepth=0 or 1).
# Then check if the target_node appears among the edges array.
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
"maxDepth": 0, # means: do not follow beyond immediate edges
}
},
{
"$project": {
"_id": 0,
"reachableNodes._id": 1, # only keep the _id from the subdocs
}
},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return False
# results[0]["reachableNodes"] are the immediate neighbors
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
return target_node_id in reachable_ids
# Direct check if the target_node appears among the edges array.
doc = await self.collection.find_one(
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
)
return doc is not None
#
# -------------------------------------------------------------------------
@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage):
The easiest approach is typically two queries:
- count of edges array in node_id's doc
- count of how many other docs have node_id in their edges.target.
But we'll do a $graphLookup demonstration for inbound edges:
1) Outbound edges: direct from node's edges array
2) Inbound edges: we can do a special $graphLookup from all docs
or do an explicit match.
For demonstration, let's do this in two steps (with second step $graphLookup).
"""
# --- 1) Outbound edges (direct from doc) ---
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
if not doc:
return 0
outbound_count = len(doc.get("edges", []))
# --- 2) Inbound edges:
# A simple way is: find all docs where "edges.target" == node_id.
# But let's do a $graphLookup from `node_id` in REVERSE.
# There's a trick to do "reverse" graphLookups: you'd store
# reversed edges or do a more advanced pipeline. Typically you'd do
# a direct match. We'll just do a direct match for inbound.
inbound_count_pipeline = [
{"$match": {"edges.target": node_id}},
{
"$project": {
"matchingEdgesCount": {
"$size": {
"$filter": {
"input": "$edges",
"as": "edge",
"cond": {"$eq": ["$$edge.target", node_id]},
}
}
}
}
},
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
]
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
inbound_result = await inbound_cursor.to_list(None)
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
inbound_count = await self.collection.count_documents({"edges.target": node_id})
return outbound_count + inbound_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""
If your graph can hold multiple edges from the same src to the same tgt
(e.g. different 'relation' values), you can sum them. If it's always
one edge, this is typically 1 or 0.
"""Get the total degree (sum of relationships) of two nodes.
We'll do a single-hop $graphLookup from src_id,
then count how many edges reference tgt_id at depth=0.
"""
pipeline = [
{"$match": {"_id": src_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return 0
Args:
src_id: Label of the source node
tgt_id: Label of the target node
# We can simply count how many edges in `results[0].edges` have target == tgt_id.
edges = results[0].get("edges", [])
count = sum(1 for e in edges if e.get("target") == tgt_id)
return count
Returns:
int: Sum of the degrees of both nodes
"""
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
#
# -------------------------------------------------------------------------
@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
pipeline = [
{"$match": {"_id": source_node_id}},
doc = await self.collection.find_one(
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
"$or": [
{"_id": source_node_id, "edges.target": target_node_id},
{"_id": target_node_id, "edges.target": source_node_id},
]
},
{"$project": {"edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
docs = await cursor.to_list(None)
if not docs:
{"edges": 1},
)
if not doc:
return None
for e in docs[0].get("edges", []):
for e in doc.get("edges", []):
if e.get("target") == target_node_id:
return e
if e.get("target") == source_node_id:
return e
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Return a list of (source_id, target_id) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
Retrieves all edges (relationships) for a particular node identified by its label.
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
"""
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"_id": 0, "edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
result = await cursor.to_list(None)
if not result:
doc = await self.get_node(source_node_id)
if not doc:
return None
edges = result[0].get("edges", [])
return [(source_node_id, e["target"]) for e in edges]
edges = []
for e in doc.get("edges", []):
edges.append((source_node_id, e.get("target")))
cursor = self.collection.find({"edges.target": source_node_id})
source_docs = await cursor.to_list(None)
if not source_docs:
return edges
for doc in source_docs:
edges.append((doc.get("_id"), source_node_id))
return edges
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
result = {}
async for doc in self.collection.find({"_id": {"$in": node_ids}}):
result[doc.get("_id")] = doc
return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
# merge the outbound and inbound results with the same "_id" and sum the "degree"
merged_results = {}
# Outbound degrees
outbound_pipeline = [
{"$match": {"_id": {"$in": node_ids}}},
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
]
cursor = await self.collection.aggregate(outbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = doc.get("degree")
# Inbound degrees
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{
"$group": {
"_id": "$edges.target",
"degree": {
"$sum": 1
}, # Count the number of incoming edges for each target node
}
},
]
cursor = await self.collection.aggregate(inbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = merged_results.get(
doc.get("_id"), 0
) + doc.get("degree")
return merged_results
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
result = {}
cursor = self.collection.find(
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
)
async for doc in cursor:
node_id = doc.get("_id")
edges = doc.get("edges", [])
result[node_id] = [(node_id, e["target"]) for e in edges]
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": "$_id", "target": "$edges.target"}},
]
cursor = await self.collection.aggregate(inbound_pipeline)
async for doc in cursor:
node_id = doc.get("target")
result[node_id] = result.get(node_id, [])
result[node_id].append((doc.get("_id"), node_id))
return result
#
# -------------------------------------------------------------------------
@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage):
Returns:
[id1, id2, ...] # Alphabetically sorted id list
"""
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline)
cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)])
labels = []
async for doc in cursor:
labels.append(doc["_id"])
return labels
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
self,
node_label: str,
max_depth: int = 5,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage):
if not edges:
return
update_tasks = []
for source, target in edges:
# Remove edge pointing to target from source node's edges array
update_tasks.append(
# Group edges by source node (document _id) for efficient updates
edge_groups = {}
for source_id, target_id in edges:
if source_id not in edge_groups:
edge_groups[source_id] = []
edge_groups[source_id].append(target_id)
# Bulk update documents to remove the edges
update_operations = []
for source_id, target_ids in edge_groups.items():
update_operations.append(
self.collection.update_one(
{"_id": source}, {"$pull": {"edges": {"target": target}}}
{"_id": source_id},
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
)
)
if update_tasks:
await asyncio.gather(*update_tasks)
if update_operations:
await asyncio.gather(*update_operations)
logger.debug(f"Successfully deleted edges: {edges}")
@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase | None = field(default=None)
_data: AsyncIOMotorCollection | None = field(default=None)
db: AsyncDatabase | None = field(default=None)
_data: AsyncCollection | None = field(default=None)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
collection_names = await db.list_collection_names()
if collection_name not in collection_names: