From 7118b23ca2851d933384612d87cd34c25b4bba5e Mon Sep 17 00:00:00 2001 From: DavIvek Date: Thu, 26 Jun 2025 16:33:19 +0200 Subject: [PATCH] reformatting --- lightrag/kg/memgraph_impl.py | 142 ++++++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 42 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index cb46cdc6..df28b8b2 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -1,5 +1,4 @@ import os -import re from dataclasses import dataclass from typing import final import configparser @@ -28,6 +27,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @final @dataclass class MemgraphStorage(BaseGraphStorage): @@ -40,10 +40,19 @@ class MemgraphStorage(BaseGraphStorage): self._driver = None async def initialize(self): - URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) - USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) - PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) - DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), + ) + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") + ) self._driver = AsyncGraphDatabase.driver( URI, @@ -58,7 +67,9 @@ class MemgraphStorage(BaseGraphStorage): logger.info("Created index on :base(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error - logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") + logger.warning( + f"Index creation on :base(entity_id) may have failed or already exists: {e}" + ) await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") except Exception as e: @@ -78,7 +89,9 @@ class MemgraphStorage(BaseGraphStorage): pass async def has_node(self, node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() @@ -86,7 +99,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -101,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" result = await session.run(query, entity_id=node_id) records = await result.fetch(2) @@ -110,12 +127,16 @@ class MemgraphStorage(BaseGraphStorage): node = records[0]["n"] node_dict = dict(node) if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] return node_dict return None async def get_all_labels(self) -> list[str]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base) WHERE n.entity_id IS NOT NULL @@ -130,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage): return labels async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-(connected:base) @@ -151,8 +174,12 @@ class MemgraphStorage(BaseGraphStorage): await results.consume() return edges - async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties @@ -181,23 +208,28 @@ class MemgraphStorage(BaseGraphStorage): properties = node_data entity_type = properties.get("entity_type", "base") if "entity_id" not in properties: - raise ValueError("Memgraph: node properties must contain an 'entity_id' field") + raise ValueError( + "Memgraph: node properties must contain an 'entity_id' field" + ) async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - f""" + query = f""" MERGE (n:base {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ - ) result = await tx.run(query, entity_id=node_id, properties=properties) await result.consume() + await session.execute_write(execute_upsert) - async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id}) @@ -214,6 +246,7 @@ class MemgraphStorage(BaseGraphStorage): properties=edge_properties, ) await result.consume() + await session.execute_write(execute_upsert) async def delete_node(self, node_id: str) -> None: @@ -224,6 +257,7 @@ class MemgraphStorage(BaseGraphStorage): """ result = await tx.run(query, entity_id=node_id) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) @@ -233,6 +267,7 @@ class MemgraphStorage(BaseGraphStorage): async def remove_edges(self, edges: list[tuple[str, str]]): for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) @@ -242,6 +277,7 @@ class MemgraphStorage(BaseGraphStorage): query, source_entity_id=source, target_entity_id=target ) await result.consume() + async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) @@ -251,14 +287,18 @@ class MemgraphStorage(BaseGraphStorage): query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() - logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") + logger.info( + f"Process {os.getpid()} drop Memgraph database {self._DATABASE}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def node_degree(self, node_id: str) -> int: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() @@ -279,7 +319,9 @@ class MemgraphStorage(BaseGraphStorage): return int(src_degree) + int(trg_degree) async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (n:base) @@ -297,7 +339,9 @@ class MemgraphStorage(BaseGraphStorage): return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) @@ -323,7 +367,9 @@ class MemgraphStorage(BaseGraphStorage): result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = await session.run(count_query) @@ -331,7 +377,9 @@ class MemgraphStorage(BaseGraphStorage): await count_result.consume() if count_record and count_record["total"] > max_nodes: result.is_truncated = True - logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() @@ -352,6 +400,7 @@ class MemgraphStorage(BaseGraphStorage): else: # BFS fallback for Memgraph (no APOC) from collections import deque + # Get the starting node start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" node_result = await session.run(start_query, entity_id=node_label) @@ -360,7 +409,6 @@ class MemgraphStorage(BaseGraphStorage): if not node_record: return result start_node = node_record["n"] - start_node_id = start_node.get("entity_id") queue = deque([(start_node, 0)]) visited = set() bfs_nodes = [] @@ -377,8 +425,12 @@ class MemgraphStorage(BaseGraphStorage): MATCH (n:base {entity_id: $entity_id})-[]-(m:base) RETURN m """ - neighbors_result = await session.run(neighbor_query, entity_id=node_id) - neighbors = [rec["m"] for rec in await neighbors_result.to_list()] + neighbors_result = await session.run( + neighbor_query, entity_id=node_id + ) + neighbors = [ + rec["m"] for rec in await neighbors_result.to_list() + ] await neighbors_result.consume() for neighbor in neighbors: neighbor_id = neighbor.get("entity_id") @@ -390,11 +442,13 @@ class MemgraphStorage(BaseGraphStorage): for n in bfs_nodes: node_id = n.get("entity_id") if node_id not in seen_nodes: - result.nodes.append(KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties=dict(n), - )) + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties=dict(n), + ) + ) seen_nodes.add(node_id) # Edges if subgraph_ids: @@ -410,14 +464,18 @@ class MemgraphStorage(BaseGraphStorage): b = record["b"] edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" if edge_id not in seen_edges: - result.edges.append(KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=a.get("entity_id"), - target=b.get("entity_id"), - properties=dict(r), - )) + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=a.get("entity_id"), + target=b.get("entity_id"), + properties=dict(r), + ) + ) seen_edges.add(edge_id) await edge_result.consume() - logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") - return result \ No newline at end of file + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + return result