Merge branch 'kenspirit/main'

This commit is contained in:
yangdx 2025-06-26 13:52:42 +08:00
commit 778ad4f23a
3 changed files with 460 additions and 404 deletions

View file

@ -14,8 +14,8 @@ STORAGE_IMPLEMENTATIONS = {
"NetworkXStorage", "NetworkXStorage",
"Neo4JStorage", "Neo4JStorage",
"PGGraphStorage", "PGGraphStorage",
"MongoGraphStorage",
# "AGEStorage", # "AGEStorage",
# "MongoGraphStorage",
# "TiDBGraphStorage", # "TiDBGraphStorage",
# "GremlinStorage", # "GremlinStorage",
], ],

View file

@ -4,7 +4,7 @@ import numpy as np
import configparser import configparser
import asyncio import asyncio
from typing import Any, List, Union, final from typing import Any, Union, final
from ..base import ( from ..base import (
BaseGraphStorage, BaseGraphStorage,
@ -17,32 +17,32 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger, compute_mdhash_id from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
pm.install("pymongo") pm.install("pymongo")
if not pm.is_installed("motor"): from pymongo import AsyncMongoClient # type: ignore
pm.install("motor") from pymongo.asynchronous.database import AsyncDatabase # type: ignore
from pymongo.asynchronous.collection import AsyncCollection # type: ignore
from motor.motor_asyncio import ( # type: ignore
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel # type: ignore from pymongo.operations import SearchIndexModel # type: ignore
from pymongo.errors import PyMongoError # type: ignore from pymongo.errors import PyMongoError # type: ignore
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock() _lock = asyncio.Lock()
@classmethod @classmethod
async def get_client(cls) -> AsyncIOMotorDatabase: async def get_client(cls) -> AsyncMongoClient:
async with cls._lock: async with cls._lock:
if cls._instances["db"] is None: if cls._instances["db"] is None:
uri = os.environ.get( uri = os.environ.get(
@ -57,7 +57,7 @@ class ClientManager:
"MONGO_DATABASE", "MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"), config.get("mongodb", "database", fallback="LightRAG"),
) )
client = AsyncIOMotorClient(uri) client = AsyncMongoClient(uri)
db = client.get_database(database_name) db = client.get_database(database_name)
cls._instances["db"] = db cls._instances["db"] = db
cls._instances["ref_count"] = 0 cls._instances["ref_count"] = 0
@ -65,7 +65,7 @@ class ClientManager:
return cls._instances["db"] return cls._instances["db"]
@classmethod @classmethod
async def release_client(cls, db: AsyncIOMotorDatabase): async def release_client(cls, db: AsyncDatabase):
async with cls._lock: async with cls._lock:
if db is not None: if db is not None:
if db is cls._instances["db"]: if db is cls._instances["db"]:
@ -77,8 +77,8 @@ class ClientManager:
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None) _data: AsyncCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@ -107,6 +107,19 @@ class MongoKVStorage(BaseKVStorage):
existing_ids = {str(x["_id"]) async for x in cursor} existing_ids = {str(x["_id"]) async for x in cursor}
return keys - existing_ids return keys - existing_ids
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
cursor = self._data.find({})
result = {}
async for doc in cursor:
doc_id = doc.pop("_id")
result[doc_id] = doc
return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
@ -159,6 +172,10 @@ class MongoKVStorage(BaseKVStorage):
if not ids: if not ids:
return return
# Convert to list if it's a set (MongoDB BSON cannot encode sets)
if isinstance(ids, set):
ids = list(ids)
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.info( logger.info(
@ -214,8 +231,8 @@ class MongoKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None) _data: AsyncCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@ -259,7 +276,7 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
cursor = self._data.aggregate(pipeline) cursor = self._data.aggregate(pipeline, allowDiskUse=True)
result = await cursor.to_list() result = await cursor.to_list()
counts = {} counts = {}
for doc in result: for doc in result:
@ -311,6 +328,9 @@ class MongoDocStatusStorage(DocStatusStorage):
logger.error(f"Error dropping doc status {self._collection_name}: {e}") logger.error(f"Error dropping doc status {self._collection_name}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None:
await self._data.delete_many({"_id": {"$in": ids}})
@final @final
@dataclass @dataclass
@ -319,8 +339,11 @@ class MongoGraphStorage(BaseGraphStorage):
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
""" """
db: AsyncIOMotorDatabase = field(default=None) db: AsyncDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None) # node collection storing node_id, node_properties
collection: AsyncCollection = field(default=None)
# edge collection storing source_node_id, target_node_id, and edge_properties
edgeCollection: AsyncCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
@ -329,6 +352,7 @@ class MongoGraphStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._collection_name = self.namespace self._collection_name = self.namespace
self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
@ -336,6 +360,9 @@ class MongoGraphStorage(BaseGraphStorage):
self.collection = await get_or_create_collection( self.collection = await get_or_create_collection(
self.db, self._collection_name self.db, self._collection_name
) )
self.edge_collection = await get_or_create_collection(
self.db, self._edge_collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}") logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def finalize(self): async def finalize(self):
@ -343,58 +370,36 @@ class MongoGraphStorage(BaseGraphStorage):
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
self.collection = None self.collection = None
self.edge_collection = None
# # Sample entity document
# ------------------------------------------------------------------------- # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
# HELPER: $graphLookup pipeline
# -------------------------------------------------------------------------
#
async def _graph_lookup( # {
self, start_node_id: str, max_depth: int = None # "_id" : "CompanyA",
) -> List[dict]: # "entity_id" : "CompanyA",
""" # "entity_type" : "Organization",
Performs a $graphLookup starting from 'start_node_id' and returns # "description" : "A major technology company",
all reachable documents (including the start node itself). # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
# "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
# "file_path" : "custom_kg",
# "created_at" : 1749904575
# }
Pipeline Explanation: # Sample relation document
- 1) $match: We match the start node document by _id = start_node_id. # {
- 2) $graphLookup: # "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId
"from": same collection, # "description" : "CompanyA develops ProductX",
"startWith": "$edges.target" (the immediate neighbors in 'edges'), # "source_node_id" : "CompanyA",
"connectFromField": "edges.target", # "target_node_id" : "ProductX",
"connectToField": "_id", # "relationship": "Develops", // To distinguish multiple same-target relations
"as": "reachableNodes", # "weight" : Double("1"),
"maxDepth": max_depth (if provided), # "keywords" : "develop, produce",
"depthField": "depth" (used for debugging or filtering). # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
- 3) We add an $project or $unwind as needed to extract data. # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
""" # "file_path" : "custom_kg",
pipeline = [ # "created_at" : 1749904575
{"$match": {"_id": start_node_id}}, # }
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
}
},
]
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
if max_depth is not None:
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
# Return the matching doc plus a field "reachableNodes"
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
# If there's no matching node, results = [].
# Otherwise, results[0] is the start node doc,
# plus results[0]["reachableNodes"] is the array of connected docs.
return results
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -413,44 +418,13 @@ class MongoGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if there's a direct single-hop edge from source_node_id to target_node_id. Check if there's a direct single-hop edge from source_node_id to target_node_id.
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
"Look up zero expansions." Actually, for a direct edge check, we can do maxDepth=1
and then see if the target node is in the "reachableNodes" at depth=0.
But typically for a direct edge, we might just do a find_one.
Below is a demonstration approach.
""" """
# We can do a single-hop graphLookup (maxDepth=0 or 1). # Direct check if the target_node appears among the edges array.
# Then check if the target_node appears among the edges array. doc = await self.edge_collection.find_one(
pipeline = [ {"source_node_id": source_node_id, "target_node_id": target_node_id},
{"$match": {"_id": source_node_id}}, {"_id": 1},
{ )
"$graphLookup": { return doc is not None
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
"maxDepth": 0, # means: do not follow beyond immediate edges
}
},
{
"$project": {
"_id": 0,
"reachableNodes._id": 1, # only keep the _id from the subdocs
}
},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return False
# results[0]["reachableNodes"] are the immediate neighbors
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
return target_node_id in reachable_ids
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -461,85 +435,25 @@ class MongoGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
""" """
Returns the total number of edges connected to node_id (both inbound and outbound). Returns the total number of edges connected to node_id (both inbound and outbound).
The easiest approach is typically two queries:
- count of edges array in node_id's doc
- count of how many other docs have node_id in their edges.target.
But we'll do a $graphLookup demonstration for inbound edges:
1) Outbound edges: direct from node's edges array
2) Inbound edges: we can do a special $graphLookup from all docs
or do an explicit match.
For demonstration, let's do this in two steps (with second step $graphLookup).
""" """
# --- 1) Outbound edges (direct from doc) --- return await self.edge_collection.count_documents(
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) {"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
if not doc: )
return 0
outbound_count = len(doc.get("edges", []))
# --- 2) Inbound edges:
# A simple way is: find all docs where "edges.target" == node_id.
# But let's do a $graphLookup from `node_id` in REVERSE.
# There's a trick to do "reverse" graphLookups: you'd store
# reversed edges or do a more advanced pipeline. Typically you'd do
# a direct match. We'll just do a direct match for inbound.
inbound_count_pipeline = [
{"$match": {"edges.target": node_id}},
{
"$project": {
"matchingEdgesCount": {
"$size": {
"$filter": {
"input": "$edges",
"as": "edge",
"cond": {"$eq": ["$$edge.target", node_id]},
}
}
}
}
},
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
]
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
inbound_result = await inbound_cursor.to_list(None)
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
return outbound_count + inbound_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
""" """Get the total degree (sum of relationships) of two nodes.
If your graph can hold multiple edges from the same src to the same tgt
(e.g. different 'relation' values), you can sum them. If it's always
one edge, this is typically 1 or 0.
We'll do a single-hop $graphLookup from src_id, Args:
then count how many edges reference tgt_id at depth=0. src_id: Label of the source node
""" tgt_id: Label of the target node
pipeline = [
{"$match": {"_id": src_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return 0
# We can simply count how many edges in `results[0].edges` have target == tgt_id. Returns:
edges = results[0].get("edges", []) int: Sum of the degrees of both nodes
count = sum(1 for e in edges if e.get("target") == tgt_id) """
return count src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
return src_degree + trg_degree
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -549,65 +463,167 @@ class MongoGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
""" """
Return the full node document (including "edges"), or None if missing. Return the full node document, or None if missing.
""" """
return await self.collection.find_one({"_id": node_id}) return await self.collection.find_one({"_id": node_id})
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
pipeline = [ return await self.edge_collection.find_one(
{"$match": {"_id": source_node_id}},
{ {
"$graphLookup": { "$or": [
"from": self.collection.name, {
"startWith": "$edges.target", "source_node_id": source_node_id,
"connectFromField": "edges.target", "target_node_id": target_node_id,
"connectToField": "_id", },
"as": "neighbors", {
"depthField": "depth", "source_node_id": target_node_id,
"maxDepth": 0, "target_node_id": source_node_id,
}
}, },
{"$project": {"edges": 1}},
] ]
cursor = self.collection.aggregate(pipeline) }
docs = await cursor.to_list(None) )
if not docs:
return None
for e in docs[0].get("edges", []):
if e.get("target") == target_node_id:
return e
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
""" """
Return a list of (source_id, target_id) for direct edges from source_node_id. Retrieves all edges (relationships) for a particular node identified by its label.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
"""
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"_id": 0, "edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
result = await cursor.to_list(None)
if not result:
return None
edges = result[0].get("edges", []) Args:
return [(source_node_id, e["target"]) for e in edges] source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
"""
cursor = self.edge_collection.find(
{
"$or": [
{"source_node_id": source_node_id},
{"target_node_id": source_node_id},
]
},
{"source_node_id": 1, "target_node_id": 1},
)
return [
(e.get("source_node_id"), e.get("target_node_id")) async for e in cursor
]
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
result = {}
async for doc in self.collection.find({"_id": {"$in": node_ids}}):
result[doc.get("_id")] = doc
return result
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
# merge the outbound and inbound results with the same "_id" and sum the "degree"
merged_results = {}
# Outbound degrees
outbound_pipeline = [
{"$match": {"source_node_id": {"$in": node_ids}}},
{"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
]
cursor = await self.edge_collection.aggregate(outbound_pipeline, allowDiskUse=True)
async for doc in cursor:
merged_results[doc.get("_id")] = doc.get("degree")
# Inbound degrees
inbound_pipeline = [
{"$match": {"target_node_id": {"$in": node_ids}}},
{"$group": {"_id": "$target_node_id", "degree": {"$sum": 1}}},
]
cursor = await self.edge_collection.aggregate(inbound_pipeline, allowDiskUse=True)
async for doc in cursor:
merged_results[doc.get("_id")] = merged_results.get(
doc.get("_id"), 0
) + doc.get("degree")
return merged_results
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
result = {node_id: [] for node_id in node_ids}
# Query outgoing edges (where node is the source)
outgoing_cursor = self.edge_collection.find(
{"source_node_id": {"$in": node_ids}},
{"source_node_id": 1, "target_node_id": 1},
)
async for edge in outgoing_cursor:
source = edge["source_node_id"]
target = edge["target_node_id"]
result[source].append((source, target))
# Query incoming edges (where node is the target)
incoming_cursor = self.edge_collection.find(
{"target_node_id": {"$in": node_ids}},
{"source_node_id": 1, "target_node_id": 1},
)
async for edge in incoming_cursor:
source = edge["source_node_id"]
target = edge["target_node_id"]
result[target].append((source, target))
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if not chunk_ids:
return []
cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
return [doc async for doc in cursor]
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if not chunk_ids:
return []
cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
edges = []
async for edge in cursor:
edge["source"] = edge["source_node_id"]
edge["target"] = edge["target_node_id"]
edges.append(edge)
return edges
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -617,11 +633,14 @@ class MongoGraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Insert or update a node document. If new, create an empty edges array. Insert or update a node document.
""" """
# By default, preserve existing 'edges'. update_doc = {"$set": {**node_data}}
# We'll only set 'edges' to [] on insert (no overwrite). if node_data.get("source_id", ""):
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} update_doc["$set"]["source_ids"] = node_data["source_id"].split(
GRAPH_FIELD_SEP
)
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge( async def upsert_edge(
@ -634,16 +653,16 @@ class MongoGraphStorage(BaseGraphStorage):
# Ensure source node exists # Ensure source node exists
await self.upsert_node(source_node_id, {}) await self.upsert_node(source_node_id, {})
# Remove existing edge (if any) update_doc = {"$set": edge_data}
await self.collection.update_one( if edge_data.get("source_id", ""):
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} update_doc["$set"]["source_ids"] = edge_data["source_id"].split(
GRAPH_FIELD_SEP
) )
# Insert new edge await self.edge_collection.update_one(
new_edge = {"target": target_node_id} {"source_node_id": source_node_id, "target_node_id": target_node_id},
new_edge.update(edge_data) update_doc,
await self.collection.update_one( upsert=True,
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
) )
# #
@ -657,8 +676,10 @@ class MongoGraphStorage(BaseGraphStorage):
1) Remove node's doc entirely. 1) Remove node's doc entirely.
2) Remove inbound edges from any doc that references node_id. 2) Remove inbound edges from any doc that references node_id.
""" """
# Remove inbound edges from all other docs # Remove all edges
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) await self.edge_collection.delete_many(
{"$or": [{"source_node_id": node_id}, {"target_node_id": node_id}]}
)
# Remove the node doc # Remove the node doc
await self.collection.delete_one({"_id": node_id}) await self.collection.delete_one({"_id": node_id})
@ -675,20 +696,18 @@ class MongoGraphStorage(BaseGraphStorage):
Returns: Returns:
[id1, id2, ...] # Alphabetically sorted id list [id1, id2, ...] # Alphabetically sorted id list
""" """
# Use MongoDB's distinct and aggregation to get all unique labels
pipeline = [
{"$group": {"_id": "$_id"}}, # Group by _id
{"$sort": {"_id": 1}}, # Sort alphabetically
]
cursor = self.collection.aggregate(pipeline) cursor = self.collection.find({}, projection={"_id": 1}, sort=[("_id", 1)])
labels = [] labels = []
async for doc in cursor: async for doc in cursor:
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self,
node_label: str,
max_depth: int = 5,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@ -704,147 +723,88 @@ class MongoGraphStorage(BaseGraphStorage):
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
node_edges = []
try: try:
if label == "*": pipeline = [
# Get all nodes and edges {
async for node_doc in self.collection.find({}): "$graphLookup": {
node_id = str(node_doc["_id"]) "from": self._edge_collection_name,
if node_id not in seen_nodes: "startWith": "$_id",
result.nodes.append( "connectFromField": "target_node_id",
KnowledgeGraphNode( "connectToField": "source_node_id",
id=node_id, "maxDepth": max_depth,
labels=[node_doc.get("_id")], "depthField": "depth",
properties={ "as": "connected_edges",
k: v
for k, v in node_doc.items()
if k not in ["_id", "edges"]
}, },
) },
) {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
seen_nodes.add(node_id) {"$sort": {"edge_count": -1}},
{"$limit": max_nodes},
]
# Process edges if label == "*":
for edge in node_doc.get("edges", []): all_node_count = await self.collection.count_documents({})
edge_id = f"{node_id}-{edge['target']}" result.is_truncated = all_node_count > max_nodes
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
},
)
)
seen_edges.add(edge_id)
else: else:
# Verify if starting node exists # Verify if starting node exists
start_nodes = self.collection.find({"_id": label}) start_node = await self.collection.find_one({"_id": label})
start_nodes_exist = await start_nodes.to_list(length=1) if not start_node:
if not start_nodes_exist:
logger.warning(f"Starting node with label {label} does not exist!") logger.warning(f"Starting node with label {label} does not exist!")
return result return result
# Use $graphLookup for traversal # Add starting node to pipeline
pipeline = [ pipeline.insert(0, {"$match": {"_id": label}})
{
"$match": {"_id": label}
}, # Start with nodes having the specified label
{
"$graphLookup": {
"from": self._collection_name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"maxDepth": max_depth,
"depthField": "depth",
"as": "connected_nodes",
}
},
]
async for doc in self.collection.aggregate(pipeline): cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
async for doc in cursor:
# Add the start node # Add the start node
node_id = str(doc["_id"]) node_id = str(doc["_id"])
if node_id not in seen_nodes:
result.nodes.append( result.nodes.append(
KnowledgeGraphNode( KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[ labels=[node_id],
doc.get(
"_id",
)
],
properties={ properties={
k: v k: v
for k, v in doc.items() for k, v in doc.items()
if k if k
not in [ not in [
"_id", "_id",
"edges", "connected_edges",
"connected_nodes", "edge_count",
"depth",
] ]
}, },
) )
) )
seen_nodes.add(node_id) seen_nodes.add(node_id)
if doc.get("connected_edges", []):
node_edges.extend(doc.get("connected_edges"))
# Add edges from start node for edge in node_edges:
for edge in doc.get("edges", []): if (
edge_id = f"{node_id}-{edge['target']}" edge["source_node_id"] not in seen_nodes
or edge["target_node_id"] not in seen_nodes
):
continue
edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
if edge_id not in seen_edges: if edge_id not in seen_edges:
result.edges.append( result.edges.append(
KnowledgeGraphEdge( KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type=edge.get("relation", ""), type=edge.get("relationship", ""),
source=node_id, source=edge["source_node_id"],
target=edge["target"], target=edge["target_node_id"],
properties={ properties={
k: v k: v
for k, v in edge.items() for k, v in edge.items()
if k not in ["target", "relation"] if k
}, not in [
) "_id",
) "source_node_id",
seen_edges.add(edge_id) "target_node_id",
"relationship",
# Add connected nodes and their edges ]
for connected in doc.get("connected_nodes", []):
node_id = str(connected["_id"])
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=node_id,
labels=[connected.get("_id")],
properties={
k: v
for k, v in connected.items()
if k not in ["_id", "edges", "depth"]
},
)
)
seen_nodes.add(node_id)
# Add edges from connected nodes
for edge in connected.get("edges", []):
edge_id = f"{node_id}-{edge['target']}"
if edge_id not in seen_edges:
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type=edge.get("relation", ""),
source=node_id,
target=edge["target"],
properties={
k: v
for k, v in edge.items()
if k not in ["target", "relation"]
}, },
) )
) )
@ -873,9 +833,14 @@ class MongoGraphStorage(BaseGraphStorage):
if not nodes: if not nodes:
return return
# 1. Remove all edges referencing these nodes (remove from edges array of other nodes) # 1. Remove all edges referencing these nodes
await self.collection.update_many( await self.edge_collection.delete_many(
{}, {"$pull": {"edges": {"target": {"$in": nodes}}}} {
"$or": [
{"source_node_id": {"$in": nodes}},
{"target_node_id": {"$in": nodes}},
]
}
) )
# 2. Delete the node documents # 2. Delete the node documents
@ -893,17 +858,16 @@ class MongoGraphStorage(BaseGraphStorage):
if not edges: if not edges:
return return
update_tasks = [] all_edge_pairs = []
for source, target in edges: for source_id, target_id in edges:
# Remove edge pointing to target from source node's edges array all_edge_pairs.append(
update_tasks.append( {"source_node_id": source_id, "target_node_id": target_id}
self.collection.update_one(
{"_id": source}, {"$pull": {"edges": {"target": target}}}
) )
all_edge_pairs.append(
{"source_node_id": target_id, "target_node_id": source_id}
) )
if update_tasks: await self.edge_collection.delete_many({"$or": all_edge_pairs})
await asyncio.gather(*update_tasks)
logger.debug(f"Successfully deleted edges: {edges}") logger.debug(f"Successfully deleted edges: {edges}")
@ -920,9 +884,16 @@ class MongoGraphStorage(BaseGraphStorage):
logger.info( logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}" f"Dropped {deleted_count} documents from graph {self._collection_name}"
) )
result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count
logger.info(
f"Dropped {edge_count} edges from graph {self._edge_collection_name}"
)
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents and {edge_count} edges dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}") logger.error(f"Error dropping graph {self._collection_name}: {e}")
@ -932,8 +903,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase | None = field(default=None) db: AsyncDatabase | None = field(default=None)
_data: AsyncIOMotorCollection | None = field(default=None) _data: AsyncCollection | None = field(default=None)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@ -967,7 +938,8 @@ class MongoVectorDBStorage(BaseVectorStorage):
try: try:
index_name = "vector_knn_index" index_name = "vector_knn_index"
indexes = await self._data.list_search_indexes().to_list(length=None) indexes_cursor = await self._data.list_search_indexes()
indexes = await indexes_cursor.to_list(length=None)
for index in indexes: for index in indexes:
if index["name"] == index_name: if index["name"] == index_name:
logger.debug("vector index already exist") logger.debug("vector index already exist")
@ -1062,8 +1034,8 @@ class MongoVectorDBStorage(BaseVectorStorage):
] ]
# Execute the aggregation pipeline # Execute the aggregation pipeline
cursor = self._data.aggregate(pipeline) cursor = await self._data.aggregate(pipeline, allowDiskUse=True)
results = await cursor.to_list() results = await cursor.to_list(length=None)
# Format and return the results with created_at field # Format and return the results with created_at field
return [ return [
@ -1090,6 +1062,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
if not ids: if not ids:
return return
# Convert to list if it's a set (MongoDB BSON cannot encode sets)
if isinstance(ids, set):
ids = list(ids)
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug( logger.debug(
@ -1232,7 +1208,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
collection_names = await db.list_collection_names() collection_names = await db.list_collection_names()
if collection_name not in collection_names: if collection_name not in collection_names:

View file

@ -30,6 +30,7 @@ from lightrag.kg import (
verify_storage_implementation, verify_storage_implementation,
) )
from lightrag.kg.shared_storage import initialize_share_data from lightrag.kg.shared_storage import initialize_share_data
from lightrag.constants import GRAPH_FIELD_SEP
# 模拟的嵌入函数,返回随机向量 # 模拟的嵌入函数,返回随机向量
@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage):
5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
""" """
try: try:
chunk1_id = "1"
chunk2_id = "2"
chunk3_id = "3"
# 1. 插入测试数据 # 1. 插入测试数据
# 插入节点1: 人工智能 # 插入节点1: 人工智能
node1_id = "人工智能" node1_id = "人工智能"
@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage):
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。", "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习", "keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域", "entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
} }
print(f"插入节点1: {node1_id}") print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data) await storage.upsert_node(node1_id, node1_data)
@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage):
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。", "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习", "keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域", "entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
} }
print(f"插入节点2: {node2_id}") print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data) await storage.upsert_node(node2_id, node2_data)
@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage):
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。", "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN", "keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域", "entity_type": "技术领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
} }
print(f"插入节点3: {node3_id}") print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data) await storage.upsert_node(node3_id, node3_data)
@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含", "relationship": "包含",
"weight": 1.0, "weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域", "description": "人工智能领域包含机器学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
} }
print(f"插入边1: {node1_id} -> {node2_id}") print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data) await storage.upsert_edge(node1_id, node2_id, edge1_data)
@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含", "relationship": "包含",
"weight": 1.0, "weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域", "description": "机器学习领域包含深度学习这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
} }
print(f"插入边2: {node2_id} -> {node3_id}") print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data) await storage.upsert_edge(node2_id, node3_id, edge2_data)
@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage):
"relationship": "包含", "relationship": "包含",
"weight": 1.0, "weight": 1.0,
"description": "人工智能领域包含自然语言处理这个子领域", "description": "人工智能领域包含自然语言处理这个子领域",
"source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
} }
print(f"插入边3: {node1_id} -> {node4_id}") print(f"插入边3: {node1_id} -> {node4_id}")
await storage.upsert_edge(node1_id, node4_id, edge3_data) await storage.upsert_edge(node1_id, node4_id, edge3_data)
@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage):
print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)") print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
# 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点
print("== 测试 get_nodes_by_chunk_ids")
print("== 测试单个 chunk_id匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id])
assert len(nodes) == 2, f"{chunk1_id} 应有2个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
print("== 测试多个 chunk_id部分匹配多个节点")
nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(nodes) == 3
), f"{chunk2_id}, {chunk3_id} 应有3个节点实际有 {len(nodes)}"
has_node1 = any(node["entity_id"] == node1_id for node in nodes)
has_node2 = any(node["entity_id"] == node2_id for node in nodes)
has_node3 = any(node["entity_id"] == node3_id for node in nodes)
assert has_node1, f"节点 {node1_id} 应在返回结果中"
assert has_node2, f"节点 {node2_id} 应在返回结果中"
assert has_node3, f"节点 {node3_id} 应在返回结果中"
# 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边
print("== 测试 get_edges_by_chunk_ids")
print("== 测试单个 chunk_id匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id])
assert len(edges) == 2, f"{chunk2_id} 应有2条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id}{node2_id} 的边"
assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id}{node3_id} 的边"
print("== 测试多个 chunk_id部分匹配多条边")
edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id])
assert (
len(edges) == 3
), f"{chunk2_id}, {chunk3_id} 应有3条边实际有 {len(edges)}"
has_edge_node1_node2 = any(
edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
)
has_edge_node2_node3 = any(
edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
)
has_edge_node1_node4 = any(
edge["source"] == node1_id and edge["target"] == node4_id for edge in edges
)
assert (
has_edge_node1_node2
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node2_id} 的边"
assert (
has_edge_node2_node3
), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id}{node3_id} 的边"
assert (
has_edge_node1_node4
), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id}{node4_id} 的边"
print("\n批量操作测试完成") print("\n批量操作测试完成")
return True return True