Refactor: Extract retry decorator to reduce code duplication in Neo4J storage

• Define READ_RETRY_EXCEPTIONS constant
• Create reusable READ_RETRY decorator
• Replace 11 duplicate retry decorators
• Improve code maintainability
• Add missing retry to edge_degrees_batch
This commit is contained in:
yangdx 2025-11-25 01:35:21 +08:00
parent 7aaa51cda9
commit 8c4d7a00ad

View file

@ -44,6 +44,23 @@ config.read("config.ini", "utf-8")
logging.getLogger("neo4j").setLevel(logging.ERROR)
READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
@ -352,20 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically
pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
@ -399,20 +403,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
@ -454,20 +445,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
@ -521,20 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
@ -571,20 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully
return nodes
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
@ -633,20 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -705,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
@READ_RETRY
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
@ -731,20 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
@ -832,20 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
@ -896,20 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume()
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
@ -977,20 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
@READ_RETRY
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: