Add retry decorators to Neo4j read operations for resilience

(cherry picked from commit 7aaa51cda9)
This commit is contained in:
yangdx 2025-11-24 22:28:15 +08:00 committed by Raphaël MANSUY
parent 8bb23b9bfa
commit 7ce3680ca5

View file

@ -44,23 +44,6 @@ config.read("config.ini", "utf-8")
logging.getLogger("neo4j").setLevel(logging.ERROR)
READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
@ -369,7 +352,20 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically
pass
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
@ -403,7 +399,20 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
@ -445,7 +454,20 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
@ -499,7 +521,20 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
@ -536,7 +571,20 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully
return nodes
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
@ -585,7 +633,20 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -644,7 +705,6 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
@READ_RETRY
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
@ -671,7 +731,20 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
@ -759,7 +832,20 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
@ -810,7 +896,20 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume()
return edges_dict
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
@ -878,7 +977,20 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@READ_RETRY
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: