Refactor: Extract retry decorator to reduce code duplication in Neo4J storage
• Define READ_RETRY_EXCEPTIONS constant • Create reusable READ_RETRY decorator • Replace 11 duplicate retry decorators • Improve code maintainability • Add missing retry to edge_degrees_batch
This commit is contained in:
parent
7aaa51cda9
commit
8c4d7a00ad
1 changed files with 28 additions and 140 deletions
|
|
@ -44,6 +44,23 @@ config.read("config.ini", "utf-8")
|
||||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
READ_RETRY_EXCEPTIONS = (
|
||||||
|
neo4jExceptions.ServiceUnavailable,
|
||||||
|
neo4jExceptions.TransientError,
|
||||||
|
neo4jExceptions.SessionExpired,
|
||||||
|
ConnectionResetError,
|
||||||
|
OSError,
|
||||||
|
AttributeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
READ_RETRY = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
@ -352,20 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
# Neo4J handles persistence automatically
|
# Neo4J handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a node with the given label exists in the database
|
Check if a node with the given label exists in the database
|
||||||
|
|
@ -399,20 +403,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are consumed even on error
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an edge exists between two nodes
|
Check if an edge exists between two nodes
|
||||||
|
|
@ -454,20 +445,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are consumed even on error
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties
|
"""Get node by its label identifier, return only node properties
|
||||||
|
|
||||||
|
|
@ -521,20 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes in one query using UNWIND.
|
Retrieve multiple nodes in one query using UNWIND.
|
||||||
|
|
@ -571,20 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Make sure to consume the result fully
|
await result.consume() # Make sure to consume the result fully
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""Get the degree (number of relationships) of a node with the given label.
|
"""Get the degree (number of relationships) of a node with the given label.
|
||||||
If multiple nodes have the same label, returns the degree of the first node.
|
If multiple nodes have the same label, returns the degree of the first node.
|
||||||
|
|
@ -633,20 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||||
|
|
@ -705,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def edge_degrees_batch(
|
async def edge_degrees_batch(
|
||||||
self, edge_pairs: list[tuple[str, str]]
|
self, edge_pairs: list[tuple[str, str]]
|
||||||
) -> dict[tuple[str, str], int]:
|
) -> dict[tuple[str, str], int]:
|
||||||
|
|
@ -731,20 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||||
return edge_degrees
|
return edge_degrees
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
|
@ -832,20 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_edges_batch(
|
async def get_edges_batch(
|
||||||
self, pairs: list[dict[str, str]]
|
self, pairs: list[dict[str, str]]
|
||||||
) -> dict[tuple[str, str], dict]:
|
) -> dict[tuple[str, str], dict]:
|
||||||
|
|
@ -896,20 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
|
|
||||||
|
|
@ -977,20 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@READ_RETRY
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(
|
|
||||||
neo4jExceptions.ServiceUnavailable,
|
|
||||||
neo4jExceptions.TransientError,
|
|
||||||
neo4jExceptions.SessionExpired,
|
|
||||||
ConnectionResetError,
|
|
||||||
OSError,
|
|
||||||
AttributeError,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async def get_nodes_edges_batch(
|
async def get_nodes_edges_batch(
|
||||||
self, node_ids: list[str]
|
self, node_ids: list[str]
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue