Rewrite get_knowledge_graph with label * by degree
This commit is contained in:
parent
d0f4eee404
commit
5739f52d29
1 changed files with 80 additions and 32 deletions
|
|
@ -743,10 +743,85 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
"source_node_id",
|
||||
"target_node_id",
|
||||
"relationship",
|
||||
"source_ids",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def get_knowledge_graph_all_by_degree(
|
||||
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
It's possible that the node with one or multiple relationships is retrieved,
|
||||
while its neighbor is not. Then this node might seem like disconnected in UI.
|
||||
"""
|
||||
|
||||
total_node_count = await self.collection.count_documents({})
|
||||
result = KnowledgeGraph()
|
||||
seen_edges = set()
|
||||
|
||||
result.is_truncated = total_node_count > max_nodes
|
||||
if result.is_truncated:
|
||||
# Get all node_ids ranked by degree if max_nodes exceeds total node count
|
||||
pipeline = [
|
||||
{"$project": {"source_node_id": 1, "_id": 0}},
|
||||
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": self._edge_collection_name,
|
||||
"pipeline": [
|
||||
{"$project": {"target_node_id": 1, "_id": 0}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$target_node_id",
|
||||
"degree": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
|
||||
{"$sort": {"degree": -1}},
|
||||
{"$limit": max_nodes},
|
||||
]
|
||||
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
|
||||
|
||||
node_ids = []
|
||||
async for doc in cursor:
|
||||
node_id = str(doc["_id"])
|
||||
node_ids.append(node_id)
|
||||
|
||||
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
|
||||
async for doc in cursor:
|
||||
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||
|
||||
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
|
||||
edge_cursor = self.edge_collection.find(
|
||||
{
|
||||
"$and": [
|
||||
{"source_node_id": {"$in": node_ids}},
|
||||
{"target_node_id": {"$in": node_ids}},
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
# All nodes and edges are needed
|
||||
cursor = self.collection.find({}, {"source_ids": 0})
|
||||
|
||||
async for doc in cursor:
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||
|
||||
edge_cursor = self.edge_collection.find({})
|
||||
|
||||
async for edge in edge_cursor:
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
seen_edges.add(edge_id)
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
|
|
@ -778,31 +853,9 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
try:
|
||||
# Optimize pipeline to avoid memory issues with large datasets
|
||||
if label == "*":
|
||||
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
||||
pipeline = [
|
||||
{"$limit": max_nodes + 1}, # Limit early to reduce memory usage
|
||||
{
|
||||
"$project": project_doc
|
||||
}, # Load without internal fields or unneeded ones shown in WebUI
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
# {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||
# {"$sort": {"edge_count": -1}},
|
||||
# {"$limit": max_nodes},
|
||||
]
|
||||
|
||||
# Check if we need to set truncation flag
|
||||
all_node_count = await self.collection.count_documents({})
|
||||
result.is_truncated = all_node_count > max_nodes
|
||||
return await self.get_knowledge_graph_all_by_degree(
|
||||
max_depth, max_nodes
|
||||
)
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_node = await self.collection.find_one({"_id": label})
|
||||
|
|
@ -827,7 +880,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": "chunk_entity_relation",
|
||||
"coll": self._collection_name,
|
||||
"pipeline": [
|
||||
{"$match": {"_id": label}},
|
||||
{"$project": project_doc},
|
||||
|
|
@ -929,13 +982,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
try:
|
||||
simple_cursor = self.collection.find({}).limit(max_nodes)
|
||||
async for doc in simple_cursor:
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={k: v for k, v in doc.items() if k != "_id"},
|
||||
)
|
||||
self._construct_graph_node(str(doc["_id"]), doc)
|
||||
)
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue