Rewrite get_knowledge_graph with label * by degree

This commit is contained in:
Ken Chen 2025-06-28 17:10:39 +08:00
parent d0f4eee404
commit 5739f52d29

View file

@ -743,10 +743,85 @@ class MongoGraphStorage(BaseGraphStorage):
"source_node_id",
"target_node_id",
"relationship",
"source_ids",
]
},
)
async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
"""
It's possible that the node with one or multiple relationships is retrieved,
while its neighbor is not. Then this node might seem like disconnected in UI.
"""
total_node_count = await self.collection.count_documents({})
result = KnowledgeGraph()
seen_edges = set()
result.is_truncated = total_node_count > max_nodes
if result.is_truncated:
# Get all node_ids ranked by degree if max_nodes exceeds total node count
pipeline = [
{"$project": {"source_node_id": 1, "_id": 0}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
{
"$unionWith": {
"coll": self._edge_collection_name,
"pipeline": [
{"$project": {"target_node_id": 1, "_id": 0}},
{
"$group": {
"_id": "$target_node_id",
"degree": {"$sum": 1},
}
},
],
}
},
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
{"$sort": {"degree": -1}},
{"$limit": max_nodes},
]
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
node_ids = []
async for doc in cursor:
node_id = str(doc["_id"])
node_ids.append(node_id)
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
edge_cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": node_ids}},
{"target_node_id": {"$in": node_ids}},
]
}
)
else:
# All nodes and edges are needed
cursor = self.collection.find({}, {"source_ids": 0})
async for doc in cursor:
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
edge_cursor = self.edge_collection.find({})
async for edge in edge_cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
seen_edges.add(edge_id)
result.edges.append(self._construct_graph_edge(edge_id, edge))
return result
async def get_knowledge_graph(
self,
node_label: str,
@ -778,31 +853,9 @@ class MongoGraphStorage(BaseGraphStorage):
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues
pipeline = [
{"$limit": max_nodes + 1}, # Limit early to reduce memory usage
{
"$project": project_doc
}, # Load without internal fields or unneeded ones shown in WebUI
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
# {"$sort": {"edge_count": -1}},
# {"$limit": max_nodes},
]
# Check if we need to set truncation flag
all_node_count = await self.collection.count_documents({})
result.is_truncated = all_node_count > max_nodes
return await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
else:
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": label})
@ -827,7 +880,7 @@ class MongoGraphStorage(BaseGraphStorage):
},
{
"$unionWith": {
"coll": "chunk_entity_relation",
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": label}},
{"$project": project_doc},
@ -929,13 +982,8 @@ class MongoGraphStorage(BaseGraphStorage):
try:
simple_cursor = self.collection.find({}).limit(max_nodes)
async for doc in simple_cursor:
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={k: v for k, v in doc.items() if k != "_id"},
)
self._construct_graph_node(str(doc["_id"]), doc)
)
result.is_truncated = True
logger.info(