From ff1927d362f5a14656ed9df9a521af295604e26f Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:15:56 +0200 Subject: [PATCH] add Memgraph graph storage backend --- config.ini.example | 3 + examples/graph_visual_with_neo4j.py | 2 +- examples/lightrag_openai_demo.py | 1 + lightrag/kg/__init__.py | 3 + lightrag/kg/memgraph_impl.py | 423 ++++++++++++++++++++++++++++ 5 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 lightrag/kg/memgraph_impl.py diff --git a/config.ini.example b/config.ini.example index 63d9c2c0..94d300a1 100644 --- a/config.ini.example +++ b/config.ini.example @@ -21,3 +21,6 @@ password = your_password database = your_database workspace = default # 可选,默认为default max_connections = 12 + +[memgraph] +uri = bolt://localhost:7687 diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 1cd2e7a3..e06c248c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100 # Neo4j connection credentials NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" -NEO4J_PASSWORD = "your_password" +NEO4J_PASSWORD = "david123" def xml_to_json(xml_file): diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index fa0b37f1..e573ec41 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -82,6 +82,7 @@ async def initialize_rag(): working_dir=WORKING_DIR, embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, + graph_storage="MemgraphStorage", ) await rag.initialize_storages() diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index b4ba0983..3398b135 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "Neo4JStorage", "PGGraphStorage", "MongoGraphStorage", + "MemgraphStorage", # "AGEStorage", # "TiDBGraphStorage", # "GremlinStorage", @@ -56,6 +57,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], + "MemgraphStorage": ["MEMGRAPH_URI"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", @@ -108,6 +110,7 @@ STORAGES = { "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", + "MemgraphStorage": ".kg.memgraph_impl", } diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py new file mode 100644 index 00000000..cb46cdc6 --- /dev/null +++ b/lightrag/kg/memgraph_impl.py @@ -0,0 +1,423 @@ +import os +import re +from dataclasses import dataclass +from typing import final +import configparser + +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +from neo4j import ( + AsyncGraphDatabase, + AsyncManagedTransaction, +) + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +load_dotenv(dotenv_path=".env", override=False) + +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + +@final +@dataclass +class MemgraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func): + super().__init__( + namespace=namespace, + global_config=global_config, + embedding_func=embedding_func, + ) + self._driver = None + + async def initialize(self): + URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) + USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) + PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) + DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + await session.run("""CREATE INDEX ON :base(entity_id)""") + logger.info("Created index on :base(entity_id) in Memgraph.") + except Exception as e: + # Index may already exist, which is not an error + logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + await session.run("RETURN 1") + logger.info(f"Connected to Memgraph at {URI}") + except Exception as e: + logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + raise + + async def finalize(self): + if self._driver is not None: + await self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + await self.finalize() + + async def index_done_callback(self): + # Memgraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() + return single_result["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = ( + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + single_result = await result.single() + await result.consume() + return single_result["edgeExists"] + + async def get_node(self, node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) + records = await result.fetch(2) + await result.consume() + if records: + node = records[0]["n"] + node_dict = dict(node) + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + return node_dict + return None + + async def get_all_labels(self) -> list[str]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected + """ + results = await session.run(query, entity_id=source_node_id) + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + if not source_node or not connected_node: + continue + source_label = source_node.get("entity_id") + target_label = connected_node.get("entity_id") + if source_label and target_label: + edges.append((source_label, target_label)) + await results.consume() + return edges + + async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + return edge_result + return None + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + properties = node_data + entity_type = properties.get("entity_type", "base") + if "entity_id" not in properties: + raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = ( + f""" + MERGE (n:base {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + ) + result = await tx.run(query, entity_id=node_id, properties=properties) + await result.consume() + await session.execute_write(execute_upsert) + + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id}) + WITH source + MATCH (target:base {entity_id: $target_entity_id}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + await result.consume() + await session.execute_write(execute_upsert) + + async def delete_node(self, node_id: str) -> None: + async def _do_delete(tx: AsyncManagedTransaction): + query = """ + MATCH (n:base {entity_id: $entity_id}) + DETACH DELETE n + """ + result = await tx.run(query, entity_id=node_id) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + + async def remove_nodes(self, nodes: list[str]): + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = """ + MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + DELETE r + """ + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + await result.consume() + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + + async def drop(self) -> dict[str, str]: + try: + async with self._driver.session(database=self._DATABASE) as session: + query = "MATCH (n) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") + return {"status": "error", "message": str(e)} + + async def node_degree(self, node_id: str) -> int: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + record = await result.single() + await result.consume() + if not record: + return 0 + return record["degree"] + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + return int(src_degree) + int(trg_degree) + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $chunk_ids AS chunk_id + MATCH (a:base)-[r]-(b:base) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + if node_label == "*": + count_query = "MATCH (n) RETURN count(n) as total" + count_result = await session.run(count_query) + count_record = await count_result.single() + await count_result.consume() + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + main_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + OPTIONAL MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = await session.run(main_query, {"max_nodes": max_nodes}) + record = await result_set.single() + await result_set.consume() + else: + # BFS fallback for Memgraph (no APOC) + from collections import deque + # Get the starting node + start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + node_result = await session.run(start_query, entity_id=node_label) + node_record = await node_result.single() + await node_result.consume() + if not node_record: + return result + start_node = node_record["n"] + start_node_id = start_node.get("entity_id") + queue = deque([(start_node, 0)]) + visited = set() + bfs_nodes = [] + while queue and len(bfs_nodes) < max_nodes: + current_node, depth = queue.popleft() + node_id = current_node.get("entity_id") + if node_id in visited: + continue + visited.add(node_id) + bfs_nodes.append(current_node) + if depth < max_depth: + # Get neighbors + neighbor_query = """ + MATCH (n:base {entity_id: $entity_id})-[]-(m:base) + RETURN m + """ + neighbors_result = await session.run(neighbor_query, entity_id=node_id) + neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + await neighbors_result.consume() + for neighbor in neighbors: + neighbor_id = neighbor.get("entity_id") + if neighbor_id not in visited: + queue.append((neighbor, depth + 1)) + # Build subgraph + subgraph_ids = [n.get("entity_id") for n in bfs_nodes] + # Nodes + for n in bfs_nodes: + node_id = n.get("entity_id") + if node_id not in seen_nodes: + result.nodes.append(KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + )) + seen_nodes.add(node_id) + # Edges + if subgraph_ids: + edge_query = """ + MATCH (a:base)-[r]-(b:base) + WHERE a.entity_id IN $ids AND b.entity_id IN $ids + RETURN DISTINCT r, a, b + """ + edge_result = await session.run(edge_query, ids=subgraph_ids) + async for record in edge_result: + r = record["r"] + a = record["a"] + b = record["b"] + edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" + if edge_id not in seen_edges: + result.edges.append(KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + )) + seen_edges.add(edge_id) + await edge_result.consume() + logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") + return result \ No newline at end of file