Refactor: Extract retry decorator to reduce code duplication in Neo4J storage

• Define READ_RETRY_EXCEPTIONS constant
• Create reusable READ_RETRY decorator
• Replace 11 duplicate retry decorators
• Improve code maintainability
• Add missing retry to edge_degrees_batch

(cherry picked from commit 8c4d7a00ad)
This commit is contained in:
yangdx 2025-11-25 01:35:21 +08:00 committed by Raphaël MANSUY
parent b28a701532
commit bd93f13012

View file

@ -16,8 +16,7 @@ import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -45,6 +44,23 @@ config.read("config.ini", "utf-8")
logging.getLogger("neo4j").setLevel(logging.ERROR) logging.getLogger("neo4j").setLevel(logging.ERROR)
READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)
@final @final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@ -68,7 +84,7 @@ class Neo4JStorage(BaseGraphStorage):
def _get_workspace_label(self) -> str: def _get_workspace_label(self) -> str:
"""Return workspace label (guaranteed non-empty during initialization)""" """Return workspace label (guaranteed non-empty during initialization)"""
return self._get_composite_workspace() return self.workspace
def _is_chinese_text(self, text: str) -> bool: def _is_chinese_text(self, text: str) -> bool:
"""Check if text contains Chinese characters.""" """Check if text contains Chinese characters."""
@ -341,10 +357,9 @@ class Neo4JStorage(BaseGraphStorage):
async def finalize(self): async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
async with get_graph_db_lock(): if self._driver:
if self._driver: await self._driver.close()
await self._driver.close() self._driver = None
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
@ -354,6 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically # Neo4J handles persistence automatically
pass pass
@READ_RETRY
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
""" """
Check if a node with the given label exists in the database Check if a node with the given label exists in the database
@ -372,6 +388,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id) result = await session.run(query, entity_id=node_id)
@ -382,9 +399,11 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
) )
await result.consume() # Ensure results are consumed even on error if result is not None:
await result.consume() # Ensure results are consumed even on error
raise raise
@READ_RETRY
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if an edge exists between two nodes Check if an edge exists between two nodes
@ -404,6 +423,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = ( query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
@ -421,9 +441,11 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure results are consumed even on error if result is not None:
await result.consume() # Ensure results are consumed even on error
raise raise
@READ_RETRY
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties
@ -477,6 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
@ -513,6 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@READ_RETRY
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@ -561,6 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -619,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
@READ_RETRY
async def edge_degrees_batch( async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]] self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]: ) -> dict[tuple[str, str], int]:
@ -645,6 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees return edge_degrees
@READ_RETRY
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@ -732,6 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
@ -782,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() await result.consume()
return edges_dict return edges_dict
@READ_RETRY
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@ -800,6 +829,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = None
try: try:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -837,7 +867,10 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -845,6 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
@ -904,49 +938,6 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are fully consumed await result.consume() # Ensure results are fully consumed
return edges_dict return edges_dict
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
# Add node id (entity_id) to the dictionary for easier access
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@ -1636,6 +1627,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
result = None
try: try:
query = f""" query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
@ -1660,7 +1652,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error( logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}" f"[{self.workspace}] Error getting popular labels: {str(e)}"
) )
await result.consume() if result is not None:
await result.consume()
raise raise
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1807,24 +1800,23 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"} - On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_graph_db_lock(): workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() try:
try: async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE) as session: # Delete all nodes and relationships in current workspace only
# Delete all nodes and relationships in current workspace only query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query)
result = await session.run(query) await result.consume() # Ensure result is fully consumed
await result.consume() # Ensure result is fully consumed
# logger.debug( # logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# ) # )
return { return {
"status": "success", "status": "success",
"message": f"workspace '{workspace_label}' data dropped", "message": f"workspace '{workspace_label}' data dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}