Add missing methods for MongoGraphStorage
This commit is contained in:
parent
1389265695
commit
cf441aa84c
1 changed files with 212 additions and 174 deletions
|
|
@ -22,27 +22,25 @@ import pipmaster as pm
|
||||||
if not pm.is_installed("pymongo"):
|
if not pm.is_installed("pymongo"):
|
||||||
pm.install("pymongo")
|
pm.install("pymongo")
|
||||||
|
|
||||||
if not pm.is_installed("motor"):
|
from pymongo import AsyncMongoClient # type: ignore
|
||||||
pm.install("motor")
|
from pymongo.asynchronous.database import AsyncDatabase # type: ignore
|
||||||
|
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
|
||||||
from motor.motor_asyncio import ( # type: ignore
|
|
||||||
AsyncIOMotorClient,
|
|
||||||
AsyncIOMotorDatabase,
|
|
||||||
AsyncIOMotorCollection,
|
|
||||||
)
|
|
||||||
from pymongo.operations import SearchIndexModel # type: ignore
|
from pymongo.operations import SearchIndexModel # type: ignore
|
||||||
from pymongo.errors import PyMongoError # type: ignore
|
from pymongo.errors import PyMongoError # type: ignore
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
||||||
|
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||||
|
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||||
|
|
||||||
|
|
||||||
class ClientManager:
|
class ClientManager:
|
||||||
_instances = {"db": None, "ref_count": 0}
|
_instances = {"db": None, "ref_count": 0}
|
||||||
_lock = asyncio.Lock()
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_client(cls) -> AsyncIOMotorDatabase:
|
async def get_client(cls) -> AsyncMongoClient:
|
||||||
async with cls._lock:
|
async with cls._lock:
|
||||||
if cls._instances["db"] is None:
|
if cls._instances["db"] is None:
|
||||||
uri = os.environ.get(
|
uri = os.environ.get(
|
||||||
|
|
@ -57,7 +55,7 @@ class ClientManager:
|
||||||
"MONGO_DATABASE",
|
"MONGO_DATABASE",
|
||||||
config.get("mongodb", "database", fallback="LightRAG"),
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
)
|
)
|
||||||
client = AsyncIOMotorClient(uri)
|
client = AsyncMongoClient(uri)
|
||||||
db = client.get_database(database_name)
|
db = client.get_database(database_name)
|
||||||
cls._instances["db"] = db
|
cls._instances["db"] = db
|
||||||
cls._instances["ref_count"] = 0
|
cls._instances["ref_count"] = 0
|
||||||
|
|
@ -65,7 +63,7 @@ class ClientManager:
|
||||||
return cls._instances["db"]
|
return cls._instances["db"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def release_client(cls, db: AsyncIOMotorDatabase):
|
async def release_client(cls, db: AsyncDatabase):
|
||||||
async with cls._lock:
|
async with cls._lock:
|
||||||
if db is not None:
|
if db is not None:
|
||||||
if db is cls._instances["db"]:
|
if db is cls._instances["db"]:
|
||||||
|
|
@ -77,8 +75,8 @@ class ClientManager:
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoKVStorage(BaseKVStorage):
|
class MongoKVStorage(BaseKVStorage):
|
||||||
db: AsyncIOMotorDatabase = field(default=None)
|
db: AsyncDatabase = field(default=None)
|
||||||
_data: AsyncIOMotorCollection = field(default=None)
|
_data: AsyncCollection = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
|
@ -214,8 +212,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoDocStatusStorage(DocStatusStorage):
|
class MongoDocStatusStorage(DocStatusStorage):
|
||||||
db: AsyncIOMotorDatabase = field(default=None)
|
db: AsyncDatabase = field(default=None)
|
||||||
_data: AsyncIOMotorCollection = field(default=None)
|
_data: AsyncCollection = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
|
@ -311,6 +309,9 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
async def delete(self, ids: list[str]) -> None:
|
||||||
|
await self._data.delete_many({"_id": {"$in": ids}})
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -319,8 +320,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db: AsyncIOMotorDatabase = field(default=None)
|
db: AsyncDatabase = field(default=None)
|
||||||
collection: AsyncIOMotorCollection = field(default=None)
|
collection: AsyncCollection = field(default=None)
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __init__(self, namespace, global_config, embedding_func):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -350,6 +351,29 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# Sample entity_relation document
|
||||||
|
# {
|
||||||
|
# "_id" : "CompanyA",
|
||||||
|
# "created_at" : 1749904575,
|
||||||
|
# "description" : "A major technology company",
|
||||||
|
# "edges" : [
|
||||||
|
# {
|
||||||
|
# "target" : "ProductX",
|
||||||
|
# "relation": "Develops", // To distinguish multiple same-target relations
|
||||||
|
# "weight" : Double("1"),
|
||||||
|
# "description" : "CompanyA develops ProductX",
|
||||||
|
# "keywords" : "develop, produce",
|
||||||
|
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
|
||||||
|
# "file_path" : "custom_kg",
|
||||||
|
# "created_at" : 1749904575
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# "entity_id" : "CompanyA",
|
||||||
|
# "entity_type" : "Organization",
|
||||||
|
# "file_path" : "custom_kg",
|
||||||
|
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
|
||||||
|
# }
|
||||||
|
|
||||||
async def _graph_lookup(
|
async def _graph_lookup(
|
||||||
self, start_node_id: str, max_depth: int = None
|
self, start_node_id: str, max_depth: int = None
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
|
|
@ -388,7 +412,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
||||||
|
|
||||||
# Return the matching doc plus a field "reachableNodes"
|
# Return the matching doc plus a field "reachableNodes"
|
||||||
cursor = self.collection.aggregate(pipeline)
|
cursor = await self.collection.aggregate(pipeline)
|
||||||
results = await cursor.to_list(None)
|
results = await cursor.to_list(None)
|
||||||
|
|
||||||
# If there's no matching node, results = [].
|
# If there's no matching node, results = [].
|
||||||
|
|
@ -413,44 +437,12 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||||
|
|
||||||
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
|
|
||||||
"Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
|
|
||||||
and then see if the target node is in the "reachableNodes" at depth=0.
|
|
||||||
|
|
||||||
But typically for a direct edge, we might just do a find_one.
|
|
||||||
Below is a demonstration approach.
|
|
||||||
"""
|
"""
|
||||||
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
# Direct check if the target_node appears among the edges array.
|
||||||
# Then check if the target_node appears among the edges array.
|
doc = await self.collection.find_one(
|
||||||
pipeline = [
|
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
|
||||||
{"$match": {"_id": source_node_id}},
|
)
|
||||||
{
|
return doc is not None
|
||||||
"$graphLookup": {
|
|
||||||
"from": self.collection.name,
|
|
||||||
"startWith": "$edges.target",
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "reachableNodes",
|
|
||||||
"depthField": "depth",
|
|
||||||
"maxDepth": 0, # means: do not follow beyond immediate edges
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$project": {
|
|
||||||
"_id": 0,
|
|
||||||
"reachableNodes._id": 1, # only keep the _id from the subdocs
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
cursor = self.collection.aggregate(pipeline)
|
|
||||||
results = await cursor.to_list(None)
|
|
||||||
if not results:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# results[0]["reachableNodes"] are the immediate neighbors
|
|
||||||
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
|
|
||||||
return target_node_id in reachable_ids
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -464,82 +456,38 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
The easiest approach is typically two queries:
|
The easiest approach is typically two queries:
|
||||||
- count of edges array in node_id's doc
|
- count of edges array in node_id's doc
|
||||||
- count of how many other docs have node_id in their edges.target.
|
- count of how many other docs have node_id in their edges.target.
|
||||||
|
|
||||||
But we'll do a $graphLookup demonstration for inbound edges:
|
|
||||||
1) Outbound edges: direct from node's edges array
|
|
||||||
2) Inbound edges: we can do a special $graphLookup from all docs
|
|
||||||
or do an explicit match.
|
|
||||||
|
|
||||||
For demonstration, let's do this in two steps (with second step $graphLookup).
|
|
||||||
"""
|
"""
|
||||||
# --- 1) Outbound edges (direct from doc) ---
|
# --- 1) Outbound edges (direct from doc) ---
|
||||||
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
|
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
|
||||||
if not doc:
|
if not doc:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
outbound_count = len(doc.get("edges", []))
|
outbound_count = len(doc.get("edges", []))
|
||||||
|
|
||||||
# --- 2) Inbound edges:
|
# --- 2) Inbound edges:
|
||||||
# A simple way is: find all docs where "edges.target" == node_id.
|
inbound_count = await self.collection.count_documents({"edges.target": node_id})
|
||||||
# But let's do a $graphLookup from `node_id` in REVERSE.
|
|
||||||
# There's a trick to do "reverse" graphLookups: you'd store
|
|
||||||
# reversed edges or do a more advanced pipeline. Typically you'd do
|
|
||||||
# a direct match. We'll just do a direct match for inbound.
|
|
||||||
inbound_count_pipeline = [
|
|
||||||
{"$match": {"edges.target": node_id}},
|
|
||||||
{
|
|
||||||
"$project": {
|
|
||||||
"matchingEdgesCount": {
|
|
||||||
"$size": {
|
|
||||||
"$filter": {
|
|
||||||
"input": "$edges",
|
|
||||||
"as": "edge",
|
|
||||||
"cond": {"$eq": ["$$edge.target", node_id]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
|
|
||||||
]
|
|
||||||
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
|
|
||||||
inbound_result = await inbound_cursor.to_list(None)
|
|
||||||
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
|
|
||||||
|
|
||||||
return outbound_count + inbound_count
|
return outbound_count + inbound_count
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
"""
|
"""Get the total degree (sum of relationships) of two nodes.
|
||||||
If your graph can hold multiple edges from the same src to the same tgt
|
|
||||||
(e.g. different 'relation' values), you can sum them. If it's always
|
|
||||||
one edge, this is typically 1 or 0.
|
|
||||||
|
|
||||||
We'll do a single-hop $graphLookup from src_id,
|
Args:
|
||||||
then count how many edges reference tgt_id at depth=0.
|
src_id: Label of the source node
|
||||||
"""
|
tgt_id: Label of the target node
|
||||||
pipeline = [
|
|
||||||
{"$match": {"_id": src_id}},
|
|
||||||
{
|
|
||||||
"$graphLookup": {
|
|
||||||
"from": self.collection.name,
|
|
||||||
"startWith": "$edges.target",
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "neighbors",
|
|
||||||
"depthField": "depth",
|
|
||||||
"maxDepth": 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
|
|
||||||
]
|
|
||||||
cursor = self.collection.aggregate(pipeline)
|
|
||||||
results = await cursor.to_list(None)
|
|
||||||
if not results:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# We can simply count how many edges in `results[0].edges` have target == tgt_id.
|
Returns:
|
||||||
edges = results[0].get("edges", [])
|
int: Sum of the degrees of both nodes
|
||||||
count = sum(1 for e in edges if e.get("target") == tgt_id)
|
"""
|
||||||
return count
|
src_degree = await self.node_degree(src_id)
|
||||||
|
trg_degree = await self.node_degree(tgt_id)
|
||||||
|
|
||||||
|
# Convert None to 0 for addition
|
||||||
|
src_degree = 0 if src_degree is None else src_degree
|
||||||
|
trg_degree = 0 if trg_degree is None else trg_degree
|
||||||
|
|
||||||
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
|
return degrees
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -556,58 +504,142 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
pipeline = [
|
doc = await self.collection.find_one(
|
||||||
{"$match": {"_id": source_node_id}},
|
|
||||||
{
|
{
|
||||||
"$graphLookup": {
|
"$or": [
|
||||||
"from": self.collection.name,
|
{"_id": source_node_id, "edges.target": target_node_id},
|
||||||
"startWith": "$edges.target",
|
{"_id": target_node_id, "edges.target": source_node_id},
|
||||||
"connectFromField": "edges.target",
|
]
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "neighbors",
|
|
||||||
"depthField": "depth",
|
|
||||||
"maxDepth": 0,
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{"$project": {"edges": 1}},
|
{"edges": 1},
|
||||||
]
|
)
|
||||||
cursor = self.collection.aggregate(pipeline)
|
if not doc:
|
||||||
docs = await cursor.to_list(None)
|
|
||||||
if not docs:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for e in docs[0].get("edges", []):
|
for e in doc.get("edges", []):
|
||||||
if e.get("target") == target_node_id:
|
if e.get("target") == target_node_id:
|
||||||
return e
|
return e
|
||||||
|
if e.get("target") == source_node_id:
|
||||||
|
return e
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
|
||||||
|
Args:
|
||||||
|
source_node_id: Label of the node to get edges for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
||||||
|
None: If no edges found
|
||||||
"""
|
"""
|
||||||
pipeline = [
|
doc = await self.get_node(source_node_id)
|
||||||
{"$match": {"_id": source_node_id}},
|
if not doc:
|
||||||
{
|
|
||||||
"$graphLookup": {
|
|
||||||
"from": self.collection.name,
|
|
||||||
"startWith": "$edges.target",
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "neighbors",
|
|
||||||
"depthField": "depth",
|
|
||||||
"maxDepth": 0,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$project": {"_id": 0, "edges": 1}},
|
|
||||||
]
|
|
||||||
cursor = self.collection.aggregate(pipeline)
|
|
||||||
result = await cursor.to_list(None)
|
|
||||||
if not result:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
edges = result[0].get("edges", [])
|
edges = []
|
||||||
return [(source_node_id, e["target"]) for e in edges]
|
for e in doc.get("edges", []):
|
||||||
|
edges.append((source_node_id, e.get("target")))
|
||||||
|
|
||||||
|
cursor = self.collection.find({"edges.target": source_node_id})
|
||||||
|
source_docs = await cursor.to_list(None)
|
||||||
|
if not source_docs:
|
||||||
|
return edges
|
||||||
|
|
||||||
|
for doc in source_docs:
|
||||||
|
edges.append((doc.get("_id"), source_node_id))
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
async for doc in self.collection.find({"_id": {"$in": node_ids}}):
|
||||||
|
result[doc.get("_id")] = doc
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||||
|
# merge the outbound and inbound results with the same "_id" and sum the "degree"
|
||||||
|
merged_results = {}
|
||||||
|
|
||||||
|
# Outbound degrees
|
||||||
|
outbound_pipeline = [
|
||||||
|
{"$match": {"_id": {"$in": node_ids}}},
|
||||||
|
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = await self.collection.aggregate(outbound_pipeline)
|
||||||
|
async for doc in cursor:
|
||||||
|
merged_results[doc.get("_id")] = doc.get("degree")
|
||||||
|
|
||||||
|
# Inbound degrees
|
||||||
|
inbound_pipeline = [
|
||||||
|
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||||
|
{"$project": {"_id": 1, "edges.target": 1}},
|
||||||
|
{"$unwind": "$edges"},
|
||||||
|
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": "$edges.target",
|
||||||
|
"degree": {
|
||||||
|
"$sum": 1
|
||||||
|
}, # Count the number of incoming edges for each target node
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||||
|
async for doc in cursor:
|
||||||
|
merged_results[doc.get("_id")] = merged_results.get(
|
||||||
|
doc.get("_id"), 0
|
||||||
|
) + doc.get("degree")
|
||||||
|
|
||||||
|
return merged_results
|
||||||
|
|
||||||
|
async def get_nodes_edges_batch(
|
||||||
|
self, node_ids: list[str]
|
||||||
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
|
"""
|
||||||
|
Batch retrieve edges for multiple nodes.
|
||||||
|
For each node, returns both outgoing and incoming edges to properly represent
|
||||||
|
the undirected graph nature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||||
|
For each node, the list includes both:
|
||||||
|
- Outgoing edges: (queried_node, connected_node)
|
||||||
|
- Incoming edges: (connected_node, queried_node)
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
cursor = self.collection.find(
|
||||||
|
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
|
||||||
|
)
|
||||||
|
async for doc in cursor:
|
||||||
|
node_id = doc.get("_id")
|
||||||
|
edges = doc.get("edges", [])
|
||||||
|
result[node_id] = [(node_id, e["target"]) for e in edges]
|
||||||
|
|
||||||
|
inbound_pipeline = [
|
||||||
|
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||||
|
{"$project": {"_id": 1, "edges.target": 1}},
|
||||||
|
{"$unwind": "$edges"},
|
||||||
|
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||||
|
{"$project": {"_id": "$_id", "target": "$edges.target"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||||
|
async for doc in cursor:
|
||||||
|
node_id = doc.get("target")
|
||||||
|
result[node_id] = result.get(node_id, [])
|
||||||
|
result[node_id].append((doc.get("_id"), node_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -675,20 +707,18 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
Returns:
|
Returns:
|
||||||
[id1, id2, ...] # Alphabetically sorted id list
|
[id1, id2, ...] # Alphabetically sorted id list
|
||||||
"""
|
"""
|
||||||
# Use MongoDB's distinct and aggregation to get all unique labels
|
|
||||||
pipeline = [
|
|
||||||
{"$group": {"_id": "$_id"}}, # Group by _id
|
|
||||||
{"$sort": {"_id": 1}}, # Sort alphabetically
|
|
||||||
]
|
|
||||||
|
|
||||||
cursor = self.collection.aggregate(pipeline)
|
cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)])
|
||||||
labels = []
|
labels = []
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
labels.append(doc["_id"])
|
labels.append(doc["_id"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self,
|
||||||
|
node_label: str,
|
||||||
|
max_depth: int = 5,
|
||||||
|
max_nodes: int = MAX_GRAPH_NODES,
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Get complete connected subgraph for specified node (including the starting node itself)
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
@ -893,17 +923,25 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
if not edges:
|
if not edges:
|
||||||
return
|
return
|
||||||
|
|
||||||
update_tasks = []
|
# Group edges by source node (document _id) for efficient updates
|
||||||
for source, target in edges:
|
edge_groups = {}
|
||||||
# Remove edge pointing to target from source node's edges array
|
for source_id, target_id in edges:
|
||||||
update_tasks.append(
|
if source_id not in edge_groups:
|
||||||
|
edge_groups[source_id] = []
|
||||||
|
edge_groups[source_id].append(target_id)
|
||||||
|
|
||||||
|
# Bulk update documents to remove the edges
|
||||||
|
update_operations = []
|
||||||
|
for source_id, target_ids in edge_groups.items():
|
||||||
|
update_operations.append(
|
||||||
self.collection.update_one(
|
self.collection.update_one(
|
||||||
{"_id": source}, {"$pull": {"edges": {"target": target}}}
|
{"_id": source_id},
|
||||||
|
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if update_tasks:
|
if update_operations:
|
||||||
await asyncio.gather(*update_tasks)
|
await asyncio.gather(*update_operations)
|
||||||
|
|
||||||
logger.debug(f"Successfully deleted edges: {edges}")
|
logger.debug(f"Successfully deleted edges: {edges}")
|
||||||
|
|
||||||
|
|
@ -932,8 +970,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoVectorDBStorage(BaseVectorStorage):
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
db: AsyncIOMotorDatabase | None = field(default=None)
|
db: AsyncDatabase | None = field(default=None)
|
||||||
_data: AsyncIOMotorCollection | None = field(default=None)
|
_data: AsyncCollection | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
|
@ -1232,7 +1270,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
|
||||||
collection_names = await db.list_collection_names()
|
collection_names = await db.list_collection_names()
|
||||||
|
|
||||||
if collection_name not in collection_names:
|
if collection_name not in collection_names:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue