MongoGraph: Separate edges from node collection

This commit is contained in:
Ken Chen 2025-06-21 21:05:04 +08:00
parent cf441aa84c
commit a047d966ab

View file

@ -4,7 +4,7 @@ import numpy as np
import configparser
import asyncio
from typing import Any, List, Union, final
from typing import Any, Union, final
from ..base import (
BaseGraphStorage,
@ -321,7 +321,10 @@ class MongoGraphStorage(BaseGraphStorage):
"""
db: AsyncDatabase = field(default=None)
# node collection storing node_id, node_properties
collection: AsyncCollection = field(default=None)
# edge collection storing source_node_id, target_node_id, and edge_properties
edgeCollection: AsyncCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
@ -330,6 +333,7 @@ class MongoGraphStorage(BaseGraphStorage):
embedding_func=embedding_func,
)
self._collection_name = self.namespace
self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self):
if self.db is None:
@ -337,6 +341,9 @@ class MongoGraphStorage(BaseGraphStorage):
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
self.edge_collection = await get_or_create_collection(
self.db, self._edge_collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def finalize(self):
@ -344,6 +351,7 @@ class MongoGraphStorage(BaseGraphStorage):
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
self.edge_collection = None
#
# -------------------------------------------------------------------------
@ -374,52 +382,6 @@ class MongoGraphStorage(BaseGraphStorage):
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
# }
async def _graph_lookup(
self, start_node_id: str, max_depth: int = None
) -> List[dict]:
"""
Performs a $graphLookup starting from 'start_node_id' and returns
all reachable documents (including the start node itself).
Pipeline Explanation:
- 1) $match: We match the start node document by _id = start_node_id.
- 2) $graphLookup:
"from": same collection,
"startWith": "$edges.target" (the immediate neighbors in 'edges'),
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"maxDepth": max_depth (if provided),
"depthField": "depth" (used for debugging or filtering).
- 3) We add an $project or $unwind as needed to extract data.
"""
pipeline = [
{"$match": {"_id": start_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
}
},
]
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
if max_depth is not None:
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
# Return the matching doc plus a field "reachableNodes"
cursor = await self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
# If there's no matching node, results = [].
# Otherwise, results[0] is the start node doc,
# plus results[0]["reachableNodes"] is the array of connected docs.
return results
#
# -------------------------------------------------------------------------
# BASIC QUERIES
@ -439,8 +401,9 @@ class MongoGraphStorage(BaseGraphStorage):
Check if there's a direct single-hop edge from source_node_id to target_node_id.
"""
# Direct check if the target_node appears among the edges array.
doc = await self.collection.find_one(
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
doc = await self.edge_collection.find_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{"_id": 1},
)
return doc is not None
@ -453,21 +416,10 @@ class MongoGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""
Returns the total number of edges connected to node_id (both inbound and outbound).
The easiest approach is typically two queries:
- count of edges array in node_id's doc
- count of how many other docs have node_id in their edges.target.
"""
# --- 1) Outbound edges (direct from doc) ---
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
if not doc:
return 0
outbound_count = len(doc.get("edges", []))
# --- 2) Inbound edges:
inbound_count = await self.collection.count_documents({"edges.target": node_id})
return outbound_count + inbound_count
return await self.edge_collection.count_documents(
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
@ -482,12 +434,7 @@ class MongoGraphStorage(BaseGraphStorage):
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
return src_degree + trg_degree
#
# -------------------------------------------------------------------------
@ -497,32 +444,27 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""
Return the full node document (including "edges"), or None if missing.
Return the full node document, or None if missing.
"""
return await self.collection.find_one({"_id": node_id})
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
doc = await self.collection.find_one(
return await self.edge_collection.find_one(
{
"$or": [
{"_id": source_node_id, "edges.target": target_node_id},
{"_id": target_node_id, "edges.target": source_node_id},
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
{"edges": 1},
}
)
if not doc:
return None
for e in doc.get("edges", []):
if e.get("target") == target_node_id:
return e
if e.get("target") == source_node_id:
return e
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
@ -535,23 +477,19 @@ class MongoGraphStorage(BaseGraphStorage):
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
"""
doc = await self.get_node(source_node_id)
if not doc:
return None
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": source_node_id},
{"target_node_id": source_node_id},
]
},
{"source_node_id": 1, "target_node_id": 1},
)
edges = []
for e in doc.get("edges", []):
edges.append((source_node_id, e.get("target")))
cursor = self.collection.find({"edges.target": source_node_id})
source_docs = await cursor.to_list(None)
if not source_docs:
return edges
for doc in source_docs:
edges.append((doc.get("_id"), source_node_id))
return edges
return [
(e.get("source_node_id"), e.get("target_node_id")) async for e in cursor
]
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
result = {}
@ -566,31 +504,21 @@ class MongoGraphStorage(BaseGraphStorage):
# Outbound degrees
outbound_pipeline = [
{"$match": {"_id": {"$in": node_ids}}},
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
{"$match": {"source_node_id": {"$in": node_ids}}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
]
cursor = await self.collection.aggregate(outbound_pipeline)
cursor = await self.edge_collection.aggregate(outbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = doc.get("degree")
# Inbound degrees
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{
"$group": {
"_id": "$edges.target",
"degree": {
"$sum": 1
}, # Count the number of incoming edges for each target node
}
},
{"$match": {"target_node_id": {"$in": node_ids}}},
{"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}},
]
cursor = await self.collection.aggregate(inbound_pipeline)
cursor = await self.edge_collection.aggregate(inbound_pipeline)
async for doc in cursor:
merged_results[doc.get("_id")] = merged_results.get(
doc.get("_id"), 0
@ -615,29 +543,27 @@ class MongoGraphStorage(BaseGraphStorage):
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
result = {}
result = {node_id: [] for node_id in node_ids}
cursor = self.collection.find(
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
# Query outgoing edges (where node is the source)
outgoing_cursor = self.edge_collection.find(
{"source_node_id": {"$in": node_ids}},
{"source_node_id": 1, "target_node_id": 1},
)
async for doc in cursor:
node_id = doc.get("_id")
edges = doc.get("edges", [])
result[node_id] = [(node_id, e["target"]) for e in edges]
async for edge in outgoing_cursor:
source = edge["source_node_id"]
target = edge["target_node_id"]
result[source].append((source, target))
inbound_pipeline = [
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": 1, "edges.target": 1}},
{"$unwind": "$edges"},
{"$match": {"edges.target": {"$in": node_ids}}},
{"$project": {"_id": "$_id", "target": "$edges.target"}},
]
cursor = await self.collection.aggregate(inbound_pipeline)
async for doc in cursor:
node_id = doc.get("target")
result[node_id] = result.get(node_id, [])
result[node_id].append((doc.get("_id"), node_id))
# Query incoming edges (where node is the target)
incoming_cursor = self.edge_collection.find(
{"target_node_id": {"$in": node_ids}},
{"source_node_id": 1, "target_node_id": 1},
)
async for edge in incoming_cursor:
source = edge["source_node_id"]
target = edge["target_node_id"]
result[target].append((source, target))
return result
@ -649,11 +575,9 @@ class MongoGraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Insert or update a node document. If new, create an empty edges array.
Insert or update a node document.
"""
# By default, preserve existing 'edges'.
# We'll only set 'edges' to [] on insert (no overwrite).
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
update_doc = {"$set": {**node_data}}
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge(
@ -666,16 +590,10 @@ class MongoGraphStorage(BaseGraphStorage):
# Ensure source node exists
await self.upsert_node(source_node_id, {})
# Remove existing edge (if any)
await self.collection.update_one(
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
)
# Insert new edge
new_edge = {"target": target_node_id}
new_edge.update(edge_data)
await self.collection.update_one(
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
await self.edge_collection.update_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{"$set": edge_data},
upsert=True,
)
#
@ -689,8 +607,10 @@ class MongoGraphStorage(BaseGraphStorage):
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
"""
# Remove inbound edges from all other docs
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
# Remove all edges
await self.edge_collection.delete_many(
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
)
# Remove the node doc
await self.collection.delete_one({"_id": node_id})
@ -734,151 +654,92 @@ class MongoGraphStorage(BaseGraphStorage):
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
node_edges = []
try:
if label == "*":
# Get all nodes and edges
async for node_doc in self.collection.find({}):
node_id = str(node_doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_doc.get("_id")],
properties={
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
},
)
)
seen_nodes.add(node_id)
pipeline = [
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
# Process edges
for edge in node_doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
if label == "*":
all_node_count = await self.collection.count_documents({})
result.is_truncated = all_node_count > max_nodes
else:
# Verify if starting node exists
start_nodes = self.collection.find({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1)
if not start_nodes_exist:
start_node = await self.collection.find_one({"_id": label})
if not start_node:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# Use $graphLookup for traversal
pipeline = [
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
# Add starting node to pipeline
pipeline.insert(0, {"$match": {"_id": label}})
async for doc in self.collection.aggregate(pipeline):
# Add the start node
node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[
doc.get(
"_id",
)
],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"edges",
"connected_nodes",
"depth",
]
},
)
cursor = await self.collection.aggregate(pipeline)
async for doc in cursor:
# Add the start node
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
)
seen_nodes.add(node_id)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
]
},
)
seen_nodes.add(node_id)
# Add edges from start node
for edge in doc.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
# Add connected nodes and their edges
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
@ -903,9 +764,14 @@ class MongoGraphStorage(BaseGraphStorage):
if not nodes:
return
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
await self.collection.update_many(
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
# 1. Remove all edges referencing these nodes
await self.edge_collection.delete_many(
{
"$or": [
{"source_node_id": {"$in": nodes}},
{"target_node_id": {"$in": nodes}},
]
}
)
# 2. Delete the node documents
@ -923,25 +789,14 @@ class MongoGraphStorage(BaseGraphStorage):
if not edges:
return
# Group edges by source node (document _id) for efficient updates
edge_groups = {}
for source_id, target_id in edges:
if source_id not in edge_groups:
edge_groups[source_id] = []
edge_groups[source_id].append(target_id)
# Bulk update documents to remove the edges
update_operations = []
for source_id, target_ids in edge_groups.items():
update_operations.append(
self.collection.update_one(
{"_id": source_id},
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
)
)
if update_operations:
await asyncio.gather(*update_operations)
await self.edge_collection.delete_many(
{
"$or": [
{"source_node_id": source_id, "target_node_id": target_id}
for source_id, target_id in edges
]
}
)
logger.debug(f"Successfully deleted edges: {edges}")
@ -958,9 +813,16 @@ class MongoGraphStorage(BaseGraphStorage):
logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}"
)
result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count
logger.info(
f"Dropped {edge_count} edges from graph {self._edge_collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
"message": f"{deleted_count} documents and {edge_count} edges dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}")