Add missing methods for MongoGraphStorage
This commit is contained in:
parent
1389265695
commit
cf441aa84c
1 changed files with 212 additions and 174 deletions
|
|
@ -22,27 +22,25 @@ import pipmaster as pm
|
|||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
|
||||
if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
from motor.motor_asyncio import ( # type: ignore
|
||||
AsyncIOMotorClient,
|
||||
AsyncIOMotorDatabase,
|
||||
AsyncIOMotorCollection,
|
||||
)
|
||||
from pymongo import AsyncMongoClient # type: ignore
|
||||
from pymongo.asynchronous.database import AsyncDatabase # type: ignore
|
||||
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
|
||||
from pymongo.operations import SearchIndexModel # type: ignore
|
||||
from pymongo.errors import PyMongoError # type: ignore
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
|
||||
|
||||
class ClientManager:
|
||||
_instances = {"db": None, "ref_count": 0}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_client(cls) -> AsyncIOMotorDatabase:
|
||||
async def get_client(cls) -> AsyncMongoClient:
|
||||
async with cls._lock:
|
||||
if cls._instances["db"] is None:
|
||||
uri = os.environ.get(
|
||||
|
|
@ -57,7 +55,7 @@ class ClientManager:
|
|||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
client = AsyncMongoClient(uri)
|
||||
db = client.get_database(database_name)
|
||||
cls._instances["db"] = db
|
||||
cls._instances["ref_count"] = 0
|
||||
|
|
@ -65,7 +63,7 @@ class ClientManager:
|
|||
return cls._instances["db"]
|
||||
|
||||
@classmethod
|
||||
async def release_client(cls, db: AsyncIOMotorDatabase):
|
||||
async def release_client(cls, db: AsyncDatabase):
|
||||
async with cls._lock:
|
||||
if db is not None:
|
||||
if db is cls._instances["db"]:
|
||||
|
|
@ -77,8 +75,8 @@ class ClientManager:
|
|||
@final
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
_data: AsyncIOMotorCollection = field(default=None)
|
||||
db: AsyncDatabase = field(default=None)
|
||||
_data: AsyncCollection = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self._collection_name = self.namespace
|
||||
|
|
@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage):
|
|||
@final
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
_data: AsyncIOMotorCollection = field(default=None)
|
||||
db: AsyncDatabase = field(default=None)
|
||||
_data: AsyncCollection = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self._collection_name = self.namespace
|
||||
|
|
@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
await self._data.delete_many({"_id": {"$in": ids}})
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
|
|
@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
||||
"""
|
||||
|
||||
db: AsyncIOMotorDatabase = field(default=None)
|
||||
collection: AsyncIOMotorCollection = field(default=None)
|
||||
db: AsyncDatabase = field(default=None)
|
||||
collection: AsyncCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
|
|
@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
# Sample entity_relation document
|
||||
# {
|
||||
# "_id" : "CompanyA",
|
||||
# "created_at" : 1749904575,
|
||||
# "description" : "A major technology company",
|
||||
# "edges" : [
|
||||
# {
|
||||
# "target" : "ProductX",
|
||||
# "relation": "Develops", // To distinguish multiple same-target relations
|
||||
# "weight" : Double("1"),
|
||||
# "description" : "CompanyA develops ProductX",
|
||||
# "keywords" : "develop, produce",
|
||||
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
|
||||
# "file_path" : "custom_kg",
|
||||
# "created_at" : 1749904575
|
||||
# }
|
||||
# ],
|
||||
# "entity_id" : "CompanyA",
|
||||
# "entity_type" : "Organization",
|
||||
# "file_path" : "custom_kg",
|
||||
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
|
||||
# }
|
||||
|
||||
async def _graph_lookup(
|
||||
self, start_node_id: str, max_depth: int = None
|
||||
) -> List[dict]:
|
||||
|
|
@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
||||
|
||||
# Return the matching doc plus a field "reachableNodes"
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
cursor = await self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
|
||||
# If there's no matching node, results = [].
|
||||
|
|
@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||
|
||||
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
|
||||
"Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
|
||||
and then see if the target node is in the "reachableNodes" at depth=0.
|
||||
|
||||
But typically for a direct edge, we might just do a find_one.
|
||||
Below is a demonstration approach.
|
||||
"""
|
||||
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
||||
# Then check if the target_node appears among the edges array.
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0, # means: do not follow beyond immediate edges
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"reachableNodes._id": 1, # only keep the _id from the subdocs
|
||||
}
|
||||
},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
if not results:
|
||||
return False
|
||||
|
||||
# results[0]["reachableNodes"] are the immediate neighbors
|
||||
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
|
||||
return target_node_id in reachable_ids
|
||||
# Direct check if the target_node appears among the edges array.
|
||||
doc = await self.collection.find_one(
|
||||
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
|
||||
)
|
||||
return doc is not None
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
The easiest approach is typically two queries:
|
||||
- count of edges array in node_id's doc
|
||||
- count of how many other docs have node_id in their edges.target.
|
||||
|
||||
But we'll do a $graphLookup demonstration for inbound edges:
|
||||
1) Outbound edges: direct from node's edges array
|
||||
2) Inbound edges: we can do a special $graphLookup from all docs
|
||||
or do an explicit match.
|
||||
|
||||
For demonstration, let's do this in two steps (with second step $graphLookup).
|
||||
"""
|
||||
# --- 1) Outbound edges (direct from doc) ---
|
||||
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
|
||||
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
|
||||
if not doc:
|
||||
return 0
|
||||
|
||||
outbound_count = len(doc.get("edges", []))
|
||||
|
||||
# --- 2) Inbound edges:
|
||||
# A simple way is: find all docs where "edges.target" == node_id.
|
||||
# But let's do a $graphLookup from `node_id` in REVERSE.
|
||||
# There's a trick to do "reverse" graphLookups: you'd store
|
||||
# reversed edges or do a more advanced pipeline. Typically you'd do
|
||||
# a direct match. We'll just do a direct match for inbound.
|
||||
inbound_count_pipeline = [
|
||||
{"$match": {"edges.target": node_id}},
|
||||
{
|
||||
"$project": {
|
||||
"matchingEdgesCount": {
|
||||
"$size": {
|
||||
"$filter": {
|
||||
"input": "$edges",
|
||||
"as": "edge",
|
||||
"cond": {"$eq": ["$$edge.target", node_id]},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
|
||||
]
|
||||
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
|
||||
inbound_result = await inbound_cursor.to_list(None)
|
||||
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
|
||||
inbound_count = await self.collection.count_documents({"edges.target": node_id})
|
||||
|
||||
return outbound_count + inbound_count
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""
|
||||
If your graph can hold multiple edges from the same src to the same tgt
|
||||
(e.g. different 'relation' values), you can sum them. If it's always
|
||||
one edge, this is typically 1 or 0.
|
||||
"""Get the total degree (sum of relationships) of two nodes.
|
||||
|
||||
We'll do a single-hop $graphLookup from src_id,
|
||||
then count how many edges reference tgt_id at depth=0.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": src_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
},
|
||||
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
if not results:
|
||||
return 0
|
||||
Args:
|
||||
src_id: Label of the source node
|
||||
tgt_id: Label of the target node
|
||||
|
||||
# We can simply count how many edges in `results[0].edges` have target == tgt_id.
|
||||
edges = results[0].get("edges", [])
|
||||
count = sum(1 for e in edges if e.get("target") == tgt_id)
|
||||
return count
|
||||
Returns:
|
||||
int: Sum of the degrees of both nodes
|
||||
"""
|
||||
src_degree = await self.node_degree(src_id)
|
||||
trg_degree = await self.node_degree(tgt_id)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
trg_degree = 0 if trg_degree is None else trg_degree
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
doc = await self.collection.find_one(
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
"$or": [
|
||||
{"_id": source_node_id, "edges.target": target_node_id},
|
||||
{"_id": target_node_id, "edges.target": source_node_id},
|
||||
]
|
||||
},
|
||||
{"$project": {"edges": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
docs = await cursor.to_list(None)
|
||||
if not docs:
|
||||
{"edges": 1},
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
for e in docs[0].get("edges", []):
|
||||
for e in doc.get("edges", []):
|
||||
if e.get("target") == target_node_id:
|
||||
return e
|
||||
if e.get("target") == source_node_id:
|
||||
return e
|
||||
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
|
||||
Args:
|
||||
source_node_id: Label of the node to get edges for
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
||||
None: If no edges found
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
},
|
||||
{"$project": {"_id": 0, "edges": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
result = await cursor.to_list(None)
|
||||
if not result:
|
||||
doc = await self.get_node(source_node_id)
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
edges = result[0].get("edges", [])
|
||||
return [(source_node_id, e["target"]) for e in edges]
|
||||
edges = []
|
||||
for e in doc.get("edges", []):
|
||||
edges.append((source_node_id, e.get("target")))
|
||||
|
||||
cursor = self.collection.find({"edges.target": source_node_id})
|
||||
source_docs = await cursor.to_list(None)
|
||||
if not source_docs:
|
||||
return edges
|
||||
|
||||
for doc in source_docs:
|
||||
edges.append((doc.get("_id"), source_node_id))
|
||||
|
||||
return edges
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
result = {}
|
||||
|
||||
async for doc in self.collection.find({"_id": {"$in": node_ids}}):
|
||||
result[doc.get("_id")] = doc
|
||||
return result
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
# merge the outbound and inbound results with the same "_id" and sum the "degree"
|
||||
merged_results = {}
|
||||
|
||||
# Outbound degrees
|
||||
outbound_pipeline = [
|
||||
{"$match": {"_id": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(outbound_pipeline)
|
||||
async for doc in cursor:
|
||||
merged_results[doc.get("_id")] = doc.get("degree")
|
||||
|
||||
# Inbound degrees
|
||||
inbound_pipeline = [
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "edges.target": 1}},
|
||||
{"$unwind": "$edges"},
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$edges.target",
|
||||
"degree": {
|
||||
"$sum": 1
|
||||
}, # Count the number of incoming edges for each target node
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||
async for doc in cursor:
|
||||
merged_results[doc.get("_id")] = merged_results.get(
|
||||
doc.get("_id"), 0
|
||||
) + doc.get("degree")
|
||||
|
||||
return merged_results
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes.
|
||||
For each node, returns both outgoing and incoming edges to properly represent
|
||||
the undirected graph nature.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
For each node, the list includes both:
|
||||
- Outgoing edges: (queried_node, connected_node)
|
||||
- Incoming edges: (connected_node, queried_node)
|
||||
"""
|
||||
result = {}
|
||||
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
|
||||
)
|
||||
async for doc in cursor:
|
||||
node_id = doc.get("_id")
|
||||
edges = doc.get("edges", [])
|
||||
result[node_id] = [(node_id, e["target"]) for e in edges]
|
||||
|
||||
inbound_pipeline = [
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "edges.target": 1}},
|
||||
{"$unwind": "$edges"},
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": "$_id", "target": "$edges.target"}},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||
async for doc in cursor:
|
||||
node_id = doc.get("target")
|
||||
result[node_id] = result.get(node_id, [])
|
||||
result[node_id].append((doc.get("_id"), node_id))
|
||||
|
||||
return result
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
Returns:
|
||||
[id1, id2, ...] # Alphabetically sorted id list
|
||||
"""
|
||||
# Use MongoDB's distinct and aggregation to get all unique labels
|
||||
pipeline = [
|
||||
{"$group": {"_id": "$_id"}}, # Group by _id
|
||||
{"$sort": {"_id": 1}}, # Sort alphabetically
|
||||
]
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)])
|
||||
labels = []
|
||||
async for doc in cursor:
|
||||
labels.append(doc["_id"])
|
||||
return labels
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 5,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
|
@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
if not edges:
|
||||
return
|
||||
|
||||
update_tasks = []
|
||||
for source, target in edges:
|
||||
# Remove edge pointing to target from source node's edges array
|
||||
update_tasks.append(
|
||||
# Group edges by source node (document _id) for efficient updates
|
||||
edge_groups = {}
|
||||
for source_id, target_id in edges:
|
||||
if source_id not in edge_groups:
|
||||
edge_groups[source_id] = []
|
||||
edge_groups[source_id].append(target_id)
|
||||
|
||||
# Bulk update documents to remove the edges
|
||||
update_operations = []
|
||||
for source_id, target_ids in edge_groups.items():
|
||||
update_operations.append(
|
||||
self.collection.update_one(
|
||||
{"_id": source}, {"$pull": {"edges": {"target": target}}}
|
||||
{"_id": source_id},
|
||||
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
|
||||
)
|
||||
)
|
||||
|
||||
if update_tasks:
|
||||
await asyncio.gather(*update_tasks)
|
||||
if update_operations:
|
||||
await asyncio.gather(*update_operations)
|
||||
|
||||
logger.debug(f"Successfully deleted edges: {edges}")
|
||||
|
||||
|
|
@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
@final
|
||||
@dataclass
|
||||
class MongoVectorDBStorage(BaseVectorStorage):
|
||||
db: AsyncIOMotorDatabase | None = field(default=None)
|
||||
_data: AsyncIOMotorCollection | None = field(default=None)
|
||||
db: AsyncDatabase | None = field(default=None)
|
||||
_data: AsyncCollection | None = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
|
|
@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue