Fix nodes & edges are missing when retrieving knowledge subgraph by selecting particular node_id
This commit is contained in:
parent
c740401b7f
commit
f40bc43d5e
1 changed files with 76 additions and 23 deletions
|
|
@ -709,6 +709,29 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
labels.append(doc["_id"])
|
labels.append(doc["_id"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
def _construct_graph_node(
|
||||||
|
self, node_data: dict[str, str], seen_nodes: set
|
||||||
|
) -> KnowledgeGraphNode:
|
||||||
|
node_id = str(node_data["_id"])
|
||||||
|
|
||||||
|
graph_node = KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_id],
|
||||||
|
properties={
|
||||||
|
k: v
|
||||||
|
for k, v in node_data.items()
|
||||||
|
if k
|
||||||
|
not in [
|
||||||
|
"_id",
|
||||||
|
"connected_edges",
|
||||||
|
"edge_count",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
|
||||||
|
return graph_node, seen_nodes
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
|
|
@ -730,13 +753,22 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
node_edges = []
|
node_edges = []
|
||||||
|
project_doc = {
|
||||||
|
"source_ids": 0,
|
||||||
|
"created_at": 0,
|
||||||
|
"entity_type": 0,
|
||||||
|
"file_path": 0,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Optimize pipeline to avoid memory issues with large datasets
|
# Optimize pipeline to avoid memory issues with large datasets
|
||||||
if label == "*":
|
if label == "*":
|
||||||
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{"$limit": max_nodes}, # Limit early to reduce memory usage
|
{"$limit": max_nodes + 1}, # Limit early to reduce memory usage
|
||||||
|
{
|
||||||
|
"$project": project_doc
|
||||||
|
}, # Load without internal fields or unneeded ones shown in WebUI
|
||||||
{
|
{
|
||||||
"$graphLookup": {
|
"$graphLookup": {
|
||||||
"from": self._edge_collection_name,
|
"from": self._edge_collection_name,
|
||||||
|
|
@ -748,6 +780,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
"as": "connected_edges",
|
"as": "connected_edges",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||||
|
# {"$sort": {"edge_count": -1}},
|
||||||
|
# {"$limit": max_nodes},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check if we need to set truncation flag
|
# Check if we need to set truncation flag
|
||||||
|
|
@ -763,6 +798,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
# For specific node queries, use the original pipeline but optimized
|
# For specific node queries, use the original pipeline but optimized
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{"$match": {"_id": label}},
|
{"$match": {"_id": label}},
|
||||||
|
{"$project": project_doc},
|
||||||
{
|
{
|
||||||
"$graphLookup": {
|
"$graphLookup": {
|
||||||
"from": self._edge_collection_name,
|
"from": self._edge_collection_name,
|
||||||
|
|
@ -774,34 +810,16 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
"as": "connected_edges",
|
"as": "connected_edges",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
|
||||||
{"$sort": {"edge_count": -1}},
|
|
||||||
{"$limit": max_nodes},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||||
nodes_processed = 0
|
nodes_processed = 0
|
||||||
|
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
# Add the start node
|
# Add the start nodes
|
||||||
node_id = str(doc["_id"])
|
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
||||||
result.nodes.append(
|
result.nodes.append(graph_node)
|
||||||
KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
|
||||||
labels=[node_id],
|
|
||||||
properties={
|
|
||||||
k: v
|
|
||||||
for k, v in doc.items()
|
|
||||||
if k
|
|
||||||
not in [
|
|
||||||
"_id",
|
|
||||||
"connected_edges",
|
|
||||||
"edge_count",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
seen_nodes.add(node_id)
|
|
||||||
if doc.get("connected_edges", []):
|
if doc.get("connected_edges", []):
|
||||||
node_edges.extend(doc.get("connected_edges"))
|
node_edges.extend(doc.get("connected_edges"))
|
||||||
|
|
||||||
|
|
@ -812,6 +830,41 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
result.is_truncated = True
|
result.is_truncated = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# When label != "*", cursor above only have one node and we need to get the subgraph by connected edges
|
||||||
|
# Sort the connected edges by depth ascending and weight descending
|
||||||
|
# And stores the source_node_id and target_node_id in sequence to retrieve the nodes again
|
||||||
|
if label != "*":
|
||||||
|
node_edges = sorted(
|
||||||
|
node_edges,
|
||||||
|
key=lambda x: (x["depth"], -x["weight"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# As order matters, we need to use another list to store the node_id
|
||||||
|
# And only take the first max_nodes ones
|
||||||
|
node_ids = []
|
||||||
|
for edge in node_edges:
|
||||||
|
if (
|
||||||
|
len(node_ids) < max_nodes
|
||||||
|
and edge["source_node_id"] not in seen_nodes
|
||||||
|
):
|
||||||
|
node_ids.append(edge["source_node_id"])
|
||||||
|
seen_nodes.add(edge["source_node_id"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(node_ids) < max_nodes
|
||||||
|
and edge["target_node_id"] not in seen_nodes
|
||||||
|
):
|
||||||
|
node_ids.append(edge["target_node_id"])
|
||||||
|
seen_nodes.add(edge["target_node_id"])
|
||||||
|
|
||||||
|
cursor = self.collection.find(
|
||||||
|
{"_id": {"$in": [node for node in node_ids if node != label]}}
|
||||||
|
)
|
||||||
|
|
||||||
|
async for doc in cursor:
|
||||||
|
graph_node, seen_nodes = self._construct_graph_node(doc, seen_nodes)
|
||||||
|
result.nodes.append(graph_node)
|
||||||
|
|
||||||
for edge in node_edges:
|
for edge in node_edges:
|
||||||
if (
|
if (
|
||||||
edge["source_node_id"] not in seen_nodes
|
edge["source_node_id"] not in seen_nodes
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue