Merge branch 'kenspirit/main'

This commit is contained in:
yangdx 2025-06-29 00:30:58 +08:00
commit 6d5e73a251
2 changed files with 379 additions and 130 deletions

View file

@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
@dataclass
class BaseGraphStorage(StorageNameSpace, ABC):
"""All operations related to edges in graph should be undirected."""
embedding_func: EmbeddingFunc
@abstractmethod

View file

@ -1,4 +1,5 @@
import os
import time
from dataclasses import dataclass, field
import numpy as np
import configparser
@ -35,6 +36,7 @@ config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager:
@ -417,11 +419,21 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if there's a direct single-hop edge from source_node_id to target_node_id.
Check if there's a direct single-hop edge between source_node_id and target_node_id.
"""
# Direct check if the target_node appears among the edges array.
doc = await self.edge_collection.find_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
{"_id": 1},
)
return doc is not None
@ -651,7 +663,7 @@ class MongoGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
Upsert an edge between source_node_id and target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data.
"""
# Ensure source node exists
@ -663,8 +675,22 @@ class MongoGraphStorage(BaseGraphStorage):
GRAPH_FIELD_SEP
)
edge_data["source_node_id"] = source_node_id
edge_data["target_node_id"] = target_node_id
await self.edge_collection.update_one(
{"source_node_id": source_node_id, "target_node_id": target_node_id},
{
"$or": [
{
"source_node_id": source_node_id,
"target_node_id": target_node_id,
},
{
"source_node_id": target_node_id,
"target_node_id": source_node_id,
},
]
},
update_doc,
upsert=True,
)
@ -678,7 +704,7 @@ class MongoGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None:
"""
1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id.
2) Remove inbound & outbound edges from any doc that references node_id.
"""
# Remove all edges
await self.edge_collection.delete_many(
@ -709,141 +735,369 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"])
return labels
def _construct_graph_node(
self, node_id, node_data: dict[str, str]
) -> KnowledgeGraphNode:
return KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in node_data.items()
if k
not in [
"_id",
"connected_edges",
"source_ids",
"edge_count",
]
},
)
def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
return KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
"source_ids",
]
},
)
async def get_knowledge_graph_all_by_degree(
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
"""
It's possible that the node with one or multiple relationships is retrieved,
while its neighbor is not. Then this node might seem like disconnected in UI.
"""
total_node_count = await self.collection.count_documents({})
result = KnowledgeGraph()
seen_edges = set()
result.is_truncated = total_node_count > max_nodes
if result.is_truncated:
# Get all node_ids ranked by degree if max_nodes exceeds total node count
pipeline = [
{"$project": {"source_node_id": 1, "_id": 0}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
{
"$unionWith": {
"coll": self._edge_collection_name,
"pipeline": [
{"$project": {"target_node_id": 1, "_id": 0}},
{
"$group": {
"_id": "$target_node_id",
"degree": {"$sum": 1},
}
},
],
}
},
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
{"$sort": {"degree": -1}},
{"$limit": max_nodes},
]
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
node_ids = []
async for doc in cursor:
node_id = str(doc["_id"])
node_ids.append(node_id)
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
edge_cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": node_ids}},
{"target_node_id": {"$in": node_ids}},
]
}
)
else:
# All nodes and edges are needed
cursor = self.collection.find({}, {"source_ids": 0})
async for doc in cursor:
node_id = str(doc["_id"])
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
edge_cursor = self.edge_collection.find({})
async for edge in edge_cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
seen_edges.add(edge_id)
result.edges.append(self._construct_graph_edge(edge_id, edge))
return result
async def _bidirectional_bfs_nodes(
self,
node_labels: list[str],
seen_nodes: set[str],
result: KnowledgeGraph,
depth: int = 0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
if depth > max_depth or len(result.nodes) > max_nodes:
return result
cursor = self.collection.find({"_id": {"$in": node_labels}})
async for node in cursor:
node_id = node["_id"]
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(self._construct_graph_node(node_id, node))
if len(result.nodes) > max_nodes:
return result
# Collect neighbors
# Get both inbound and outbound one hop nodes
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": {"$in": node_labels}},
{"target_node_id": {"$in": node_labels}},
]
}
)
neighbor_nodes = []
async for edge in cursor:
if edge["source_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["source_node_id"])
if edge["target_node_id"] not in seen_nodes:
neighbor_nodes.append(edge["target_node_id"])
if neighbor_nodes:
result = await self._bidirectional_bfs_nodes(
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
)
return result
async def get_knowledge_subgraph_bidirectional_bfs(
self,
node_label: str,
depth=0,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
result = await self._bidirectional_bfs_nodes(
[node_label], seen_nodes, result, depth, max_depth, max_nodes
)
# Get all edges from seen_nodes
all_node_ids = list(seen_nodes)
cursor = self.edge_collection.find(
{
"$and": [
{"source_node_id": {"$in": all_node_ids}},
{"target_node_id": {"$in": all_node_ids}},
]
}
)
async for edge in cursor:
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_subgraph_in_out_bound_bfs(
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
) -> KnowledgeGraph:
seen_nodes = set()
seen_edges = set()
result = KnowledgeGraph()
project_doc = {
"source_ids": 0,
"created_at": 0,
"entity_type": 0,
"file_path": 0,
}
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label})
if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!")
return result
seen_nodes.add(node_label)
result.nodes.append(self._construct_graph_node(node_label, start_node))
if max_depth == 0:
return result
# In MongoDB, depth = 0 means one-hop
max_depth = max_depth - 1
pipeline = [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{
"$unionWith": {
"coll": self._collection_name,
"pipeline": [
{"$match": {"_id": node_label}},
{"$project": project_doc},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "source_node_id",
"connectToField": "target_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
}
},
],
}
},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
node_edges = []
# Two records for node_label are returned capturing outbound and inbound connected_edges
async for doc in cursor:
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
# Sort the connected edges by depth ascending and weight descending
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
node_edges = sorted(
node_edges,
key=lambda x: (x["depth"], -x["weight"]),
)
# As order matters, we need to use another list to store the node_id
# And only take the first max_nodes ones
node_ids = []
for edge in node_edges:
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
node_ids.append(edge["source_node_id"])
seen_nodes.add(edge["source_node_id"])
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
node_ids.append(edge["target_node_id"])
seen_nodes.add(edge["target_node_id"])
# Filter out all the node whose id is same as node_label so that we do not check existence next step
cursor = self.collection.find({"_id": {"$in": node_ids}})
async for doc in cursor:
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(self._construct_graph_edge(edge_id, edge))
seen_edges.add(edge_id)
return result
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 5,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the nodes to start from
max_depth: Maximum depth of traversal (default: 5)
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges of the subgraph
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
If a graph is like this and starting from B:
A B C F, B -> E, C D
Outbound BFS:
B E
Inbound BFS:
A B
C B
F C
Bidirectional BFS:
A B
B E
F C
C B
C D
"""
label = node_label
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
node_edges = []
start = time.perf_counter()
try:
# Optimize pipeline to avoid memory issues with large datasets
if label == "*":
# For getting all nodes, use a simpler pipeline to avoid memory issues
pipeline = [
{"$limit": max_nodes}, # Limit early to reduce memory usage
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
]
# Check if we need to set truncation flag
all_node_count = await self.collection.count_documents({})
result.is_truncated = all_node_count > max_nodes
else:
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": label})
if not start_node:
logger.warning(f"Starting node with label {label} does not exist!")
return result
# For specific node queries, use the original pipeline but optimized
pipeline = [
{"$match": {"_id": label}},
{
"$graphLookup": {
"from": self._edge_collection_name,
"startWith": "$_id",
"connectFromField": "target_node_id",
"connectToField": "source_node_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_edges",
},
},
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
{"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
nodes_processed = 0
async for doc in cursor:
# Add the start node
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={
k: v
for k, v in doc.items()
if k
not in [
"_id",
"connected_edges",
"edge_count",
]
},
)
if node_label == "*":
result = await self.get_knowledge_graph_all_by_degree(
max_depth, max_nodes
)
elif GRAPH_BFS_MODE == "in_out_bound":
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
node_label, max_depth, max_nodes
)
else:
result = await self.get_knowledge_subgraph_bidirectional_bfs(
node_label, 0, max_depth, max_nodes
)
seen_nodes.add(node_id)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
nodes_processed += 1
# Additional safety check to prevent memory issues
if nodes_processed >= max_nodes:
result.is_truncated = True
break
for edge in node_edges:
if (
edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relationship", ""),
source=edge["source_node_id"],
target=edge["target_node_id"],
properties={
k: v
for k, v in edge.items()
if k
not in [
"_id",
"source_node_id",
"target_node_id",
"relationship",
]
},
)
)
seen_edges.add(edge_id)
duration = time.perf_counter() - start
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
)
except PyMongoError as e:
@ -856,13 +1110,8 @@ class MongoGraphStorage(BaseGraphStorage):
try:
simple_cursor = self.collection.find({}).limit(max_nodes)
async for doc in simple_cursor:
node_id = str(doc["_id"])
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[node_id],
properties={k: v for k, v in doc.items() if k != "_id"},
)
self._construct_graph_node(str(doc["_id"]), doc)
)
result.is_truncated = True
logger.info(
@ -1028,8 +1277,6 @@ class MongoVectorDBStorage(BaseVectorStorage):
return
# Add current time as Unix timestamp
import time
current_time = int(time.time())
list_data = [