diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d3d6c4eb..38320643 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -44,6 +44,23 @@ config.read("config.ini", "utf-8") logging.getLogger("neo4j").setLevel(logging.ERROR) +READ_RETRY_EXCEPTIONS = ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, +) + +READ_RETRY = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS), + reraise=True, +) + + @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -352,20 +369,7 @@ class Neo4JStorage(BaseGraphStorage): # Neo4J handles persistence automatically pass - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -399,20 +403,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes @@ -454,20 +445,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -521,20 +499,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. @@ -571,20 +536,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Make sure to consume the result fully return nodes - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -633,20 +585,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. @@ -705,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees + @READ_RETRY async def edge_degrees_batch( self, edge_pairs: list[tuple[str, str]] ) -> dict[tuple[str, str], int]: @@ -731,20 +671,7 @@ class Neo4JStorage(BaseGraphStorage): edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -832,20 +759,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: @@ -896,20 +810,7 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() return edges_dict - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -977,20 +878,7 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, - ) - ), - ) + @READ_RETRY async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: