MongoGraph: Separate edges from node collection
This commit is contained in:
parent
cf441aa84c
commit
a047d966ab
1 changed files with 173 additions and 311 deletions
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import configparser
|
||||
import asyncio
|
||||
|
||||
from typing import Any, List, Union, final
|
||||
from typing import Any, Union, final
|
||||
|
||||
from ..base import (
|
||||
BaseGraphStorage,
|
||||
|
|
@ -321,7 +321,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
"""
|
||||
|
||||
db: AsyncDatabase = field(default=None)
|
||||
# node collection storing node_id, node_properties
|
||||
collection: AsyncCollection = field(default=None)
|
||||
# edge collection storing source_node_id, target_node_id, and edge_properties
|
||||
edgeCollection: AsyncCollection = field(default=None)
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
|
|
@ -330,6 +333,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
embedding_func=embedding_func,
|
||||
)
|
||||
self._collection_name = self.namespace
|
||||
self._edge_collection_name = f"{self._collection_name}_edges"
|
||||
|
||||
async def initialize(self):
|
||||
if self.db is None:
|
||||
|
|
@ -337,6 +341,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
self.collection = await get_or_create_collection(
|
||||
self.db, self._collection_name
|
||||
)
|
||||
self.edge_collection = await get_or_create_collection(
|
||||
self.db, self._edge_collection_name
|
||||
)
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
|
||||
async def finalize(self):
|
||||
|
|
@ -344,6 +351,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
self.collection = None
|
||||
self.edge_collection = None
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -374,52 +382,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
|
||||
# }
|
||||
|
||||
async def _graph_lookup(
|
||||
self, start_node_id: str, max_depth: int = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Performs a $graphLookup starting from 'start_node_id' and returns
|
||||
all reachable documents (including the start node itself).
|
||||
|
||||
Pipeline Explanation:
|
||||
- 1) $match: We match the start node document by _id = start_node_id.
|
||||
- 2) $graphLookup:
|
||||
"from": same collection,
|
||||
"startWith": "$edges.target" (the immediate neighbors in 'edges'),
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"maxDepth": max_depth (if provided),
|
||||
"depthField": "depth" (used for debugging or filtering).
|
||||
- 3) We add an $project or $unwind as needed to extract data.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": start_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"depthField": "depth",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
|
||||
if max_depth is not None:
|
||||
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
||||
|
||||
# Return the matching doc plus a field "reachableNodes"
|
||||
cursor = await self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
|
||||
# If there's no matching node, results = [].
|
||||
# Otherwise, results[0] is the start node doc,
|
||||
# plus results[0]["reachableNodes"] is the array of connected docs.
|
||||
return results
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# BASIC QUERIES
|
||||
|
|
@ -439,8 +401,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||
"""
|
||||
# Direct check if the target_node appears among the edges array.
|
||||
doc = await self.collection.find_one(
|
||||
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
|
||||
doc = await self.edge_collection.find_one(
|
||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||
{"_id": 1},
|
||||
)
|
||||
return doc is not None
|
||||
|
||||
|
|
@ -453,21 +416,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
async def node_degree(self, node_id: str) -> int:
|
||||
"""
|
||||
Returns the total number of edges connected to node_id (both inbound and outbound).
|
||||
The easiest approach is typically two queries:
|
||||
- count of edges array in node_id's doc
|
||||
- count of how many other docs have node_id in their edges.target.
|
||||
"""
|
||||
# --- 1) Outbound edges (direct from doc) ---
|
||||
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
|
||||
if not doc:
|
||||
return 0
|
||||
|
||||
outbound_count = len(doc.get("edges", []))
|
||||
|
||||
# --- 2) Inbound edges:
|
||||
inbound_count = await self.collection.count_documents({"edges.target": node_id})
|
||||
|
||||
return outbound_count + inbound_count
|
||||
return await self.edge_collection.count_documents(
|
||||
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
|
||||
)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""Get the total degree (sum of relationships) of two nodes.
|
||||
|
|
@ -482,12 +434,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
src_degree = await self.node_degree(src_id)
|
||||
trg_degree = await self.node_degree(tgt_id)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
trg_degree = 0 if trg_degree is None else trg_degree
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
return src_degree + trg_degree
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -497,32 +444,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Return the full node document (including "edges"), or None if missing.
|
||||
Return the full node document, or None if missing.
|
||||
"""
|
||||
return await self.collection.find_one({"_id": node_id})
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
doc = await self.collection.find_one(
|
||||
return await self.edge_collection.find_one(
|
||||
{
|
||||
"$or": [
|
||||
{"_id": source_node_id, "edges.target": target_node_id},
|
||||
{"_id": target_node_id, "edges.target": source_node_id},
|
||||
{
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id": target_node_id,
|
||||
},
|
||||
{
|
||||
"source_node_id": target_node_id,
|
||||
"target_node_id": source_node_id,
|
||||
},
|
||||
]
|
||||
},
|
||||
{"edges": 1},
|
||||
}
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
for e in doc.get("edges", []):
|
||||
if e.get("target") == target_node_id:
|
||||
return e
|
||||
if e.get("target") == source_node_id:
|
||||
return e
|
||||
|
||||
return None
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""
|
||||
|
|
@ -535,23 +477,19 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
||||
None: If no edges found
|
||||
"""
|
||||
doc = await self.get_node(source_node_id)
|
||||
if not doc:
|
||||
return None
|
||||
cursor = self.edge_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"source_node_id": source_node_id},
|
||||
{"target_node_id": source_node_id},
|
||||
]
|
||||
},
|
||||
{"source_node_id": 1, "target_node_id": 1},
|
||||
)
|
||||
|
||||
edges = []
|
||||
for e in doc.get("edges", []):
|
||||
edges.append((source_node_id, e.get("target")))
|
||||
|
||||
cursor = self.collection.find({"edges.target": source_node_id})
|
||||
source_docs = await cursor.to_list(None)
|
||||
if not source_docs:
|
||||
return edges
|
||||
|
||||
for doc in source_docs:
|
||||
edges.append((doc.get("_id"), source_node_id))
|
||||
|
||||
return edges
|
||||
return [
|
||||
(e.get("source_node_id"), e.get("target_node_id")) async for e in cursor
|
||||
]
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
result = {}
|
||||
|
|
@ -566,31 +504,21 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
|
||||
# Outbound degrees
|
||||
outbound_pipeline = [
|
||||
{"$match": {"_id": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
|
||||
{"$match": {"source_node_id": {"$in": node_ids}}},
|
||||
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(outbound_pipeline)
|
||||
cursor = await self.edge_collection.aggregate(outbound_pipeline)
|
||||
async for doc in cursor:
|
||||
merged_results[doc.get("_id")] = doc.get("degree")
|
||||
|
||||
# Inbound degrees
|
||||
inbound_pipeline = [
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "edges.target": 1}},
|
||||
{"$unwind": "$edges"},
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$edges.target",
|
||||
"degree": {
|
||||
"$sum": 1
|
||||
}, # Count the number of incoming edges for each target node
|
||||
}
|
||||
},
|
||||
{"$match": {"target_node_id": {"$in": node_ids}}},
|
||||
{"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||
cursor = await self.edge_collection.aggregate(inbound_pipeline)
|
||||
async for doc in cursor:
|
||||
merged_results[doc.get("_id")] = merged_results.get(
|
||||
doc.get("_id"), 0
|
||||
|
|
@ -615,29 +543,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
- Outgoing edges: (queried_node, connected_node)
|
||||
- Incoming edges: (connected_node, queried_node)
|
||||
"""
|
||||
result = {}
|
||||
result = {node_id: [] for node_id in node_ids}
|
||||
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
|
||||
# Query outgoing edges (where node is the source)
|
||||
outgoing_cursor = self.edge_collection.find(
|
||||
{"source_node_id": {"$in": node_ids}},
|
||||
{"source_node_id": 1, "target_node_id": 1},
|
||||
)
|
||||
async for doc in cursor:
|
||||
node_id = doc.get("_id")
|
||||
edges = doc.get("edges", [])
|
||||
result[node_id] = [(node_id, e["target"]) for e in edges]
|
||||
async for edge in outgoing_cursor:
|
||||
source = edge["source_node_id"]
|
||||
target = edge["target_node_id"]
|
||||
result[source].append((source, target))
|
||||
|
||||
inbound_pipeline = [
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": 1, "edges.target": 1}},
|
||||
{"$unwind": "$edges"},
|
||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
||||
{"$project": {"_id": "$_id", "target": "$edges.target"}},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
||||
async for doc in cursor:
|
||||
node_id = doc.get("target")
|
||||
result[node_id] = result.get(node_id, [])
|
||||
result[node_id].append((doc.get("_id"), node_id))
|
||||
# Query incoming edges (where node is the target)
|
||||
incoming_cursor = self.edge_collection.find(
|
||||
{"target_node_id": {"$in": node_ids}},
|
||||
{"source_node_id": 1, "target_node_id": 1},
|
||||
)
|
||||
async for edge in incoming_cursor:
|
||||
source = edge["source_node_id"]
|
||||
target = edge["target_node_id"]
|
||||
result[target].append((source, target))
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -649,11 +575,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Insert or update a node document. If new, create an empty edges array.
|
||||
Insert or update a node document.
|
||||
"""
|
||||
# By default, preserve existing 'edges'.
|
||||
# We'll only set 'edges' to [] on insert (no overwrite).
|
||||
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
|
||||
update_doc = {"$set": {**node_data}}
|
||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||
|
||||
async def upsert_edge(
|
||||
|
|
@ -666,16 +590,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
# Ensure source node exists
|
||||
await self.upsert_node(source_node_id, {})
|
||||
|
||||
# Remove existing edge (if any)
|
||||
await self.collection.update_one(
|
||||
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
|
||||
)
|
||||
|
||||
# Insert new edge
|
||||
new_edge = {"target": target_node_id}
|
||||
new_edge.update(edge_data)
|
||||
await self.collection.update_one(
|
||||
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
|
||||
await self.edge_collection.update_one(
|
||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||
{"$set": edge_data},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
#
|
||||
|
|
@ -689,8 +607,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
1) Remove node's doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
"""
|
||||
# Remove inbound edges from all other docs
|
||||
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
|
||||
# Remove all edges
|
||||
await self.edge_collection.delete_many(
|
||||
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
|
||||
)
|
||||
|
||||
# Remove the node doc
|
||||
await self.collection.delete_one({"_id": node_id})
|
||||
|
|
@ -734,151 +654,92 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
node_edges = []
|
||||
|
||||
try:
|
||||
if label == "*":
|
||||
# Get all nodes and edges
|
||||
async for node_doc in self.collection.find({}):
|
||||
node_id = str(node_doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_doc.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in node_doc.items()
|
||||
if k not in ["_id", "edges"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
pipeline = [
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||
{"$sort": {"edge_count": -1}},
|
||||
{"$limit": max_nodes},
|
||||
]
|
||||
|
||||
# Process edges
|
||||
for edge in node_doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
if label == "*":
|
||||
all_node_count = await self.collection.count_documents({})
|
||||
result.is_truncated = all_node_count > max_nodes
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_nodes = self.collection.find({"_id": label})
|
||||
start_nodes_exist = await start_nodes.to_list(length=1)
|
||||
if not start_nodes_exist:
|
||||
start_node = await self.collection.find_one({"_id": label})
|
||||
if not start_node:
|
||||
logger.warning(f"Starting node with label {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Use $graphLookup for traversal
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {"_id": label}
|
||||
}, # Start with nodes having the specified label
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._collection_name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_nodes",
|
||||
}
|
||||
},
|
||||
]
|
||||
# Add starting node to pipeline
|
||||
pipeline.insert(0, {"$match": {"_id": label}})
|
||||
|
||||
async for doc in self.collection.aggregate(pipeline):
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[
|
||||
doc.get(
|
||||
"_id",
|
||||
)
|
||||
],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"edges",
|
||||
"connected_nodes",
|
||||
"depth",
|
||||
]
|
||||
},
|
||||
)
|
||||
cursor = await self.collection.aggregate(pipeline)
|
||||
async for doc in cursor:
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"connected_edges",
|
||||
"edge_count",
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
or edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
continue
|
||||
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relationship", ""),
|
||||
source=edge["source_node_id"],
|
||||
target=edge["target_node_id"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"source_node_id",
|
||||
"target_node_id",
|
||||
"relationship",
|
||||
]
|
||||
},
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from start node
|
||||
for edge in doc.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
# Add connected nodes and their edges
|
||||
for connected in doc.get("connected_nodes", []):
|
||||
node_id = str(connected["_id"])
|
||||
if node_id not in seen_nodes:
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[connected.get("_id")],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in connected.items()
|
||||
if k not in ["_id", "edges", "depth"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Add edges from connected nodes
|
||||
for edge in connected.get("edges", []):
|
||||
edge_id = f"{node_id}-{edge['target']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relation", ""),
|
||||
source=node_id,
|
||||
target=edge["target"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k not in ["target", "relation"]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
|
|
@ -903,9 +764,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
if not nodes:
|
||||
return
|
||||
|
||||
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
|
||||
await self.collection.update_many(
|
||||
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
|
||||
# 1. Remove all edges referencing these nodes
|
||||
await self.edge_collection.delete_many(
|
||||
{
|
||||
"$or": [
|
||||
{"source_node_id": {"$in": nodes}},
|
||||
{"target_node_id": {"$in": nodes}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Delete the node documents
|
||||
|
|
@ -923,25 +789,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
if not edges:
|
||||
return
|
||||
|
||||
# Group edges by source node (document _id) for efficient updates
|
||||
edge_groups = {}
|
||||
for source_id, target_id in edges:
|
||||
if source_id not in edge_groups:
|
||||
edge_groups[source_id] = []
|
||||
edge_groups[source_id].append(target_id)
|
||||
|
||||
# Bulk update documents to remove the edges
|
||||
update_operations = []
|
||||
for source_id, target_ids in edge_groups.items():
|
||||
update_operations.append(
|
||||
self.collection.update_one(
|
||||
{"_id": source_id},
|
||||
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
|
||||
)
|
||||
)
|
||||
|
||||
if update_operations:
|
||||
await asyncio.gather(*update_operations)
|
||||
await self.edge_collection.delete_many(
|
||||
{
|
||||
"$or": [
|
||||
{"source_node_id": source_id, "target_node_id": target_id}
|
||||
for source_id, target_id in edges
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully deleted edges: {edges}")
|
||||
|
||||
|
|
@ -958,9 +813,16 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
logger.info(
|
||||
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||
)
|
||||
|
||||
result = await self.edge_collection.delete_many({})
|
||||
edge_count = result.deleted_count
|
||||
logger.info(
|
||||
f"Dropped {edge_count} edges from graph {self._edge_collection_name}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"{deleted_count} documents dropped",
|
||||
"message": f"{deleted_count} documents and {edge_count} edges dropped",
|
||||
}
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue