Merge branch 'kenspirit/main'
This commit is contained in:
commit
6d5e73a251
2 changed files with 379 additions and 130 deletions
|
|
@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
"""All operations related to edges in graph should be undirected."""
|
||||
|
||||
embedding_func: EmbeddingFunc
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
|
@ -35,6 +36,7 @@ config.read("config.ini", "utf-8")
|
|||
|
||||
# Get maximum number of graph nodes from environment variable, default is 1000
|
||||
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
||||
GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
|
||||
|
||||
|
||||
class ClientManager:
|
||||
|
|
@ -417,11 +419,21 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||
Check if there's a direct single-hop edge between source_node_id and target_node_id.
|
||||
"""
|
||||
# Direct check if the target_node appears among the edges array.
|
||||
doc = await self.edge_collection.find_one(
|
||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||
{
|
||||
"$or": [
|
||||
{
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id": target_node_id,
|
||||
},
|
||||
{
|
||||
"source_node_id": target_node_id,
|
||||
"target_node_id": source_node_id,
|
||||
},
|
||||
]
|
||||
},
|
||||
{"_id": 1},
|
||||
)
|
||||
return doc is not None
|
||||
|
|
@ -651,7 +663,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||
Upsert an edge between source_node_id and target_node_id with optional 'relation'.
|
||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||
"""
|
||||
# Ensure source node exists
|
||||
|
|
@ -663,8 +675,22 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
GRAPH_FIELD_SEP
|
||||
)
|
||||
|
||||
edge_data["source_node_id"] = source_node_id
|
||||
edge_data["target_node_id"] = target_node_id
|
||||
|
||||
await self.edge_collection.update_one(
|
||||
{"source_node_id": source_node_id, "target_node_id": target_node_id},
|
||||
{
|
||||
"$or": [
|
||||
{
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id": target_node_id,
|
||||
},
|
||||
{
|
||||
"source_node_id": target_node_id,
|
||||
"target_node_id": source_node_id,
|
||||
},
|
||||
]
|
||||
},
|
||||
update_doc,
|
||||
upsert=True,
|
||||
)
|
||||
|
|
@ -678,7 +704,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
async def delete_node(self, node_id: str) -> None:
|
||||
"""
|
||||
1) Remove node's doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
2) Remove inbound & outbound edges from any doc that references node_id.
|
||||
"""
|
||||
# Remove all edges
|
||||
await self.edge_collection.delete_many(
|
||||
|
|
@ -709,141 +735,369 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
labels.append(doc["_id"])
|
||||
return labels
|
||||
|
||||
def _construct_graph_node(
|
||||
self, node_id, node_data: dict[str, str]
|
||||
) -> KnowledgeGraphNode:
|
||||
return KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in node_data.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"connected_edges",
|
||||
"source_ids",
|
||||
"edge_count",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
|
||||
return KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relationship", ""),
|
||||
source=edge["source_node_id"],
|
||||
target=edge["target_node_id"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"source_node_id",
|
||||
"target_node_id",
|
||||
"relationship",
|
||||
"source_ids",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def get_knowledge_graph_all_by_degree(
|
||||
self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
It's possible that the node with one or multiple relationships is retrieved,
|
||||
while its neighbor is not. Then this node might seem like disconnected in UI.
|
||||
"""
|
||||
|
||||
total_node_count = await self.collection.count_documents({})
|
||||
result = KnowledgeGraph()
|
||||
seen_edges = set()
|
||||
|
||||
result.is_truncated = total_node_count > max_nodes
|
||||
if result.is_truncated:
|
||||
# Get all node_ids ranked by degree if max_nodes exceeds total node count
|
||||
pipeline = [
|
||||
{"$project": {"source_node_id": 1, "_id": 0}},
|
||||
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": self._edge_collection_name,
|
||||
"pipeline": [
|
||||
{"$project": {"target_node_id": 1, "_id": 0}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$target_node_id",
|
||||
"degree": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
{"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
|
||||
{"$sort": {"degree": -1}},
|
||||
{"$limit": max_nodes},
|
||||
]
|
||||
cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
|
||||
|
||||
node_ids = []
|
||||
async for doc in cursor:
|
||||
node_id = str(doc["_id"])
|
||||
node_ids.append(node_id)
|
||||
|
||||
cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
|
||||
async for doc in cursor:
|
||||
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||
|
||||
# As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
|
||||
edge_cursor = self.edge_collection.find(
|
||||
{
|
||||
"$and": [
|
||||
{"source_node_id": {"$in": node_ids}},
|
||||
{"target_node_id": {"$in": node_ids}},
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
# All nodes and edges are needed
|
||||
cursor = self.collection.find({}, {"source_ids": 0})
|
||||
|
||||
async for doc in cursor:
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(self._construct_graph_node(doc["_id"], doc))
|
||||
|
||||
edge_cursor = self.edge_collection.find({})
|
||||
|
||||
async for edge in edge_cursor:
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
seen_edges.add(edge_id)
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
|
||||
return result
|
||||
|
||||
async def _bidirectional_bfs_nodes(
|
||||
self,
|
||||
node_labels: list[str],
|
||||
seen_nodes: set[str],
|
||||
result: KnowledgeGraph,
|
||||
depth: int = 0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
if depth > max_depth or len(result.nodes) > max_nodes:
|
||||
return result
|
||||
|
||||
cursor = self.collection.find({"_id": {"$in": node_labels}})
|
||||
|
||||
async for node in cursor:
|
||||
node_id = node["_id"]
|
||||
if node_id not in seen_nodes:
|
||||
seen_nodes.add(node_id)
|
||||
result.nodes.append(self._construct_graph_node(node_id, node))
|
||||
if len(result.nodes) > max_nodes:
|
||||
return result
|
||||
|
||||
# Collect neighbors
|
||||
# Get both inbound and outbound one hop nodes
|
||||
cursor = self.edge_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"source_node_id": {"$in": node_labels}},
|
||||
{"target_node_id": {"$in": node_labels}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
neighbor_nodes = []
|
||||
async for edge in cursor:
|
||||
if edge["source_node_id"] not in seen_nodes:
|
||||
neighbor_nodes.append(edge["source_node_id"])
|
||||
if edge["target_node_id"] not in seen_nodes:
|
||||
neighbor_nodes.append(edge["target_node_id"])
|
||||
|
||||
if neighbor_nodes:
|
||||
result = await self._bidirectional_bfs_nodes(
|
||||
neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_subgraph_bidirectional_bfs(
|
||||
self,
|
||||
node_label: str,
|
||||
depth=0,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
result = KnowledgeGraph()
|
||||
|
||||
result = await self._bidirectional_bfs_nodes(
|
||||
[node_label], seen_nodes, result, depth, max_depth, max_nodes
|
||||
)
|
||||
|
||||
# Get all edges from seen_nodes
|
||||
all_node_ids = list(seen_nodes)
|
||||
cursor = self.edge_collection.find(
|
||||
{
|
||||
"$and": [
|
||||
{"source_node_id": {"$in": all_node_ids}},
|
||||
{"target_node_id": {"$in": all_node_ids}},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
async for edge in cursor:
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_subgraph_in_out_bound_bfs(
|
||||
self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
|
||||
) -> KnowledgeGraph:
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
result = KnowledgeGraph()
|
||||
project_doc = {
|
||||
"source_ids": 0,
|
||||
"created_at": 0,
|
||||
"entity_type": 0,
|
||||
"file_path": 0,
|
||||
}
|
||||
|
||||
# Verify if starting node exists
|
||||
start_node = await self.collection.find_one({"_id": node_label})
|
||||
if not start_node:
|
||||
logger.warning(f"Starting node with label {node_label} does not exist!")
|
||||
return result
|
||||
|
||||
seen_nodes.add(node_label)
|
||||
result.nodes.append(self._construct_graph_node(node_label, start_node))
|
||||
|
||||
if max_depth == 0:
|
||||
return result
|
||||
|
||||
# In MongoDB, depth = 0 means one-hop
|
||||
max_depth = max_depth - 1
|
||||
|
||||
pipeline = [
|
||||
{"$match": {"_id": node_label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{
|
||||
"$unionWith": {
|
||||
"coll": self._collection_name,
|
||||
"pipeline": [
|
||||
{"$match": {"_id": node_label}},
|
||||
{"$project": project_doc},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "source_node_id",
|
||||
"connectToField": "target_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||
node_edges = []
|
||||
|
||||
# Two records for node_label are returned capturing outbound and inbound connected_edges
|
||||
async for doc in cursor:
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
# Sort the connected edges by depth ascending and weight descending
|
||||
# And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
|
||||
node_edges = sorted(
|
||||
node_edges,
|
||||
key=lambda x: (x["depth"], -x["weight"]),
|
||||
)
|
||||
|
||||
# As order matters, we need to use another list to store the node_id
|
||||
# And only take the first max_nodes ones
|
||||
node_ids = []
|
||||
for edge in node_edges:
|
||||
if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
|
||||
node_ids.append(edge["source_node_id"])
|
||||
seen_nodes.add(edge["source_node_id"])
|
||||
|
||||
if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
|
||||
node_ids.append(edge["target_node_id"])
|
||||
seen_nodes.add(edge["target_node_id"])
|
||||
|
||||
# Filter out all the node whose id is same as node_label so that we do not check existence next step
|
||||
cursor = self.collection.find({"_id": {"$in": node_ids}})
|
||||
|
||||
async for doc in cursor:
|
||||
result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
or edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
continue
|
||||
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(self._construct_graph_edge(edge_id, edge))
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
max_depth: int = 5,
|
||||
max_depth: int = 3,
|
||||
max_nodes: int = MAX_GRAPH_NODES,
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
|
||||
Args:
|
||||
node_label: Label of the nodes to start from
|
||||
max_depth: Maximum depth of traversal (default: 5)
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges of the subgraph
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
|
||||
If a graph is like this and starting from B:
|
||||
A → B ← C ← F, B -> E, C → D
|
||||
|
||||
Outbound BFS:
|
||||
B → E
|
||||
|
||||
Inbound BFS:
|
||||
A → B
|
||||
C → B
|
||||
F → C
|
||||
|
||||
Bidirectional BFS:
|
||||
A → B
|
||||
B → E
|
||||
F → C
|
||||
C → B
|
||||
C → D
|
||||
"""
|
||||
label = node_label
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
node_edges = []
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Optimize pipeline to avoid memory issues with large datasets
|
||||
if label == "*":
|
||||
# For getting all nodes, use a simpler pipeline to avoid memory issues
|
||||
pipeline = [
|
||||
{"$limit": max_nodes}, # Limit early to reduce memory usage
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Check if we need to set truncation flag
|
||||
all_node_count = await self.collection.count_documents({})
|
||||
result.is_truncated = all_node_count > max_nodes
|
||||
else:
|
||||
# Verify if starting node exists
|
||||
start_node = await self.collection.find_one({"_id": label})
|
||||
if not start_node:
|
||||
logger.warning(f"Starting node with label {label} does not exist!")
|
||||
return result
|
||||
|
||||
# For specific node queries, use the original pipeline but optimized
|
||||
pipeline = [
|
||||
{"$match": {"_id": label}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self._edge_collection_name,
|
||||
"startWith": "$_id",
|
||||
"connectFromField": "target_node_id",
|
||||
"connectToField": "source_node_id",
|
||||
"maxDepth": max_depth,
|
||||
"depthField": "depth",
|
||||
"as": "connected_edges",
|
||||
},
|
||||
},
|
||||
{"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
|
||||
{"$sort": {"edge_count": -1}},
|
||||
{"$limit": max_nodes},
|
||||
]
|
||||
|
||||
cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
|
||||
nodes_processed = 0
|
||||
|
||||
async for doc in cursor:
|
||||
# Add the start node
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in doc.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"connected_edges",
|
||||
"edge_count",
|
||||
]
|
||||
},
|
||||
)
|
||||
if node_label == "*":
|
||||
result = await self.get_knowledge_graph_all_by_degree(
|
||||
max_depth, max_nodes
|
||||
)
|
||||
elif GRAPH_BFS_MODE == "in_out_bound":
|
||||
result = await self.get_knowledge_subgraph_in_out_bound_bfs(
|
||||
node_label, max_depth, max_nodes
|
||||
)
|
||||
else:
|
||||
result = await self.get_knowledge_subgraph_bidirectional_bfs(
|
||||
node_label, 0, max_depth, max_nodes
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
if doc.get("connected_edges", []):
|
||||
node_edges.extend(doc.get("connected_edges"))
|
||||
|
||||
nodes_processed += 1
|
||||
|
||||
# Additional safety check to prevent memory issues
|
||||
if nodes_processed >= max_nodes:
|
||||
result.is_truncated = True
|
||||
break
|
||||
|
||||
for edge in node_edges:
|
||||
if (
|
||||
edge["source_node_id"] not in seen_nodes
|
||||
or edge["target_node_id"] not in seen_nodes
|
||||
):
|
||||
continue
|
||||
|
||||
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
|
||||
if edge_id not in seen_edges:
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge.get("relationship", ""),
|
||||
source=edge["source_node_id"],
|
||||
target=edge["target_node_id"],
|
||||
properties={
|
||||
k: v
|
||||
for k, v in edge.items()
|
||||
if k
|
||||
not in [
|
||||
"_id",
|
||||
"source_node_id",
|
||||
"target_node_id",
|
||||
"relationship",
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
||||
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
|
||||
)
|
||||
|
||||
except PyMongoError as e:
|
||||
|
|
@ -856,13 +1110,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||
try:
|
||||
simple_cursor = self.collection.find({}).limit(max_nodes)
|
||||
async for doc in simple_cursor:
|
||||
node_id = str(doc["_id"])
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_id],
|
||||
properties={k: v for k, v in doc.items() if k != "_id"},
|
||||
)
|
||||
self._construct_graph_node(str(doc["_id"]), doc)
|
||||
)
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
|
|
@ -1028,8 +1277,6 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||
return
|
||||
|
||||
# Add current time as Unix timestamp
|
||||
import time
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
list_data = [
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue