Fix nodes & edges are missing when retrieving knowledge subgraph by selecting particular node_id
This commit is contained in:
parent
c740401b7f
commit
f40bc43d5e
1 changed files with 76 additions and 23 deletions
|
|
@ -709,6 +709,29 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
labels.append(doc["_id"])
|
||||
return labels
|
||||
|
||||
def _construct_graph_node(
|
||||
self, node_data: dict[str, str], seen_nodes: set
|
||||
) -> KnowledgeGraphNode:
|
||||
node_id = str(node_data["_id"])
|
||||
|
||||
graph_node = KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in node_data.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"connected_edges",
|
||||
"edge_count",
|
||||
]
|
||||
},
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
return graph_node, seen_nodes
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
|
|
@ -730,13 +753,22 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
node_edges = []
|
||||
project_doc = {
|
||||
"source_ids": 0,
|
||||
"created_at": 0,
|
||||
"entity_type": 0,
|
||||
"file_path": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# Optimize pipeline to avoid memory issues with large datasets
|
||||
if label == "*":
|
||||
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
||||
pipeline = [
|
||||
{"$limit": max_nodes}, # Limit early to reduce memory usage
|
||||
{"$limit": max_nodes + 1}, # Limit early to reduce memory usage
|
||||
{
|
||||
"$project": project_doc
|
||||
}, # Load without internal fields or unneeded ones shown in WebUI
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
|
|
@ -748,6 +780,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||
# {"$sort": {"edge_count": -1}},
|
||||
# {"$limit": max_nodes},
|
||||
]
|
||||
|
||||
# Check if we need to set truncation flag
|
||||
|
|
@ -763,6 +798,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
# For specific node queries, use the original pipeline but optimized
|
||||
pipeline = [
|
||||
{"$match": {"_id": label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
|
|
@ -774,34 +810,16 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||
{"$sort": {"edge_count": -1}},
|
||||
{"$limit": max_nodes},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||
nodes_processed = 0
|
||||
|
||||
async for doc in cursor:
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"connected_edges",
|
||||
"edge_count",
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
# Add the start nodes
|
||||
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
||||
result.nodes.append(graph_node)
|
||||
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
|
|
@ -812,6 +830,41 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
result.is_truncated = True
|
||||
break
|
||||
|
||||
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
|
||||
# Sort the connected edges by depth ascending and weight descending
|
||||
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
|
||||
if label != "*":
|
||||
node_edges = sorted(
|
||||
node_edges,
|
||||
key=lambda x: (x["depth"], -x["weight"]),
|
||||
)
|
||||
|
||||
# As order matters, we need to use another list to store the node_id
|
||||
# And only take the first max_nodes ones
|
||||
node_ids = []
|
||||
for edge in node_edges:
|
||||
if (
|
||||
len(node_ids) < max_nodes
|
||||
and edge["source_node_id"] not in seen_nodes
|
||||
):
|
||||
node_ids.append(edge["source_node_id"])
|
||||
seen_nodes.add(edge["source_node_id"])
|
||||
|
||||
if (
|
||||
len(node_ids) < max_nodes
|
||||
and edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
node_ids.append(edge["target_node_id"])
|
||||
seen_nodes.add(edge["target_node_id"])
|
||||
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": [node for node in node_ids if node != label]}}
|
||||
)
|
||||
|
||||
async for doc in cursor:
|
||||
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
||||
result.nodes.append(graph_node)
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue