MongoGraph: Separate edges from node collection
This commit is contained in:
parent
cf441aa84c
commit
a047d966ab
1 changed files with 173 additions and 311 deletions
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import Any, List, Union, final
|
from typing import Any, Union, final
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
|
|
@ -321,7 +321,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db: AsyncDatabase = field(default=None)
|
db: AsyncDatabase = field(default=None)
|
||||||
|
# node collection storing node_id, node_properties
|
||||||
collection: AsyncCollection = field(default=None)
|
collection: AsyncCollection = field(default=None)
|
||||||
|
# edge collection storing source_node_id, target_node_id, and edge_properties
|
||||||
|
edgeCollection: AsyncCollection = field(default=None)
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __init__(self, namespace, global_config, embedding_func):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -330,6 +333,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
)
|
)
|
||||||
self._collection_name = self.namespace
|
self._collection_name = self.namespace
|
||||||
|
self._edge_collection_name = f"{self._collection_name}_edges"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
|
|
@ -337,6 +341,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
self.collection = await get_or_create_collection(
|
self.collection = await get_or_create_collection(
|
||||||
self.db, self._collection_name
|
self.db, self._collection_name
|
||||||
)
|
)
|
||||||
|
self.edge_collection = await get_or_create_collection(
|
||||||
|
self.db, self._edge_collection_name
|
||||||
|
)
|
||||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
|
|
@ -344,6 +351,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
self.collection = None
|
self.collection = None
|
||||||
|
self.edge_collection = None
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -374,52 +382,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
|
# "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
async def _graph_lookup(
|
|
||||||
self, start_node_id: str, max_depth: int = None
|
|
||||||
) -> List[dict]:
|
|
||||||
"""
|
|
||||||
Performs a $graphLookup starting from 'start_node_id' and returns
|
|
||||||
all reachable documents (including the start node itself).
|
|
||||||
|
|
||||||
Pipeline Explanation:
|
|
||||||
- 1) $match: We match the start node document by _id = start_node_id.
|
|
||||||
- 2) $graphLookup:
|
|
||||||
"from": same collection,
|
|
||||||
"startWith": "$edges.target" (the immediate neighbors in 'edges'),
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "reachableNodes",
|
|
||||||
"maxDepth": max_depth (if provided),
|
|
||||||
"depthField": "depth" (used for debugging or filtering).
|
|
||||||
- 3) We add an $project or $unwind as needed to extract data.
|
|
||||||
"""
|
|
||||||
pipeline = [
|
|
||||||
{"$match": {"_id": start_node_id}},
|
|
||||||
{
|
|
||||||
"$graphLookup": {
|
|
||||||
"from": self.collection.name,
|
|
||||||
"startWith": "$edges.target",
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"as": "reachableNodes",
|
|
||||||
"depthField": "depth",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
|
|
||||||
if max_depth is not None:
|
|
||||||
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
|
||||||
|
|
||||||
# Return the matching doc plus a field "reachableNodes"
|
|
||||||
cursor = await self.collection.aggregate(pipeline)
|
|
||||||
results = await cursor.to_list(None)
|
|
||||||
|
|
||||||
# If there's no matching node, results = [].
|
|
||||||
# Otherwise, results[0] is the start node doc,
|
|
||||||
# plus results[0]["reachableNodes"] is the array of connected docs.
|
|
||||||
return results
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# BASIC QUERIES
|
# BASIC QUERIES
|
||||||
|
|
@ -439,8 +401,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||||
"""
|
"""
|
||||||
# Direct check if the target_node appears among the edges array.
|
# Direct check if the target_node appears among the edges array.
|
||||||
doc = await self.collection.find_one(
|
doc = await self.edge_collection.find_one(
|
||||||
{"_id": source_node_id, "edges.target": target_node_id}, {"_id": 1}
|
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||||
|
{"_id": 1},
|
||||||
)
|
)
|
||||||
return doc is not None
|
return doc is not None
|
||||||
|
|
||||||
|
|
@ -453,21 +416,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the total number of edges connected to node_id (both inbound and outbound).
|
Returns the total number of edges connected to node_id (both inbound and outbound).
|
||||||
The easiest approach is typically two queries:
|
|
||||||
- count of edges array in node_id's doc
|
|
||||||
- count of how many other docs have node_id in their edges.target.
|
|
||||||
"""
|
"""
|
||||||
# --- 1) Outbound edges (direct from doc) ---
|
return await self.edge_collection.count_documents(
|
||||||
doc = await self.collection.find_one({"_id": node_id}, {"edges.target": 1})
|
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
|
||||||
if not doc:
|
)
|
||||||
return 0
|
|
||||||
|
|
||||||
outbound_count = len(doc.get("edges", []))
|
|
||||||
|
|
||||||
# --- 2) Inbound edges:
|
|
||||||
inbound_count = await self.collection.count_documents({"edges.target": node_id})
|
|
||||||
|
|
||||||
return outbound_count + inbound_count
|
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
"""Get the total degree (sum of relationships) of two nodes.
|
"""Get the total degree (sum of relationships) of two nodes.
|
||||||
|
|
@ -482,12 +434,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
src_degree = await self.node_degree(src_id)
|
src_degree = await self.node_degree(src_id)
|
||||||
trg_degree = await self.node_degree(tgt_id)
|
trg_degree = await self.node_degree(tgt_id)
|
||||||
|
|
||||||
# Convert None to 0 for addition
|
return src_degree + trg_degree
|
||||||
src_degree = 0 if src_degree is None else src_degree
|
|
||||||
trg_degree = 0 if trg_degree is None else trg_degree
|
|
||||||
|
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
|
||||||
return degrees
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -497,32 +444,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""
|
"""
|
||||||
Return the full node document (including "edges"), or None if missing.
|
Return the full node document, or None if missing.
|
||||||
"""
|
"""
|
||||||
return await self.collection.find_one({"_id": node_id})
|
return await self.collection.find_one({"_id": node_id})
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
doc = await self.collection.find_one(
|
return await self.edge_collection.find_one(
|
||||||
{
|
{
|
||||||
"$or": [
|
"$or": [
|
||||||
{"_id": source_node_id, "edges.target": target_node_id},
|
{
|
||||||
{"_id": target_node_id, "edges.target": source_node_id},
|
"source_node_id": source_node_id,
|
||||||
|
"target_node_id": target_node_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source_node_id": target_node_id,
|
||||||
|
"target_node_id": source_node_id,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
{"edges": 1},
|
|
||||||
)
|
)
|
||||||
if not doc:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for e in doc.get("edges", []):
|
|
||||||
if e.get("target") == target_node_id:
|
|
||||||
return e
|
|
||||||
if e.get("target") == source_node_id:
|
|
||||||
return e
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -535,23 +477,19 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
||||||
None: If no edges found
|
None: If no edges found
|
||||||
"""
|
"""
|
||||||
doc = await self.get_node(source_node_id)
|
cursor = self.edge_collection.find(
|
||||||
if not doc:
|
{
|
||||||
return None
|
"$or": [
|
||||||
|
{"source_node_id": source_node_id},
|
||||||
|
{"target_node_id": source_node_id},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"source_node_id": 1, "target_node_id": 1},
|
||||||
|
)
|
||||||
|
|
||||||
edges = []
|
return [
|
||||||
for e in doc.get("edges", []):
|
(e.get("source_node_id"), e.get("target_node_id")) async for e in cursor
|
||||||
edges.append((source_node_id, e.get("target")))
|
]
|
||||||
|
|
||||||
cursor = self.collection.find({"edges.target": source_node_id})
|
|
||||||
source_docs = await cursor.to_list(None)
|
|
||||||
if not source_docs:
|
|
||||||
return edges
|
|
||||||
|
|
||||||
for doc in source_docs:
|
|
||||||
edges.append((doc.get("_id"), source_node_id))
|
|
||||||
|
|
||||||
return edges
|
|
||||||
|
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
@ -566,31 +504,21 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
# Outbound degrees
|
# Outbound degrees
|
||||||
outbound_pipeline = [
|
outbound_pipeline = [
|
||||||
{"$match": {"_id": {"$in": node_ids}}},
|
{"$match": {"source_node_id": {"$in": node_ids}}},
|
||||||
{"$project": {"_id": 1, "degree": {"$size": "$edges"}}},
|
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
|
||||||
]
|
]
|
||||||
|
|
||||||
cursor = await self.collection.aggregate(outbound_pipeline)
|
cursor = await self.edge_collection.aggregate(outbound_pipeline)
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
merged_results[doc.get("_id")] = doc.get("degree")
|
merged_results[doc.get("_id")] = doc.get("degree")
|
||||||
|
|
||||||
# Inbound degrees
|
# Inbound degrees
|
||||||
inbound_pipeline = [
|
inbound_pipeline = [
|
||||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
{"$match": {"target_node_id": {"$in": node_ids}}},
|
||||||
{"$project": {"_id": 1, "edges.target": 1}},
|
{"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}},
|
||||||
{"$unwind": "$edges"},
|
|
||||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": "$edges.target",
|
|
||||||
"degree": {
|
|
||||||
"$sum": 1
|
|
||||||
}, # Count the number of incoming edges for each target node
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
cursor = await self.edge_collection.aggregate(inbound_pipeline)
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
merged_results[doc.get("_id")] = merged_results.get(
|
merged_results[doc.get("_id")] = merged_results.get(
|
||||||
doc.get("_id"), 0
|
doc.get("_id"), 0
|
||||||
|
|
@ -615,29 +543,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
- Outgoing edges: (queried_node, connected_node)
|
- Outgoing edges: (queried_node, connected_node)
|
||||||
- Incoming edges: (connected_node, queried_node)
|
- Incoming edges: (connected_node, queried_node)
|
||||||
"""
|
"""
|
||||||
result = {}
|
result = {node_id: [] for node_id in node_ids}
|
||||||
|
|
||||||
cursor = self.collection.find(
|
# Query outgoing edges (where node is the source)
|
||||||
{"_id": {"$in": node_ids}}, {"_id": 1, "edges.target": 1}
|
outgoing_cursor = self.edge_collection.find(
|
||||||
|
{"source_node_id": {"$in": node_ids}},
|
||||||
|
{"source_node_id": 1, "target_node_id": 1},
|
||||||
)
|
)
|
||||||
async for doc in cursor:
|
async for edge in outgoing_cursor:
|
||||||
node_id = doc.get("_id")
|
source = edge["source_node_id"]
|
||||||
edges = doc.get("edges", [])
|
target = edge["target_node_id"]
|
||||||
result[node_id] = [(node_id, e["target"]) for e in edges]
|
result[source].append((source, target))
|
||||||
|
|
||||||
inbound_pipeline = [
|
# Query incoming edges (where node is the target)
|
||||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
incoming_cursor = self.edge_collection.find(
|
||||||
{"$project": {"_id": 1, "edges.target": 1}},
|
{"target_node_id": {"$in": node_ids}},
|
||||||
{"$unwind": "$edges"},
|
{"source_node_id": 1, "target_node_id": 1},
|
||||||
{"$match": {"edges.target": {"$in": node_ids}}},
|
)
|
||||||
{"$project": {"_id": "$_id", "target": "$edges.target"}},
|
async for edge in incoming_cursor:
|
||||||
]
|
source = edge["source_node_id"]
|
||||||
|
target = edge["target_node_id"]
|
||||||
cursor = await self.collection.aggregate(inbound_pipeline)
|
result[target].append((source, target))
|
||||||
async for doc in cursor:
|
|
||||||
node_id = doc.get("target")
|
|
||||||
result[node_id] = result.get(node_id, [])
|
|
||||||
result[node_id].append((doc.get("_id"), node_id))
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -649,11 +575,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update a node document. If new, create an empty edges array.
|
Insert or update a node document.
|
||||||
"""
|
"""
|
||||||
# By default, preserve existing 'edges'.
|
update_doc = {"$set": {**node_data}}
|
||||||
# We'll only set 'edges' to [] on insert (no overwrite).
|
|
||||||
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
|
|
||||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
|
|
@ -666,16 +590,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
# Ensure source node exists
|
# Ensure source node exists
|
||||||
await self.upsert_node(source_node_id, {})
|
await self.upsert_node(source_node_id, {})
|
||||||
|
|
||||||
# Remove existing edge (if any)
|
await self.edge_collection.update_one(
|
||||||
await self.collection.update_one(
|
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||||
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
|
{"$set": edge_data},
|
||||||
)
|
upsert=True,
|
||||||
|
|
||||||
# Insert new edge
|
|
||||||
new_edge = {"target": target_node_id}
|
|
||||||
new_edge.update(edge_data)
|
|
||||||
await self.collection.update_one(
|
|
||||||
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
@ -689,8 +607,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
1) Remove node's doc entirely.
|
1) Remove node's doc entirely.
|
||||||
2) Remove inbound edges from any doc that references node_id.
|
2) Remove inbound edges from any doc that references node_id.
|
||||||
"""
|
"""
|
||||||
# Remove inbound edges from all other docs
|
# Remove all edges
|
||||||
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
|
await self.edge_collection.delete_many(
|
||||||
|
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
|
||||||
|
)
|
||||||
|
|
||||||
# Remove the node doc
|
# Remove the node doc
|
||||||
await self.collection.delete_one({"_id": node_id})
|
await self.collection.delete_one({"_id": node_id})
|
||||||
|
|
@ -734,151 +654,92 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
|
node_edges = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if label == "*":
|
pipeline = [
|
||||||
# Get all nodes and edges
|
{
|
||||||
async for node_doc in self.collection.find({}):
|
"$graphLookup": {
|
||||||
node_id = str(node_doc["_id"])
|
"from": self._edge_collection_name,
|
||||||
if node_id not in seen_nodes:
|
"startWith": "$_id",
|
||||||
result.nodes.append(
|
"connectFromField": "target_node_id",
|
||||||
KnowledgeGraphNode(
|
"connectToField": "source_node_id",
|
||||||
id=node_id,
|
"maxDepth": max_depth,
|
||||||
labels=[node_doc.get("_id")],
|
"depthField": "depth",
|
||||||
properties={
|
"as": "connected_edges",
|
||||||
k: v
|
},
|
||||||
for k, v in node_doc.items()
|
},
|
||||||
if k not in ["_id", "edges"]
|
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||||
},
|
{"$sort": {"edge_count": -1}},
|
||||||
)
|
{"$limit": max_nodes},
|
||||||
)
|
]
|
||||||
seen_nodes.add(node_id)
|
|
||||||
|
|
||||||
# Process edges
|
if label == "*":
|
||||||
for edge in node_doc.get("edges", []):
|
all_node_count = await self.collection.count_documents({})
|
||||||
edge_id = f"{node_id}-{edge['target']}"
|
result.is_truncated = all_node_count > max_nodes
|
||||||
if edge_id not in seen_edges:
|
|
||||||
result.edges.append(
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type=edge.get("relation", ""),
|
|
||||||
source=node_id,
|
|
||||||
target=edge["target"],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in edge.items()
|
|
||||||
if k not in ["target", "relation"]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_edges.add(edge_id)
|
|
||||||
else:
|
else:
|
||||||
# Verify if starting node exists
|
# Verify if starting node exists
|
||||||
start_nodes = self.collection.find({"_id": label})
|
start_node = await self.collection.find_one({"_id": label})
|
||||||
start_nodes_exist = await start_nodes.to_list(length=1)
|
if not start_node:
|
||||||
if not start_nodes_exist:
|
|
||||||
logger.warning(f"Starting node with label {label} does not exist!")
|
logger.warning(f"Starting node with label {label} does not exist!")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Use $graphLookup for traversal
|
# Add starting node to pipeline
|
||||||
pipeline = [
|
pipeline.insert(0, {"$match": {"_id": label}})
|
||||||
{
|
|
||||||
"$match": {"_id": label}
|
|
||||||
}, # Start with nodes having the specified label
|
|
||||||
{
|
|
||||||
"$graphLookup": {
|
|
||||||
"from": self._collection_name,
|
|
||||||
"startWith": "$edges.target",
|
|
||||||
"connectFromField": "edges.target",
|
|
||||||
"connectToField": "_id",
|
|
||||||
"maxDepth": max_depth,
|
|
||||||
"depthField": "depth",
|
|
||||||
"as": "connected_nodes",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
async for doc in self.collection.aggregate(pipeline):
|
cursor = await self.collection.aggregate(pipeline)
|
||||||
# Add the start node
|
async for doc in cursor:
|
||||||
node_id = str(doc["_id"])
|
# Add the start node
|
||||||
if node_id not in seen_nodes:
|
node_id = str(doc["_id"])
|
||||||
result.nodes.append(
|
result.nodes.append(
|
||||||
KnowledgeGraphNode(
|
KnowledgeGraphNode(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
labels=[
|
labels=[node_id],
|
||||||
doc.get(
|
properties={
|
||||||
"_id",
|
k: v
|
||||||
)
|
for k, v in doc.items()
|
||||||
],
|
if k
|
||||||
properties={
|
not in [
|
||||||
k: v
|
"_id",
|
||||||
for k, v in doc.items()
|
"connected_edges",
|
||||||
if k
|
"edge_count",
|
||||||
not in [
|
]
|
||||||
"_id",
|
},
|
||||||
"edges",
|
)
|
||||||
"connected_nodes",
|
)
|
||||||
"depth",
|
seen_nodes.add(node_id)
|
||||||
]
|
if doc.get("connected_edges", []):
|
||||||
},
|
node_edges.extend(doc.get("connected_edges"))
|
||||||
)
|
|
||||||
|
for edge in node_edges:
|
||||||
|
if (
|
||||||
|
edge["source_node_id"] not in seen_nodes
|
||||||
|
or edge["target_node_id"] not in seen_nodes
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge.get("relationship", ""),
|
||||||
|
source=edge["source_node_id"],
|
||||||
|
target=edge["target_node_id"],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in edge.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"source_node_id",
|
||||||
|
"target_node_id",
|
||||||
|
"relationship",
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
seen_nodes.add(node_id)
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
# Add edges from start node
|
|
||||||
for edge in doc.get("edges", []):
|
|
||||||
edge_id = f"{node_id}-{edge['target']}"
|
|
||||||
if edge_id not in seen_edges:
|
|
||||||
result.edges.append(
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type=edge.get("relation", ""),
|
|
||||||
source=node_id,
|
|
||||||
target=edge["target"],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in edge.items()
|
|
||||||
if k not in ["target", "relation"]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_edges.add(edge_id)
|
|
||||||
|
|
||||||
# Add connected nodes and their edges
|
|
||||||
for connected in doc.get("connected_nodes", []):
|
|
||||||
node_id = str(connected["_id"])
|
|
||||||
if node_id not in seen_nodes:
|
|
||||||
result.nodes.append(
|
|
||||||
KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
|
||||||
labels=[connected.get("_id")],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in connected.items()
|
|
||||||
if k not in ["_id", "edges", "depth"]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_nodes.add(node_id)
|
|
||||||
|
|
||||||
# Add edges from connected nodes
|
|
||||||
for edge in connected.get("edges", []):
|
|
||||||
edge_id = f"{node_id}-{edge['target']}"
|
|
||||||
if edge_id not in seen_edges:
|
|
||||||
result.edges.append(
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type=edge.get("relation", ""),
|
|
||||||
source=node_id,
|
|
||||||
target=edge["target"],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in edge.items()
|
|
||||||
if k not in ["target", "relation"]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_edges.add(edge_id)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
|
@ -903,9 +764,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
if not nodes:
|
if not nodes:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes)
|
# 1. Remove all edges referencing these nodes
|
||||||
await self.collection.update_many(
|
await self.edge_collection.delete_many(
|
||||||
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}}
|
{
|
||||||
|
"$or": [
|
||||||
|
{"source_node_id": {"$in": nodes}},
|
||||||
|
{"target_node_id": {"$in": nodes}},
|
||||||
|
]
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Delete the node documents
|
# 2. Delete the node documents
|
||||||
|
|
@ -923,25 +789,14 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
if not edges:
|
if not edges:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Group edges by source node (document _id) for efficient updates
|
await self.edge_collection.delete_many(
|
||||||
edge_groups = {}
|
{
|
||||||
for source_id, target_id in edges:
|
"$or": [
|
||||||
if source_id not in edge_groups:
|
{"source_node_id": source_id, "target_node_id": target_id}
|
||||||
edge_groups[source_id] = []
|
for source_id, target_id in edges
|
||||||
edge_groups[source_id].append(target_id)
|
]
|
||||||
|
}
|
||||||
# Bulk update documents to remove the edges
|
)
|
||||||
update_operations = []
|
|
||||||
for source_id, target_ids in edge_groups.items():
|
|
||||||
update_operations.append(
|
|
||||||
self.collection.update_one(
|
|
||||||
{"_id": source_id},
|
|
||||||
{"$pull": {"edges": {"target": {"$in": target_ids}}}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if update_operations:
|
|
||||||
await asyncio.gather(*update_operations)
|
|
||||||
|
|
||||||
logger.debug(f"Successfully deleted edges: {edges}")
|
logger.debug(f"Successfully deleted edges: {edges}")
|
||||||
|
|
||||||
|
|
@ -958,9 +813,16 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result = await self.edge_collection.delete_many({})
|
||||||
|
edge_count = result.deleted_count
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {edge_count} edges from graph {self._edge_collection_name}"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} documents dropped",
|
"message": f"{deleted_count} documents and {edge_count} edges dropped",
|
||||||
}
|
}
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue