Refactor: Extract retry decorator to reduce code duplication in Neo4J storage

• Define READ_RETRY_EXCEPTIONS constant
• Create reusable READ_RETRY decorator
• Replace 11 duplicate retry decorators
• Improve code maintainability
• Add missing retry to edge_degrees_batch
This commit is contained in:
yangdx 2025-11-25 01:35:21 +08:00
parent 7aaa51cda9
commit 8c4d7a00ad

View file

@ -44,6 +44,23 @@ config.read("config.ini", "utf-8")
logging.getLogger("neo4j").setLevel(logging.ERROR) logging.getLogger("neo4j").setLevel(logging.ERROR)
READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)
@final @final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@ -352,20 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically # Neo4J handles persistence automatically
pass pass
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
""" """
Check if a node with the given label exists in the database Check if a node with the given label exists in the database
@ -399,20 +403,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if an edge exists between two nodes Check if an edge exists between two nodes
@ -454,20 +445,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties
@ -521,20 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
@ -571,20 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@ -633,20 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -705,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
@READ_RETRY
async def edge_degrees_batch( async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]] self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]: ) -> dict[tuple[str, str], int]:
@ -731,20 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees return edge_degrees
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@ -832,20 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
@ -896,20 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() await result.consume()
return edges_dict return edges_dict
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@ -977,20 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@retry( @READ_RETRY
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
),
)
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]: