Fix nodes & edges are missing when retrieving knowledge subgraph by selecting particular node_id

This commit is contained in:
Ken Chen 2025-06-26 23:11:31 +08:00
parent c740401b7f
commit f40bc43d5e

View file

@ -709,6 +709,29 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
def _construct_graph_node(
self, node_data: dict[str, str], seen_nodes: set
) -> KnowledgeGraphNode:
node_id = str(node_data["_id"])
graph_node = KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
seen_nodes.add(node_id)
return graph_node, seen_nodes
async def get_knowledge_graph(
self,
node_label: str,
@ -730,13 +753,22 @@ class MongoGraphStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
node_edges = []
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues
pipeline = [
{"$limit": max_nodes}, # Limit early to reduce memory usage
{"$limit": max_nodes + 1}, # Limit early to reduce memory usage
{
"$project": project_doc
}, # Load without internal fields or unneeded ones shown in WebUI
{
"$graphLookup": {
"from": self._edge_collection_name,
@ -748,6 +780,9 @@ class MongoGraphStorage(BaseGraphStorage):
"as": "connected_edges",
},
},
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
# {"$sort": {"edge_count": -1}},
# {"$limit": max_nodes},
]
# Check if we need to set truncation flag
@ -763,6 +798,7 @@ class MongoGraphStorage(BaseGraphStorage):
# For specific node queries, use the original pipeline but optimized
pipeline = [
{"$match": {"_id": label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
@ -774,34 +810,16 @@ class MongoGraphStorage(BaseGraphStorage):
"as": "connected_edges",
},
},
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0
async for doc in cursor:
# Add the start node
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
)
seen_nodes.add(node_id)
# Add the start nodes
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append(graph_node)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
@ -812,6 +830,41 @@ class MongoGraphStorage(BaseGraphStorage):
result.is_truncated = True
break
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
if label != "*":
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if (
len(node_ids) < max_nodes
and edge["source_node_id"] not in seen_nodes
):
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if (
len(node_ids) < max_nodes
and edge["target_node_id"] not in seen_nodes
):
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
cursor = self.collection.find(
{"_id": {"$in": [node for node in node_ids if node != label]}}
)
async for doc in cursor:
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append(graph_node)
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes