Fix nodes & edges are missing when retrieving knowledge subgraph by selecting particular node_id

This commit is contained in:
Ken Chen 2025-06-26 23:11:31 +08:00
parent c740401b7f
commit f40bc43d5e

View file

@ -709,6 +709,29 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
def _construct_graph_node(
self, node_data: dict[str, str], seen_nodes: set
) -> KnowledgeGraphNode:
node_id = str(node_data["_id"])
graph_node = KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
seen_nodes.add(node_id)
return graph_node, seen_nodes
async def get_knowledge_graph( async def get_knowledge_graph(
self, self,
node_label: str, node_label: str,
@ -730,13 +753,22 @@ class MongoGraphStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
node_edges = [] node_edges = []
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
try: try:
# Optimize pipeline to avoid memory issues with large datasets # Optimize pipeline to avoid memory issues with large datasets
if label == "*": if label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues # For getting all nodes, use a simpler pipeline to avoid memory issues
pipeline = [ pipeline = [
{"$limit": max_nodes}, # Limit early to reduce memory usage {"$limit": max_nodes + 1}, # Limit early to reduce memory usage
{
"$project": project_doc
}, # Load without internal fields or unneeded ones shown in WebUI
{ {
"$graphLookup": { "$graphLookup": {
"from": self._edge_collection_name, "from": self._edge_collection_name,
@ -748,6 +780,9 @@ class MongoGraphStorage(BaseGraphStorage):
"as": "connected_edges", "as": "connected_edges",
}, },
}, },
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
# {"$sort": {"edge_count": -1}},
# {"$limit": max_nodes},
] ]
# Check if we need to set truncation flag # Check if we need to set truncation flag
@ -763,6 +798,7 @@ class MongoGraphStorage(BaseGraphStorage):
# For specific node queries, use the original pipeline but optimized # For specific node queries, use the original pipeline but optimized
pipeline = [ pipeline = [
{"$match": {"_id": label}}, {"$match": {"_id": label}},
{"$project": project_doc},
{ {
"$graphLookup": { "$graphLookup": {
"from": self._edge_collection_name, "from": self._edge_collection_name,
@ -774,34 +810,16 @@ class MongoGraphStorage(BaseGraphStorage):
"as": "connected_edges", "as": "connected_edges",
}, },
}, },
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
] ]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0 nodes_processed = 0
async for doc in cursor: async for doc in cursor:
# Add the start node # Add the start nodes
node_id = str(doc["_id"]) graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append( result.nodes.append(graph_node)
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
)
seen_nodes.add(node_id)
if doc.get("connected_edges", []): if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges")) node_edges.extend(doc.get("connected_edges"))
@ -812,6 +830,41 @@ class MongoGraphStorage(BaseGraphStorage):
result.is_truncated = True result.is_truncated = True
break break
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
if label != "*":
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if (
len(node_ids) < max_nodes
and edge["source_node_id"] not in seen_nodes
):
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if (
len(node_ids) < max_nodes
and edge["target_node_id"] not in seen_nodes
):
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
cursor = self.collection.find(
{"_id": {"$in": [node for node in node_ids if node != label]}}
)
async for doc in cursor:
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
result.nodes.append(graph_node)
for edge in node_edges: for edge in node_edges:
if ( if (
edge["source_node_id"] not in seen_nodes edge["source_node_id"] not in seen_nodes