import os import re from dataclasses import dataclass from typing import final import configparser from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") from neo4j import ( AsyncGraphDatabase, AsyncManagedTransaction, ) from dotenv import load_dotenv # use the .env that is inside the current folder load_dotenv(dotenv_path=".env", override=False) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) config = configparser.ConfigParser() config.read("config.ini", "utf-8") @final @dataclass class MemgraphStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, global_config=global_config, embedding_func=embedding_func, ) self._driver = None async def initialize(self): URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687")) USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")) PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")) DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")) self._driver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD), ) self._DATABASE = DATABASE try: async with self._driver.session(database=DATABASE) as session: # Create index for base nodes on entity_id if it doesn't exist try: await session.run("""CREATE INDEX ON :base(entity_id)""") logger.info("Created index on :base(entity_id) in Memgraph.") except Exception as e: # Index may already exist, which is not an error logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}") await session.run("RETURN 1") logger.info(f"Connected to Memgraph at {URI}") except Exception as e: logger.error(f"Failed to connect to Memgraph at {URI}: {e}") raise async def finalize(self): if self._driver is not None: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): await self.finalize() async def index_done_callback(self): # Memgraph handles persistence automatically pass async def has_node(self, node_id: str) -> bool: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, ) single_result = await result.single() await result.consume() return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" result = await session.run(query, entity_id=node_id) records = await result.fetch(2) await result.consume() if records: node = records[0]["n"] node_dict = dict(node) if "labels" in node_dict: node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] return node_dict return None async def get_all_labels(self) -> list[str]: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ MATCH (n:base) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label """ result = await session.run(query) labels = [] async for record in result: labels.append(record["label"]) await result.consume() return labels async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-(connected:base) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected """ results = await session.run(query, entity_id=source_node_id) edges = [] async for record in results: source_node = record["n"] connected_node = record["connected"] if not source_node or not connected_node: continue source_label = source_node.get("entity_id") target_label = connected_node.get("entity_id") if source_label and target_label: edges.append((source_label, target_label)) await results.consume() return edges async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties """ result = await session.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, ) records = await result.fetch(2) await result.consume() if records: edge_result = dict(records[0]["edge_properties"]) for key, default_value in { "weight": 0.0, "source_id": None, "description": None, "keywords": None, }.items(): if key not in edge_result: edge_result[key] = default_value return edge_result return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: properties = node_data entity_type = properties.get("entity_type", "base") if "entity_id" not in properties: raise ValueError("Memgraph: node properties must contain an 'entity_id' field") async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = ( f""" MERGE (n:base {{entity_id: $entity_id}}) SET n += $properties SET n:`{entity_type}` """ ) result = await tx.run(query, entity_id=node_id, properties=properties) await result.consume() await session.execute_write(execute_upsert) async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id}) WITH source MATCH (target:base {entity_id: $target_entity_id}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target """ result = await tx.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, properties=edge_properties, ) await result.consume() await session.execute_write(execute_upsert) async def delete_node(self, node_id: str) -> None: async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) await result.consume() async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) async def remove_nodes(self, nodes: list[str]): for node in nodes: await self.delete_node(node) async def remove_edges(self, edges: list[tuple[str, str]]): for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) DELETE r """ result = await tx.run( query, source_entity_id=source, target_entity_id=target ) await result.consume() async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) async def drop(self) -> dict[str, str]: try: async with self._driver.session(database=self._DATABASE) as session: query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}") return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)} async def node_degree(self, node_id: str) -> int: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ result = await session.run(query, entity_id=node_id) record = await result.single() await result.consume() if not record: return 0 return record["degree"] async def edge_degree(self, src_id: str, tgt_id: str) -> int: src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree return int(src_degree) + int(trg_degree) async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (n:base) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) RETURN DISTINCT n """ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) nodes = [] async for record in result: node = record["n"] node_dict = dict(node) node_dict["id"] = node_dict.get("entity_id") nodes.append(node_dict) await result.consume() return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = """ UNWIND $chunk_ids AS chunk_id MATCH (a:base)-[r]-(b:base) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties """ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) edges = [] async for record in result: edge_properties = record["properties"] edge_properties["source"] = record["source"] edge_properties["target"] = record["target"] edges.append(edge_properties) await result.consume() return edges async def get_knowledge_graph( self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: result = KnowledgeGraph() seen_nodes = set() seen_edges = set() async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: if node_label == "*": count_query = "MATCH (n) RETURN count(n) as total" count_result = await session.run(count_query) count_record = await count_result.single() await count_result.consume() if count_record and count_record["total"] > max_nodes: result.is_truncated = True logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}") main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes WITH collect({node: n}) AS filtered_nodes UNWIND filtered_nodes AS node_info WITH collect(node_info.node) AS kept_nodes, filtered_nodes OPTIONAL MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes RETURN filtered_nodes AS node_info, collect(DISTINCT r) AS relationships """ result_set = await session.run(main_query, {"max_nodes": max_nodes}) record = await result_set.single() await result_set.consume() else: # BFS fallback for Memgraph (no APOC) from collections import deque # Get the starting node start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" node_result = await session.run(start_query, entity_id=node_label) node_record = await node_result.single() await node_result.consume() if not node_record: return result start_node = node_record["n"] start_node_id = start_node.get("entity_id") queue = deque([(start_node, 0)]) visited = set() bfs_nodes = [] while queue and len(bfs_nodes) < max_nodes: current_node, depth = queue.popleft() node_id = current_node.get("entity_id") if node_id in visited: continue visited.add(node_id) bfs_nodes.append(current_node) if depth < max_depth: # Get neighbors neighbor_query = """ MATCH (n:base {entity_id: $entity_id})-[]-(m:base) RETURN m """ neighbors_result = await session.run(neighbor_query, entity_id=node_id) neighbors = [rec["m"] for rec in await neighbors_result.to_list()] await neighbors_result.consume() for neighbor in neighbors: neighbor_id = neighbor.get("entity_id") if neighbor_id not in visited: queue.append((neighbor, depth + 1)) # Build subgraph subgraph_ids = [n.get("entity_id") for n in bfs_nodes] # Nodes for n in bfs_nodes: node_id = n.get("entity_id") if node_id not in seen_nodes: result.nodes.append(KnowledgeGraphNode( id=node_id, labels=[node_id], properties=dict(n), )) seen_nodes.add(node_id) # Edges if subgraph_ids: edge_query = """ MATCH (a:base)-[r]-(b:base) WHERE a.entity_id IN $ids AND b.entity_id IN $ids RETURN DISTINCT r, a, b """ edge_result = await session.run(edge_query, ids=subgraph_ids) async for record in edge_result: r = record["r"] a = record["a"] b = record["b"] edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}" if edge_id not in seen_edges: result.edges.append(KnowledgeGraphEdge( id=edge_id, type="DIRECTED", source=a.get("entity_id"), target=b.get("entity_id"), properties=dict(r), )) seen_edges.add(edge_id) await edge_result.consume() logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}") return result