From ff1927d362f5a14656ed9df9a521af295604e26f Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:15:56 +0200 Subject: [PATCH 01/19] add Memgraph graph storage backend --- config.ini.example | 3 + examples/graph_visual_with_neo4j.py | 2 +- examples/lightrag_openai_demo.py | 1 + lightrag/kg/__init__.py | 3 + lightrag/kg/memgraph_impl.py | 423 ++++++++++++++++++++++++++++ 5 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 lightrag/kg/memgraph_impl.py diff --git a/config.ini.example b/config.ini.example index 63d9c2c0..94d300a1 100644 --- a/config.ini.example +++ b/config.ini.example @@ -21,3 +21,6 @@ password = your_password database = your_database workspace = default # 可选,默认为default max_connections = 12 + +[memgraph] +uri = bolt://localhost:7687 diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 1cd2e7a3..e06c248c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100 # Neo4j connection credentials NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" -NEO4J_PASSWORD = "your_password" +NEO4J_PASSWORD = "david123" def xml_to_json(xml_file): diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index fa0b37f1..e573ec41 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -82,6 +82,7 @@ async def initialize_rag(): working_dir=WORKING_DIR, embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, + graph_storage="MemgraphStorage", ) await rag.initialize_storages() diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b4ba0983..3398b135 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "Neo4JStorage", "PGGraphStorage", "MongoGraphStorage", + "MemgraphStorage", # "AGEStorage", # "TiDBGraphStorage", # "GremlinStorage", @@ -56,6 +57,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], + "MemgraphStorage": ["MEMGRAPH_URI"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", @@ -108,6 +110,7 @@ STORAGES = { "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", + "MemgraphStorage": ".kg.memgraph_impl", } diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py new file mode 100644 index 00000000..cb46cdc6 --- /dev/null +++ b/lightrag/kg/memgraph_impl.py @@ -0,0 +1,423 @@ +import os +import re +from dataclasses import dataclass +from typing import final +import configparser + +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +from neo4j import ( + AsyncGraphDatabase, + AsyncManagedTransaction, +) + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +load_dotenv(dotenv_path=".env", override=False) + +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + +@final +@dataclass +class MemgraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + self._driver = None + + async def initialize(self): + URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) + USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) + PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) + DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + await session.run("""CREATE INDEX ON :base(entity_id)""") + logger.info("Created index on :base(entity_id) in Memgraph.") + except Exception as e: + # Index may already exist, which is not an error + logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + await session.run("RETURN 1") + logger.info(f"Connected to Memgraph at {URI}") + except Exception as e: + logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + raise + + async def finalize(self): + if self._driver is not None: + await self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + await self.finalize() + + async def index_done_callback(self): + # Memgraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() + return single_result["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() + return single_result["edgeExists"] + + async def get_node(self, node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + records = await result.fetch(2) + await result.consume() + if records: + node = records[0]["n"] + node_dict = dict(node) + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + return node_dict + return None + + async def get_all_labels(self) -> list[str]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected + """ + results = await session.run(query, entity_id=source_node_id) + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + if not source_node or not connected_node: + continue + source_label = source_node.get("entity_id") + target_label = connected_node.get("entity_id") + if source_label and target_label: + edges.append((source_label, target_label)) + await results.consume() + return edges + + async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + return edge_result + return None + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + properties = node_data + entity_type = properties.get("entity_type", "base") + if "entity_id" not in properties: + raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + f""" + MERGE (n:base {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + ) + result = await tx.run(query, entity_id=node_id, properties=properties) + await result.consume() + await session.execute_write(execute_upsert) + + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + await result.consume() + await session.execute_write(execute_upsert) + + async def delete_node(self, node_id: str) -> None: + async def _do_delete(tx: AsyncManagedTransaction): + query = """ + MATCH (n:base {entity_id: $entity_id}) + DETACH DELETE n + """ + result = await tx.run(query, entity_id=node_id) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + + async def remove_nodes(self, nodes: list[str]): + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + DELETE r + """ + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + + async def drop(self) -> dict[str, str]: + try: + async with self._driver.session(database=self._DATABASE) as session: + query = "MATCH (n) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + return {"status": "error", "message": str(e)} + + async def node_degree(self, node_id: str) -> int: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + record = await result.single() + await result.consume() + if not record: + return 0 + return record["degree"] + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + return int(src_degree) + int(trg_degree) + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = await session.run(count_query) + count_record = await count_result.single() + await count_result.consume() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = await session.run(main_query, {"max_nodes": max_nodes}) + record = await result_set.single() + await result_set.consume() + else: + # BFS fallback for Memgraph (no APOC) + from collections import deque + # Get the starting node + start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + node_result = await session.run(start_query, entity_id=node_label) + node_record = await node_result.single() + await node_result.consume() + if not node_record: + return result + start_node = node_record["n"] + start_node_id = start_node.get("entity_id") + queue = deque([(start_node, 0)]) + visited = set() + bfs_nodes = [] + while queue and len(bfs_nodes) < max_nodes: + current_node, depth = queue.popleft() + node_id = current_node.get("entity_id") + if node_id in visited: + continue + visited.add(node_id) + bfs_nodes.append(current_node) + if depth < max_depth: + # Get neighbors + neighbor_query = """ + MATCH (n:base {entity_id: $entity_id})-[]-(m:base) + RETURN m + """ + neighbors_result = await session.run(neighbor_query, entity_id=node_id) + neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + await neighbors_result.consume() + for neighbor in neighbors: + neighbor_id = neighbor.get("entity_id") + if neighbor_id not in visited: + queue.append((neighbor, depth + 1)) + # Build subgraph + subgraph_ids = [n.get("entity_id") for n in bfs_nodes] + # Nodes + for n in bfs_nodes: + node_id = n.get("entity_id") + if node_id not in seen_nodes: + result.nodes.append(KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + )) + seen_nodes.add(node_id) + # Edges + if subgraph_ids: + edge_query = """ + MATCH (a:base)-[r]-(b:base) + WHERE a.entity_id IN $ids AND b.entity_id IN $ids + RETURN DISTINCT r, a, b + """ + edge_result = await session.run(edge_query, ids=subgraph_ids) + async for record in edge_result: + r = record["r"] + a = record["a"] + b = record["b"] + edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" + if edge_id not in seen_edges: + result.edges.append(KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + )) + seen_edges.add(edge_id) + await edge_result.consume() + logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") + return result \ No newline at end of file From 0d6bd3bac2c9f0e5befe955628838e975610b6f2 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:18:25 +0200 Subject: [PATCH 02/19] Revert changes made to graph_visual_with_neo4j.py --- examples/graph_visual_with_neo4j.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index e06c248c..1cd2e7a3 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100 # Neo4j connection credentials NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" -NEO4J_PASSWORD = "david123" +NEO4J_PASSWORD = "your_password" def xml_to_json(xml_file): From 80d4d5b0d50056cd89e347789d489896f0b39275 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:26:51 +0200 Subject: [PATCH 03/19] Add Memgraph into README.md --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 617dc5e6..2068f205 100644 --- a/README.md +++ b/README.md @@ -854,6 +854,41 @@ rag = LightRAG( +
+ Using Memgraph for Storage + +* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol. +* You can run Memgraph locally using Docker for easy testing: +* See: https://memgraph.com/download + +```python +export MEMGRAPH_URI="bolt://localhost:7687" + +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + +# When you launch the project, override the default KG: NetworkX +# by specifying kg="MemgraphStorage". + +# Note: Default settings use NetworkX +# Initialize LightRAG with Memgraph implementation. +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="MemgraphStorage", #<-----------override KG default + ) + + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag +``` + +
+ ## Edit Entities and Relations LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. From 7118b23ca2851d933384612d87cd34c25b4bba5e Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:33:19 +0200 Subject: [PATCH 04/19] reformatting --- lightrag/kg/memgraph_impl.py | 142 ++++++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 42 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index cb46cdc6..df28b8b2 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -1,5 +1,4 @@ import os -import re from dataclasses import dataclass from typing import final import configparser @@ -28,6 +27,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @final @dataclass class MemgraphStorage(BaseGraphStorage): @@ -40,10 +40,19 @@ class MemgraphStorage(BaseGraphStorage): self._driver = None async def initialize(self): - URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) - USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) - PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) - DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), + ) + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") + ) self._driver = AsyncGraphDatabase.driver( URI, @@ -58,7 +67,9 @@ class MemgraphStorage(BaseGraphStorage): logger.info("Created index on :base(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error - logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + logger.warning( + f"Index creation on :base(entity_id) may have failed or already exists: {e}" + ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") except Exception as e: @@ -78,7 +89,9 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() @@ -86,7 +99,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -101,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" result = await session.run(query, entity_id=node_id) records = await result.fetch(2) @@ -110,12 +127,16 @@ class MemgraphStorage(BaseGraphStorage): node = records[0]["n"] node_dict = dict(node) if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] return node_dict return None async def get_all_labels(self) -> list[str]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base) WHERE n.entity_id IS NOT NULL @@ -130,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): return labels async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-(connected:base) @@ -151,8 +174,12 @@ class MemgraphStorage(BaseGraphStorage): await results.consume() return edges - async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties @@ -181,23 +208,28 @@ class MemgraphStorage(BaseGraphStorage): properties = node_data entity_type = properties.get("entity_type", "base") if "entity_id" not in properties: - raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + raise ValueError( + "Memgraph: node properties must contain an 'entity_id' field" + ) async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:base {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run(query, entity_id=node_id, properties=properties) await result.consume() + await session.execute_write(execute_upsert) - async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id}) @@ -214,6 +246,7 @@ class MemgraphStorage(BaseGraphStorage): properties=edge_properties, ) await result.consume() + await session.execute_write(execute_upsert) async def delete_node(self, node_id: str) -> None: @@ -224,6 +257,7 @@ class MemgraphStorage(BaseGraphStorage): """ result = await tx.run(query, entity_id=node_id) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) @@ -233,6 +267,7 @@ class MemgraphStorage(BaseGraphStorage): async def remove_edges(self, edges: list[tuple[str, str]]): for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) @@ -242,6 +277,7 @@ class MemgraphStorage(BaseGraphStorage): query, source_entity_id=source, target_entity_id=target ) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) @@ -251,14 +287,18 @@ class MemgraphStorage(BaseGraphStorage): query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + logger.info( + f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def node_degree(self, node_id: str) -> int: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() @@ -279,7 +319,9 @@ class MemgraphStorage(BaseGraphStorage): return int(src_degree) + int(trg_degree) async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (n:base) @@ -297,7 +339,9 @@ class MemgraphStorage(BaseGraphStorage): return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) @@ -323,7 +367,9 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = await session.run(count_query) @@ -331,7 +377,9 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() if count_record and count_record["total"] > max_nodes: result.is_truncated = True - logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() @@ -352,6 +400,7 @@ class MemgraphStorage(BaseGraphStorage): else: # BFS fallback for Memgraph (no APOC) from collections import deque + # Get the starting node start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" node_result = await session.run(start_query, entity_id=node_label) @@ -360,7 +409,6 @@ class MemgraphStorage(BaseGraphStorage): if not node_record: return result start_node = node_record["n"] - start_node_id = start_node.get("entity_id") queue = deque([(start_node, 0)]) visited = set() bfs_nodes = [] @@ -377,8 +425,12 @@ class MemgraphStorage(BaseGraphStorage): MATCH (n:base {entity_id: $entity_id})-[]-(m:base) RETURN m """ - neighbors_result = await session.run(neighbor_query, entity_id=node_id) - neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + neighbors_result = await session.run( + neighbor_query, entity_id=node_id + ) + neighbors = [ + rec["m"] for rec in await neighbors_result.to_list() + ] await neighbors_result.consume() for neighbor in neighbors: neighbor_id = neighbor.get("entity_id") @@ -390,11 +442,13 @@ class MemgraphStorage(BaseGraphStorage): for n in bfs_nodes: node_id = n.get("entity_id") if node_id not in seen_nodes: - result.nodes.append(KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - )) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + ) + ) seen_nodes.add(node_id) # Edges if subgraph_ids: @@ -410,14 +464,18 @@ class MemgraphStorage(BaseGraphStorage): b = record["b"] edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" if edge_id not in seen_edges: - result.edges.append(KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), - )) + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + ) + ) seen_edges.add(edge_id) await edge_result.consume() - logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") - return result \ No newline at end of file + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result From bd158d096bb26215b283496b4bf2aebe4d0e2292 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 14:47:23 +0200 Subject: [PATCH 05/19] polish Memgraph implementation --- lightrag/kg/memgraph_impl.py | 803 ++++++++++++++++++++++++----------- 1 file changed, 551 insertions(+), 252 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index df28b8b2..bf870154 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -89,183 +89,419 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: + """ + Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + + Raises: + Exception: If there is an error checking the node existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" - result = await session.run(query, entity_id=node_id) - single_result = await result.single() - await result.consume() - return single_result["node_exists"] + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes in the graph. + + Args: + source_node_id: The ID of the source node. + target_node_id: The ID of the target node. + + Returns: + bool: True if the edge exists, False otherwise. + + Raises: + Exception: If there is an error checking the edge existence. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - single_result = await result.single() - await result.consume() - return single_result["edgeExists"] + try: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - result = await session.run(query, entity_id=node_id) - records = await result.fetch(2) - await result.consume() - if records: - node = records[0]["n"] - node_dict = dict(node) - if "labels" in node_dict: - node_dict["labels"] = [ - label for label in node_dict["labels"] if label != "base" - ] - return node_dict - return None + try: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove base label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != "base" + ] + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + Exception: If there is an error executing the query + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (n:base) - WHERE n.entity_id IS NOT NULL - RETURN DISTINCT n.entity_id AS label - ORDER BY label - """ - result = await session.run(query) - labels = [] - async for record in result: - labels.append(record["label"]) - await result.consume() - return labels + try: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) - WHERE connected.entity_id IS NOT NULL - RETURN n, r, connected - """ - results = await session.run(query, entity_id=source_node_id) - edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] - if not source_node or not connected_node: - continue - source_label = source_node.get("entity_id") - target_label = connected_node.get("entity_id") - if source_label and target_label: - edges.append((source_label, target_label)) - await results.consume() - return edges + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + Exception: If there is an error executing the query + """ + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + query = """MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + Exception: If there is an error executing the query + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) - RETURN properties(r) as edge_properties - """ - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) - records = await result.fetch(2) - await result.consume() - if records: - edge_result = dict(records[0]["edge_properties"]) - for key, default_value in { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - }.items(): - if key not in edge_result: - edge_result[key] = default_value - return edge_result - return None + try: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + ) + return edge_result + return None + except Exception as e: + logger.error( + f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ properties = node_data - entity_type = properties.get("entity_type", "base") + entity_type = properties["entity_type"] if "entity_id" not in properties: - raise ValueError( - "Memgraph: node properties must contain an 'entity_id' field" - ) - async with self._driver.session(database=self._DATABASE) as session: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") - async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:base {{entity_id: $entity_id}}) + try: + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + """ + MERGE (n:base {entity_id: $entity_id}) SET n += $properties - SET n:`{entity_type}` + SET n:`%s` """ - result = await tx.run(query, entity_id=node_id, properties=properties) - await result.consume() + % entity_type + ) + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed - await session.execute_write(execute_upsert) + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - edge_properties = edge_data - async with self._driver.session(database=self._DATABASE) as session: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. - async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) - WITH source - MATCH (target:base {entity_id: $target_entity_id}) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r, source, target - """ - result = await tx.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - properties=edge_properties, - ) - await result.consume() + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge - await session.execute_write(execute_upsert) + Raises: + Exception: If there is an error executing the query + """ + try: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + + Raises: + Exception: If there is an error executing the query + """ + async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label {node_id}") await result.consume() - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ for node in nodes: await self.delete_node(node) async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + + Raises: + Exception: If there is an error executing the query + """ for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -276,15 +512,32 @@ class MemgraphStorage(BaseGraphStorage): result = await tx.run( query, source_entity_id=source, target_entity_id=target ) - await result.consume() + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed - async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_delete_edge) + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This method will delete all nodes and relationships in the Neo4j database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + + Raises: + Exception: If there is an error executing the query + """ try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + query = "DROP GRAPH" result = await session.run(query) await result.consume() logger.info( @@ -295,30 +548,36 @@ class MemgraphStorage(BaseGraphStorage): logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} - async def node_degree(self, node_id: str) -> int: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-() - RETURN COUNT(r) AS degree - """ - result = await session.run(query, entity_id=node_id) - record = await result.single() - await result.consume() - if not record: - return 0 - return record["degree"] - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree - return int(src_degree) + int(trg_degree) + + degrees = int(src_degree) + int(trg_degree) + return degrees async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated nodes for + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -335,10 +594,19 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated edges for + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -364,118 +632,149 @@ class MemgraphStorage(BaseGraphStorage): max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + Raises: + Exception: If there is an error executing the query + """ result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - if node_label == "*": - count_query = "MATCH (n) RETURN count(n) as total" - count_result = await session.run(count_query) - count_record = await count_result.single() - await count_result.consume() - if count_record and count_record["total"] > max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" - ) - main_query = """ - MATCH (n) - OPTIONAL MATCH (n)-[r]-() - WITH n, COALESCE(count(r), 0) AS degree - ORDER BY degree DESC - LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) - WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships - """ - result_set = await session.run(main_query, {"max_nodes": max_nodes}) - record = await result_set.single() - await result_set.consume() - else: - # BFS fallback for Memgraph (no APOC) - from collections import deque + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() - # Get the starting node - start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" - node_result = await session.run(start_query, entity_id=node_label) - node_record = await node_result.single() - await node_result.consume() - if not node_record: - return result - start_node = node_record["n"] - queue = deque([(start_node, 0)]) - visited = set() - bfs_nodes = [] - while queue and len(bfs_nodes) < max_nodes: - current_node, depth = queue.popleft() - node_id = current_node.get("entity_id") - if node_id in visited: - continue - visited.add(node_id) - bfs_nodes.append(current_node) - if depth < max_depth: - # Get neighbors - neighbor_query = """ - MATCH (n:base {entity_id: $entity_id})-[]-(m:base) - RETURN m - """ - neighbors_result = await session.run( - neighbor_query, entity_id=node_id - ) - neighbors = [ - rec["m"] for rec in await neighbors_result.to_list() - ] - await neighbors_result.consume() - for neighbor in neighbors: - neighbor_id = neighbor.get("entity_id") - if neighbor_id not in visited: - queue.append((neighbor, depth + 1)) - # Build subgraph - subgraph_ids = [n.get("entity_id") for n in bfs_nodes] - # Nodes - for n in bfs_nodes: - node_id = n.get("entity_id") - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - ) - ) - seen_nodes.add(node_id) - # Edges - if subgraph_ids: - edge_query = """ - MATCH (a:base)-[r]-(b:base) - WHERE a.entity_id IN $ids AND b.entity_id IN $ids - RETURN DISTINCT r, a, b + # Run the main query to get nodes with highest degree + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ - edge_result = await session.run(edge_query, ids=subgraph_ids) - async for record in edge_result: - r = record["r"] - a = record["a"] - b = record["b"] - edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), + result_set = None + try: + result_set = await session.run( + main_query, {"max_nodes": max_nodes} + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + + else: + bfs_query = """ + MATCH (start) WHERE start.entity_id = $entity_id + WITH start + CALL { + WITH start + MATCH path = (start)-[*0..$max_depth]-(node) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS n + WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists + WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + RETURN all_nodes, all_rels + } + WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes + + // Apply node limiting here + WITH CASE + WHEN total_nodes <= $max_nodes THEN nodes + ELSE nodes[0..$max_nodes] + END AS limited_nodes, + relationships, + total_nodes, + total_nodes > $max_nodes AS is_truncated + UNWIND limited_nodes AS node + WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated + RETURN node_info, relationships, total_nodes, is_truncated + """ + result_set = None + try: + result_set = await session.run( + bfs_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, + ) + record = await result_set.single() + if not record: + logger.debug(f"No record found for node {node_label}") + return result + + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) ) - ) - seen_edges.add(edge_id) - await edge_result.consume() - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - return result + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + finally: + if result_set: + await result_set.consume() + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result From eed43e071cd887c7241f970d15ec91c3eeafe0a4 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 14:49:57 +0200 Subject: [PATCH 06/19] revert lightrag_openai_demo.py changes --- examples/lightrag_openai_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index e573ec41..fa0b37f1 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -82,7 +82,6 @@ async def initialize_rag(): working_dir=WORKING_DIR, embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, - graph_storage="MemgraphStorage", ) await rag.initialize_storages() From 9aaa7d2dd3e6386e6f61ef3a93e0e69077949f1a Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 15:09:22 +0200 Subject: [PATCH 07/19] fix drop function in Memgraph implementation --- lightrag/kg/memgraph_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index bf870154..36f0186b 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -537,7 +537,7 @@ class MemgraphStorage(BaseGraphStorage): """ try: async with self._driver.session(database=self._DATABASE) as session: - query = "DROP GRAPH" + query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() logger.info( From c0a3638d011ab2b3df586fbc0aaf970d37aeee2c Mon Sep 17 00:00:00 2001 From: DavIvek Date: Fri, 27 Jun 2025 15:35:20 +0200 Subject: [PATCH 08/19] fix memgraph_impl.py according to test_graph_storage.py --- lightrag/kg/memgraph_impl.py | 96 +++++++++++++++++++----------------- tests/test_graph_storage.py | 1 + 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 36f0186b..41a1129b 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage): node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) - await result.consume() - return nodes + await result.consume() + return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: """Get all edges that are associated with the given chunk_ids. @@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage): UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id + // Ensure we only return each unique edge once by ordering the source and target + WITH a, b, r, + CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, + CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target + RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties """ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) edges = [] @@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - try: - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = None @@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage): logger.debug(f"No record found for node {node_label}") return result - for node_info in record["node_info"]: - node = node_info["node"] - node_id = node.id - if node_id not in seen_nodes: - seen_nodes.add(node_id) - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=[node.get("entity_id")], - properties=dict(node), - ) - ) - - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) - ) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - return result - finally: if result_set: await result_set.consume() - except Exception as e: - logger.error(f"Error getting knowledge graph: {str(e)}") - return result + if record: + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + return result + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + return result diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 64e66f48..3fd1abbc 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -9,6 +9,7 @@ - NetworkXStorage - Neo4JStorage - PGGraphStorage +- MemgraphStorage """ import asyncio From 4ea38456f060892fed2953fdc843760a920c1db5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 00:31:52 +0800 Subject: [PATCH 09/19] Improve graph query robustness and error handling --- lightrag/kg/memgraph_impl.py | 74 +++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 41a1129b..397e5a99 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -660,16 +660,23 @@ class MemgraphStorage(BaseGraphStorage): ) as session: try: if node_label == "*": + # First check if database has any nodes count_query = "MATCH (n) RETURN count(n) as total" count_result = None + total_count = 0 try: count_result = await session.run(count_query) count_record = await count_result.single() - if count_record and count_record["total"] > max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" - ) + if count_record: + total_count = count_record["total"] + if total_count == 0: + logger.debug("No nodes found in database") + return result + if total_count > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {total_count} nodes found, limited to {max_nodes}" + ) finally: if count_result: await count_result.consume() @@ -695,6 +702,9 @@ class MemgraphStorage(BaseGraphStorage): main_query, {"max_nodes": max_nodes} ) record = await result_set.single() + if not record: + logger.debug("No record returned from main query") + return result finally: if result_set: await result_set.consume() @@ -738,14 +748,22 @@ class MemgraphStorage(BaseGraphStorage): ) record = await result_set.single() if not record: - logger.debug(f"No record found for node {node_label}") + logger.debug(f"No nodes found for entity_id: {node_label}") return result + # Check if the query indicates truncation + if "is_truncated" in record and record["is_truncated"]: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + ) + finally: if result_set: await result_set.consume() - if record: + # Process the record if it exists + if record and record["node_info"]: for node_info in record["node_info"]: node = node_info["node"] node_id = node.id @@ -759,28 +777,30 @@ class MemgraphStorage(BaseGraphStorage): ) ) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + if "relationships" in record and record["relationships"]: + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - return result + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) except Exception as e: logger.error(f"Error getting knowledge graph: {str(e)}") - return result + # Return empty but properly initialized KnowledgeGraph on error + return KnowledgeGraph() + + return result From 2f7cef968d49a0986d0f14f3903862947d208812 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 00:32:55 +0800 Subject: [PATCH 10/19] fix: ensure Milvus collections are loaded before operations - Resolves "collection not loaded" MilvusException errors --- lightrag/kg/milvus_impl.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 6cffae88..eecf679a 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -539,6 +539,23 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) raise + def _ensure_collection_loaded(self): + """Ensure the collection is loaded into memory for search operations""" + try: + # Check if collection exists first + if not self._client.has_collection(self.namespace): + logger.error(f"Collection {self.namespace} does not exist") + raise ValueError(f"Collection {self.namespace} does not exist") + + # Load the collection if it's not already loaded + # In Milvus, collections need to be loaded before they can be searched + self._client.load_collection(self.namespace) + logger.debug(f"Collection {self.namespace} loaded successfully") + + except Exception as e: + logger.error(f"Failed to load collection {self.namespace}: {e}") + raise + def _create_collection_if_not_exist(self): """Create collection if not exists and check existing collection compatibility""" @@ -565,6 +582,8 @@ class MilvusVectorDBStorage(BaseVectorStorage): f"Collection '{self.namespace}' confirmed to exist, validating compatibility..." ) self._validate_collection_compatibility() + # Ensure the collection is loaded after validation + self._ensure_collection_loaded() return except Exception as describe_error: logger.warning( @@ -587,6 +606,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Then create indexes self._create_indexes_after_collection() + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully created Milvus collection: {self.namespace}") except Exception as e: @@ -615,6 +637,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): collection_name=self.namespace, schema=schema ) self._create_indexes_after_collection() + + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully force-created collection {self.namespace}") except Exception as create_error: @@ -670,6 +696,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not data: return + # Ensure collection is loaded before upserting + self._ensure_collection_loaded() + import time current_time = int(time.time()) @@ -700,6 +729,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def query( self, query: str, top_k: int, ids: list[str] | None = None ) -> list[dict[str, Any]]: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query @@ -764,6 +796,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): entity_name: The name of the entity whose relations should be deleted """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' @@ -802,6 +837,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: + # Ensure collection is loaded before deleting + self._ensure_collection_loaded() + # Delete vectors by IDs result = self._client.delete(collection_name=self.namespace, pks=ids) @@ -825,6 +863,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): The vector data if found, or None if not found """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] @@ -856,6 +897,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return [] try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] From fb979be9ff9f19aad60030e509764938540e97ea Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 07:09:33 +0800 Subject: [PATCH 11/19] Fix linting --- lightrag/kg/milvus_impl.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index eecf679a..2226784f 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -546,12 +546,12 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not self._client.has_collection(self.namespace): logger.error(f"Collection {self.namespace} does not exist") raise ValueError(f"Collection {self.namespace} does not exist") - + # Load the collection if it's not already loaded # In Milvus, collections need to be loaded before they can be searched self._client.load_collection(self.namespace) logger.debug(f"Collection {self.namespace} loaded successfully") - + except Exception as e: logger.error(f"Failed to load collection {self.namespace}: {e}") raise @@ -637,10 +637,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): collection_name=self.namespace, schema=schema ) self._create_indexes_after_collection() - + # Load the newly created collection self._ensure_collection_loaded() - + logger.info(f"Successfully force-created collection {self.namespace}") except Exception as create_error: @@ -731,7 +731,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query @@ -798,7 +798,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' @@ -839,7 +839,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before deleting self._ensure_collection_loaded() - + # Delete vectors by IDs result = self._client.delete(collection_name=self.namespace, pks=ids) @@ -865,7 +865,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] @@ -899,7 +899,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): try: # Ensure collection is loaded before querying self._ensure_collection_loaded() - + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] From 3eaadb8a4432b03f67432774fd18684206208215 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 8 Jul 2025 03:06:19 +0800 Subject: [PATCH 12/19] Update env.example --- env.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/env.example b/env.example index 1efe4830..ef52bd53 100644 --- a/env.example +++ b/env.example @@ -159,7 +159,7 @@ NEO4J_PASSWORD='your_password' ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ -#MONGO_URI=mongodb+srv://root:rooot@cluster0.xxxx.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0 +#MONGO_URI=mongodb+srv://xxxx MONGO_DATABASE=LightRAG # MONGODB_WORKSPACE=forced_workspace_name From 9a9674d590a43f1c02c7fed164057a33d13db9f2 Mon Sep 17 00:00:00 2001 From: frankj Date: Tue, 8 Jul 2025 10:24:19 +0800 Subject: [PATCH 13/19] Fix incorrect file path (404 Not Found) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue Description A 404 error occurred when accessing the repository link pointing to README_zh.md. Upon inspection, the actual file path is README-zh.md, indicating an incorrect path reference in the original link. Fix Details Corrected the broken link from README_zh.md to the correct path README-zh.md. Verification Method After modification, the target file opens normally in the browser. Hope this fix helps users access the Chinese documentation properly—thanks for the review! --- README-zh.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README-zh.md b/README-zh.md index 45335489..e9599099 100644 --- a/README-zh.md +++ b/README-zh.md @@ -30,7 +30,7 @@

- +

From 8cbba6e9dbf2ae28760ae6986a81fa6d2bca371f Mon Sep 17 00:00:00 2001 From: Molion Surya Date: Tue, 8 Jul 2025 13:25:52 +0800 Subject: [PATCH 14/19] Fix #1746: [openai.py logic for streaming complete] --- lightrag/llm/openai.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 57f016cf..30491476 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -210,9 +210,16 @@ async def openai_complete_if_cache( async def inner(): # Track if we've started iterating iteration_started = False + final_chunk_usage = None + try: iteration_started = True async for chunk in response: + # Check if this chunk has usage information (final chunk) + if hasattr(chunk, "usage") and chunk.usage: + final_chunk_usage = chunk.usage + logger.debug(f"Received usage info in streaming chunk: {chunk.usage}") + # Check if choices exists and is not empty if not hasattr(chunk, "choices") or not chunk.choices: logger.warning(f"Received chunk without choices: {chunk}") @@ -222,16 +229,29 @@ async def openai_complete_if_cache( if not hasattr(chunk.choices[0], "delta") or not hasattr( chunk.choices[0].delta, "content" ): - logger.warning( - f"Received chunk without delta content: {chunk.choices[0]}" - ) + # This might be the final chunk, continue to check for usage continue + content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr(final_chunk_usage, "completion_tokens", 0), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Streaming token usage (from API): {token_counts}") + elif token_tracker: + logger.debug("No usage information available in streaming response") except Exception as e: logger.error(f"Error in stream response: {str(e)}") # Try to clean up resources if possible @@ -451,4 +471,4 @@ async def openai_embed( response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) - return np.array([dp.embedding for dp in response.data]) + return np.array([dp.embedding for dp in response.data]) \ No newline at end of file From 5f330ec11a487753e9aa06a15fdeb5df782d9d49 Mon Sep 17 00:00:00 2001 From: SLKun Date: Mon, 7 Jul 2025 10:31:46 +0800 Subject: [PATCH 15/19] remove tag for entities and keywords extraction --- lightrag/operate.py | 4 +++- lightrag/utils.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 88837435..4e219cf8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -26,6 +26,7 @@ from .utils import ( get_conversation_turns, use_llm_func_with_cache, update_chunk_cache_list, + remove_think_tags, ) from .base import ( BaseGraphStorage, @@ -1703,7 +1704,8 @@ async def extract_keywords_only( result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response - match = re.search(r"\{.*\}", result, re.DOTALL) + result = remove_think_tags(result) + match = re.search(r"\{.*?\}", result, re.DOTALL) if not match: logger.error("No JSON-like structure found in the LLM respond.") return [], [] diff --git a/lightrag/utils.py b/lightrag/utils.py index c6e2def9..386de3ab 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1465,6 +1465,11 @@ async def update_chunk_cache_list( ) +def remove_think_tags(text: str) -> str: + """Remove tags from the text""" + return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip() + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache( kwargs["max_tokens"] = max_tokens res: str = await use_llm_func(input_text, **kwargs) + res = remove_think_tags(res) if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): await save_to_cache( @@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache( if max_tokens is not None: kwargs["max_tokens"] = max_tokens - logger.info(f"Call LLM function with query text lenght: {len(input_text)}") - return await use_llm_func(input_text, **kwargs) + logger.info(f"Call LLM function with query text length: {len(input_text)}") + res = await use_llm_func(input_text, **kwargs) + return remove_think_tags(res) def get_content_summary(content: str, max_length: int = 250) -> str: From 2a0cff3ed6ec69e0b5786bbcea7402b25b5c2dc0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 8 Jul 2025 18:17:21 +0800 Subject: [PATCH 16/19] Fix linting --- lightrag/llm/openai.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 30491476..eb74c2f1 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -211,15 +211,17 @@ async def openai_complete_if_cache( # Track if we've started iterating iteration_started = False final_chunk_usage = None - + try: iteration_started = True async for chunk in response: # Check if this chunk has usage information (final chunk) if hasattr(chunk, "usage") and chunk.usage: final_chunk_usage = chunk.usage - logger.debug(f"Received usage info in streaming chunk: {chunk.usage}") - + logger.debug( + f"Received usage info in streaming chunk: {chunk.usage}" + ) + # Check if choices exists and is not empty if not hasattr(chunk, "choices") or not chunk.choices: logger.warning(f"Received chunk without choices: {chunk}") @@ -231,21 +233,23 @@ async def openai_complete_if_cache( ): # This might be the final chunk, continue to check for usage continue - + content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) - + yield content - + # After streaming is complete, track token usage if token_tracker and final_chunk_usage: # Use actual usage from the API token_counts = { "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), - "completion_tokens": getattr(final_chunk_usage, "completion_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), } token_tracker.add_usage(token_counts) @@ -471,4 +475,4 @@ async def openai_embed( response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) - return np.array([dp.embedding for dp in response.data]) \ No newline at end of file + return np.array([dp.embedding for dp in response.data]) From 4438897b6bf5e3db4fe3ff1a872805dce8377751 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 8 Jul 2025 16:27:38 +0200 Subject: [PATCH 17/19] add changes based on review --- env.example | 7 ++ lightrag/kg/memgraph_impl.py | 223 +++++++++++++++++++++-------------- 2 files changed, 143 insertions(+), 87 deletions(-) diff --git a/env.example b/env.example index ef52bd53..df88a518 100644 --- a/env.example +++ b/env.example @@ -179,3 +179,10 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +# MEMGRAPH_WORKSPACE=forced_workspace_name diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 397e5a99..4c16b843 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -31,14 +31,23 @@ config.read("config.ini", "utf-8") @final @dataclass class MemgraphStorage(BaseGraphStorage): - def __init__(self, namespace, global_config, embedding_func): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace super().__init__( namespace=namespace, + workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self._driver = None + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + async def initialize(self): URI = os.environ.get( "MEMGRAPH_URI", @@ -63,12 +72,13 @@ class MemgraphStorage(BaseGraphStorage): async with self._driver.session(database=DATABASE) as session: # Create index for base nodes on entity_id if it doesn't exist try: - await session.run("""CREATE INDEX ON :base(entity_id)""") - logger.info("Created index on :base(entity_id) in Memgraph.") + workspace_label = self._get_workspace_label() + await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""") + logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error logger.warning( - f"Index creation on :base(entity_id) may have failed or already exists: {e}" + f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") @@ -101,15 +111,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the node existence. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + return single_result["node_exists"] if single_result is not None else False except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -129,22 +142,21 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error checking the edge existence. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: + workspace_label = self._get_workspace_label() query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - ) + result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + return single_result["edgeExists"] if single_result is not None else False except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -165,11 +177,14 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -183,12 +198,12 @@ class MemgraphStorage(BaseGraphStorage): if records: node = records[0]["n"] node_dict = dict(node) - # Remove base label from labels list if it exists + # Remove workspace label from labels list if it exists if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] - if label != "base" + if label != workspace_label ] return node_dict return None @@ -212,12 +227,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ @@ -246,12 +264,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label @@ -280,13 +301,16 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected""" results = await session.run(query, entity_id=source_node_id) @@ -341,12 +365,15 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) RETURN properties(r) as edge_properties """ result = await session.run( @@ -386,6 +413,8 @@ class MemgraphStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -393,15 +422,14 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: - + workspace_label = self._get_workspace_label() async def execute_upsert(tx: AsyncManagedTransaction): query = ( - """ - MERGE (n:base {entity_id: $entity_id}) + f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties - SET n:`%s` + SET n:`{entity_type}` """ - % entity_type ) result = await tx.run( query, entity_id=node_id, properties=properties @@ -429,15 +457,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source - MATCH (target:base {entity_id: $target_entity_id}) + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target @@ -467,10 +498,13 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") async def _do_delete(tx: AsyncManagedTransaction): - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) @@ -490,6 +524,8 @@ class MemgraphStorage(BaseGraphStorage): Args: nodes: List of node labels to be deleted """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") for node in nodes: await self.delete_node(node) @@ -502,11 +538,14 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) DELETE r """ result = await tx.run( @@ -523,9 +562,9 @@ class MemgraphStorage(BaseGraphStorage): raise async def drop(self) -> dict[str, str]: - """Drop all data from storage and clean up resources + """Drop all data from the current workspace and clean up resources - This method will delete all nodes and relationships in the Neo4j database. + This method will delete all nodes and relationships in the Memgraph database. Returns: dict[str, str]: Operation status and message @@ -535,17 +574,18 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") try: async with self._driver.session(database=self._DATABASE) as session: - query = "MATCH (n) DETACH DELETE n" + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info( - f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" - ) - return {"status": "success", "message": "data dropped"} + logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}") + return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -558,6 +598,8 @@ class MemgraphStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -578,12 +620,15 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of nodes, where each node is a dictionary of its properties. An empty list if no matching nodes are found. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (n:base) + MATCH (n:`{workspace_label}`) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) RETURN DISTINCT n """ @@ -607,12 +652,15 @@ class MemgraphStorage(BaseGraphStorage): list[dict]: A list of edges, where each edge is a dictionary of its properties. An empty list if no matching edges are found. """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (a:base)-[r]-(b:base) + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id // Ensure we only return each unique edge once by ordering the source and target @@ -652,9 +700,13 @@ class MemgraphStorage(BaseGraphStorage): Raises: Exception: If there is an error executing the query """ + if self._driver is None: + raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + result = KnowledgeGraph() seen_nodes = set() seen_edges = set() + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -682,19 +734,17 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() # Run the main query to get nodes with highest degree - main_query = """ - MATCH (n) + main_query = f""" + MATCH (n:`{workspace_label}`) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) + WITH collect(n) AS kept_nodes + MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships + RETURN [node IN kept_nodes | {{node: node}}] AS node_info, + collect(DISTINCT r) AS relationships """ result_set = None try: @@ -710,31 +760,33 @@ class MemgraphStorage(BaseGraphStorage): await result_set.consume() else: - bfs_query = """ - MATCH (start) WHERE start.entity_id = $entity_id + bfs_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id WITH start - CALL { + CALL {{ WITH start - MATCH path = (start)-[*0..$max_depth]-(node) + MATCH path = (start)-[*0..{max_depth}]-(node) WITH nodes(path) AS path_nodes, relationships(path) AS path_rels UNWIND path_nodes AS n WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels RETURN all_nodes, all_rels - } + }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - - // Apply node limiting here - WITH CASE - WHEN total_nodes <= $max_nodes THEN nodes - ELSE nodes[0..$max_nodes] + WITH + CASE + WHEN total_nodes <= {max_nodes} THEN nodes + ELSE nodes[0..{max_nodes}] END AS limited_nodes, relationships, total_nodes, - total_nodes > $max_nodes AS is_truncated - UNWIND limited_nodes AS node - WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated - RETURN node_info, relationships, total_nodes, is_truncated + total_nodes > {max_nodes} AS is_truncated + RETURN + [node IN limited_nodes | {{node: node}}] AS node_info, + relationships, + total_nodes, + is_truncated """ result_set = None try: @@ -742,8 +794,6 @@ class MemgraphStorage(BaseGraphStorage): bfs_query, { "entity_id": node_label, - "max_depth": max_depth, - "max_nodes": max_nodes, }, ) record = await result_set.single() @@ -777,22 +827,21 @@ class MemgraphStorage(BaseGraphStorage): ) ) - if "relationships" in record and record["relationships"]: - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - seen_edges.add(edge_id) - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), ) + ) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" From 08eb68b8ed0f774843fd682e3de3cc4749b5b6e8 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Tue, 8 Jul 2025 20:21:20 +0200 Subject: [PATCH 18/19] run pre-commit --- lightrag/kg/memgraph_impl.py | 113 +++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 31 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 4c16b843..8c6d6574 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage): # Create index for base nodes on entity_id if it doesn't exist try: workspace_label = self._get_workspace_label() - await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""") - logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.") + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) except Exception as e: # Index may already exist, which is not an error logger.warning( @@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the node existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage): result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] if single_result is not None else False + return ( + single_result["node_exists"] if single_result is not None else False + ) except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure the result is consumed even on error @@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error checking the edge existence. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage): f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) # type: ignore single_result = await result.single() await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] if single_result is not None else False + return ( + single_result["edgeExists"] if single_result is not None else False + ) except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" @@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage): node_data: Dictionary of node properties """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run( query, entity_id=node_id, properties=properties ) @@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: @@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) async def _do_delete(tx: AsyncManagedTransaction): workspace_label = self._get_workspace_label() @@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage): nodes: List of node labels to be deleted """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for node in nodes: await self.delete_node(node) @@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): @@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}") + logger.info( + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "workspace data dropped"} except Exception as e: - logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}") + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage): int: Sum of the degrees of both nodes """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) @@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching nodes are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage): An empty list if no matching edges are found. """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage): Exception: If there is an error executing the query """ if self._driver is None: - raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.") + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) result = KnowledgeGraph() seen_nodes = set() @@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage): else: bfs_query = f""" - MATCH (start:`{workspace_label}`) + MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start CALL {{ @@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage): RETURN all_nodes, all_rels }} WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes - WITH + WITH CASE WHEN total_nodes <= {max_nodes} THEN nodes ELSE nodes[0..{max_nodes}] @@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage): relationships, total_nodes, total_nodes > {max_nodes} AS is_truncated - RETURN + RETURN [node IN limited_nodes | {{node: node}}] AS node_info, relationships, total_nodes, From 3a0249a6b9bc09e2584f316f0b58a0b020ec0465 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 9 Jul 2025 03:36:17 +0800 Subject: [PATCH 19/19] Update env.example --- env.example | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/env.example b/env.example index df88a518..32a9f3ed 100644 --- a/env.example +++ b/env.example @@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage ### Graph Storage (Recommended for production deployment) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage #################################################################### ### Default workspace for all storage types ### For the purpose of isolation of data for each LightRAG instance ### Valid characters: a-z, A-Z, 0-9, and _ #################################################################### -# WORKSPACE=doc— +# WORKSPACE=space1 ### PostgreSQL Configuration POSTGRES_HOST=localhost