diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 4e337fe9..38320643 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,8 +16,7 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -45,6 +44,23 @@ config.read("config.ini", "utf-8") logging.getLogger("neo4j").setLevel(logging.ERROR) +READ_RETRY_EXCEPTIONS = ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, +) + +READ_RETRY = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS), + reraise=True, +) + + @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -68,7 +84,7 @@ class Neo4JStorage(BaseGraphStorage): def _get_workspace_label(self) -> str: """Return workspace label (guaranteed non-empty during initialization)""" - return self._get_composite_workspace() + return self.workspace def _is_chinese_text(self, text: str) -> bool: """Check if text contains Chinese characters.""" @@ -341,10 +357,9 @@ class Neo4JStorage(BaseGraphStorage): async def finalize(self): """Close the Neo4j driver and release all resources""" - async with get_graph_db_lock(): - if self._driver: - await self._driver.close() - self._driver = None + if self._driver: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" @@ -354,6 +369,7 @@ class Neo4JStorage(BaseGraphStorage): # Neo4J handles persistence automatically pass + @READ_RETRY async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -372,6 +388,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) @@ -382,9 +399,11 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" ) - await result.consume() # Ensure results are consumed even on error + if result is not None: + await result.consume() # Ensure results are consumed even on error raise + @READ_RETRY async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes @@ -404,6 +423,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = ( f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " @@ -421,9 +441,11 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) - await result.consume() # Ensure results are consumed even on error + if result is not None: + await result.consume() # Ensure results are consumed even on error raise + @READ_RETRY async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -477,6 +499,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @READ_RETRY async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. @@ -513,6 +536,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Make sure to consume the result fully return nodes + @READ_RETRY async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -561,6 +585,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @READ_RETRY async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. @@ -619,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees + @READ_RETRY async def edge_degrees_batch( self, edge_pairs: list[tuple[str, str]] ) -> dict[tuple[str, str], int]: @@ -645,6 +671,7 @@ class Neo4JStorage(BaseGraphStorage): edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees + @READ_RETRY async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -732,6 +759,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @READ_RETRY async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: @@ -782,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() return edges_dict + @READ_RETRY async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -800,6 +829,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + results = None try: workspace_label = self._get_workspace_label() query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) @@ -837,7 +867,10 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) - await results.consume() # Ensure results are consumed even on error + if results is not None: + await ( + results.consume() + ) # Ensure results are consumed even on error raise except Exception as e: logger.error( @@ -845,6 +878,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise + @READ_RETRY async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: @@ -904,49 +938,6 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are fully consumed return edges_dict - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (n:`{workspace_label}`) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) - RETURN DISTINCT n - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - nodes = [] - async for record in result: - node = record["n"] - node_dict = dict(node) - # Add node id (entity_id) to the dictionary for easier access - node_dict["id"] = node_dict.get("entity_id") - nodes.append(node_dict) - await result.consume() - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - workspace_label = self._get_workspace_label() - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - UNWIND $chunk_ids AS chunk_id - MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties - """ - result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) - edges = [] - async for record in result: - edge_properties = record["properties"] - edge_properties["source"] = record["source"] - edge_properties["target"] = record["target"] - edges.append(edge_properties) - await result.consume() - return edges - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -1636,6 +1627,7 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + result = None try: query = f""" MATCH (n:`{workspace_label}`) @@ -1660,7 +1652,8 @@ class Neo4JStorage(BaseGraphStorage): logger.error( f"[{self.workspace}] Error getting popular labels: {str(e)}" ) - await result.consume() + if result is not None: + await result.consume() raise async def search_labels(self, query: str, limit: int = 50) -> list[str]: @@ -1807,24 +1800,23 @@ class Neo4JStorage(BaseGraphStorage): - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_graph_db_lock(): - workspace_label = self._get_workspace_label() - try: - async with self._driver.session(database=self._DATABASE) as session: - # Delete all nodes and relationships in current workspace only - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() # Ensure result is fully consumed + workspace_label = self._get_workspace_label() + try: + async with self._driver.session(database=self._DATABASE) as session: + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() # Ensure result is fully consumed - # logger.debug( - # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" - # ) - return { - "status": "success", - "message": f"workspace '{workspace_label}' data dropped", - } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + # logger.debug( + # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" + # ) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)}