diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 38320643..d3d6c4eb 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -44,23 +44,6 @@ config.read("config.ini", "utf-8") logging.getLogger("neo4j").setLevel(logging.ERROR) -READ_RETRY_EXCEPTIONS = ( - neo4jExceptions.ServiceUnavailable, - neo4jExceptions.TransientError, - neo4jExceptions.SessionExpired, - ConnectionResetError, - OSError, - AttributeError, -) - -READ_RETRY = retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS), - reraise=True, -) - - @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -369,7 +352,20 @@ class Neo4JStorage(BaseGraphStorage): # Neo4J handles persistence automatically pass - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -403,7 +399,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes @@ -445,7 +454,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure results are consumed even on error raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties @@ -499,7 +521,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. @@ -536,7 +571,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Make sure to consume the result fully return nodes - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -585,7 +633,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. @@ -644,7 +705,6 @@ class Neo4JStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees - @READ_RETRY async def edge_degrees_batch( self, edge_pairs: list[tuple[str, str]] ) -> dict[tuple[str, str], int]: @@ -671,7 +731,20 @@ class Neo4JStorage(BaseGraphStorage): edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -759,7 +832,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: @@ -810,7 +896,20 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() return edges_dict - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -878,7 +977,20 @@ class Neo4JStorage(BaseGraphStorage): ) raise - @READ_RETRY + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.SessionExpired, + ConnectionResetError, + OSError, + AttributeError, + ) + ), + ) async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: