Refactor: Extract retry decorator to reduce code duplication in Neo4J storage
• Define READ_RETRY_EXCEPTIONS constant
• Create reusable READ_RETRY decorator
• Replace 11 duplicate retry decorators
• Improve code maintainability
• Add missing retry to edge_degrees_batch
(cherry picked from commit 8c4d7a00ad)
This commit is contained in:
parent
b28a701532
commit
bd93f13012
1 changed files with 66 additions and 74 deletions
|
|
@ -16,8 +16,7 @@ import logging
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..kg.shared_storage import get_data_init_lock
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
|
|
@ -45,6 +44,23 @@ config.read("config.ini", "utf-8")
|
||||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
READ_RETRY_EXCEPTIONS = (
|
||||||
|
neo4jExceptions.ServiceUnavailable,
|
||||||
|
neo4jExceptions.TransientError,
|
||||||
|
neo4jExceptions.SessionExpired,
|
||||||
|
ConnectionResetError,
|
||||||
|
OSError,
|
||||||
|
AttributeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
READ_RETRY = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
@ -68,7 +84,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
||||||
def _get_workspace_label(self) -> str:
|
def _get_workspace_label(self) -> str:
|
||||||
"""Return workspace label (guaranteed non-empty during initialization)"""
|
"""Return workspace label (guaranteed non-empty during initialization)"""
|
||||||
return self._get_composite_workspace()
|
return self.workspace
|
||||||
|
|
||||||
def _is_chinese_text(self, text: str) -> bool:
|
def _is_chinese_text(self, text: str) -> bool:
|
||||||
"""Check if text contains Chinese characters."""
|
"""Check if text contains Chinese characters."""
|
||||||
|
|
@ -341,10 +357,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
"""Close the Neo4j driver and release all resources"""
|
"""Close the Neo4j driver and release all resources"""
|
||||||
async with get_graph_db_lock():
|
if self._driver:
|
||||||
if self._driver:
|
await self._driver.close()
|
||||||
await self._driver.close()
|
self._driver = None
|
||||||
self._driver = None
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
"""Ensure driver is closed when context manager exits"""
|
"""Ensure driver is closed when context manager exits"""
|
||||||
|
|
@ -354,6 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
# Neo4J handles persistence automatically
|
# Neo4J handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a node with the given label exists in the database
|
Check if a node with the given label exists in the database
|
||||||
|
|
@ -372,6 +388,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||||
result = await session.run(query, entity_id=node_id)
|
result = await session.run(query, entity_id=node_id)
|
||||||
|
|
@ -382,9 +399,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure results are consumed even on error
|
if result is not None:
|
||||||
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an edge exists between two nodes
|
Check if an edge exists between two nodes
|
||||||
|
|
@ -404,6 +423,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
||||||
|
|
@ -421,9 +441,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure results are consumed even on error
|
if result is not None:
|
||||||
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties
|
"""Get node by its label identifier, return only node properties
|
||||||
|
|
||||||
|
|
@ -477,6 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes in one query using UNWIND.
|
Retrieve multiple nodes in one query using UNWIND.
|
||||||
|
|
@ -513,6 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Make sure to consume the result fully
|
await result.consume() # Make sure to consume the result fully
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""Get the degree (number of relationships) of a node with the given label.
|
"""Get the degree (number of relationships) of a node with the given label.
|
||||||
If multiple nodes have the same label, returns the degree of the first node.
|
If multiple nodes have the same label, returns the degree of the first node.
|
||||||
|
|
@ -561,6 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||||
|
|
@ -619,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def edge_degrees_batch(
|
async def edge_degrees_batch(
|
||||||
self, edge_pairs: list[tuple[str, str]]
|
self, edge_pairs: list[tuple[str, str]]
|
||||||
) -> dict[tuple[str, str], int]:
|
) -> dict[tuple[str, str], int]:
|
||||||
|
|
@ -645,6 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||||
return edge_degrees
|
return edge_degrees
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
|
@ -732,6 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_edges_batch(
|
async def get_edges_batch(
|
||||||
self, pairs: list[dict[str, str]]
|
self, pairs: list[dict[str, str]]
|
||||||
) -> dict[tuple[str, str], dict]:
|
) -> dict[tuple[str, str], dict]:
|
||||||
|
|
@ -782,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
|
|
||||||
|
|
@ -800,6 +829,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
results = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||||
|
|
@ -837,7 +867,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await results.consume() # Ensure results are consumed even on error
|
if results is not None:
|
||||||
|
await (
|
||||||
|
results.consume()
|
||||||
|
) # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -845,6 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_nodes_edges_batch(
|
async def get_nodes_edges_batch(
|
||||||
self, node_ids: list[str]
|
self, node_ids: list[str]
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
|
|
@ -904,49 +938,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are fully consumed
|
await result.consume() # Ensure results are fully consumed
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (n:`{workspace_label}`)
|
|
||||||
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
|
||||||
RETURN DISTINCT n
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
nodes = []
|
|
||||||
async for record in result:
|
|
||||||
node = record["n"]
|
|
||||||
node_dict = dict(node)
|
|
||||||
# Add node id (entity_id) to the dictionary for easier access
|
|
||||||
node_dict["id"] = node_dict.get("entity_id")
|
|
||||||
nodes.append(node_dict)
|
|
||||||
await result.consume()
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = f"""
|
|
||||||
UNWIND $chunk_ids AS chunk_id
|
|
||||||
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
|
||||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
|
||||||
"""
|
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
|
||||||
edges = []
|
|
||||||
async for record in result:
|
|
||||||
edge_properties = record["properties"]
|
|
||||||
edge_properties["source"] = record["source"]
|
|
||||||
edge_properties["target"] = record["target"]
|
|
||||||
edges.append(edge_properties)
|
|
||||||
await result.consume()
|
|
||||||
return edges
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
|
@ -1636,6 +1627,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:`{workspace_label}`)
|
MATCH (n:`{workspace_label}`)
|
||||||
|
|
@ -1660,7 +1652,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume()
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
|
|
@ -1807,24 +1800,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
- On success: {"status": "success", "message": "workspace data dropped"}
|
- On success: {"status": "success", "message": "workspace data dropped"}
|
||||||
- On failure: {"status": "error", "message": "<error details>"}
|
- On failure: {"status": "error", "message": "<error details>"}
|
||||||
"""
|
"""
|
||||||
async with get_graph_db_lock():
|
workspace_label = self._get_workspace_label()
|
||||||
workspace_label = self._get_workspace_label()
|
try:
|
||||||
try:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
# Delete all nodes and relationships in current workspace only
|
||||||
# Delete all nodes and relationships in current workspace only
|
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
result = await session.run(query)
|
||||||
result = await session.run(query)
|
await result.consume() # Ensure result is fully consumed
|
||||||
await result.consume() # Ensure result is fully consumed
|
|
||||||
|
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
||||||
# )
|
# )
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"workspace '{workspace_label}' data dropped",
|
"message": f"workspace '{workspace_label}' data dropped",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue