Replace tenacity retries with manual Memgraph transaction retries

- Implement manual retry logic
- Add exponential backoff with jitter
- Improve error handling for transient errors
This commit is contained in:
yangdx 2025-07-19 11:31:21 +08:00
parent 99e58ac752
commit 9f5399c2f1
2 changed files with 144 additions and 86 deletions

View file

@ -1,4 +1,6 @@
import os import os
import asyncio
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final
import configparser import configparser
@ -11,20 +13,11 @@ import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from neo4j import ( from neo4j import (
AsyncGraphDatabase, AsyncGraphDatabase,
AsyncManagedTransaction, AsyncManagedTransaction,
) )
from neo4j.exceptions import TransientError from neo4j.exceptions import TransientError, ResultFailedError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from dotenv import load_dotenv from dotenv import load_dotenv
@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage):
# Memgraph handles persistence automatically # Memgraph handles persistence automatically
pass pass
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type(TransientError),
reraise=True,
)
async def _execute_write_with_retry(self, session, operation_func):
"""
Execute a write operation with retry logic for Memgraph transient errors.
Args:
session: Neo4j session
operation_func: Async function that takes a transaction and executes the operation
Raises:
TransientError: If all retry attempts fail
"""
return await session.execute_write(operation_func)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
""" """
Check if a node exists in the graph. Check if a node exists in the graph.
@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Memgraph database with retry logic for transient errors. Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
Args: Args:
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph: node properties must contain an 'entity_id' field" "Memgraph: node properties must contain an 'entity_id' field"
) )
try: # Manual transaction-level retry following official Memgraph documentation
async with self._driver.session(database=self._DATABASE) as session: max_retries = 100
workspace_label = self._get_workspace_label() initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
async def execute_upsert(tx: AsyncManagedTransaction): for attempt in range(max_retries):
query = f""" try:
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) logger.debug(
SET n += $properties f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
SET n:`{entity_type}` )
""" async with self._driver.session(database=self._DATABASE) as session:
result = await tx.run( workspace_label = self._get_workspace_label()
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await self._execute_write_with_retry(session, execute_upsert) async def execute_upsert(tx: AsyncManagedTransaction):
except TransientError as e: query = f"""
logger.error( MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
f"Memgraph transient error during node upsert after retries: {str(e)}" SET n += $properties
) SET n:`{entity_type}`
raise """
except Exception as e: result = await tx.run(
logger.error(f"Error during node upsert: {str(e)}") query, entity_id=node_id, properties=properties
raise )
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during node upsert: {str(e)}")
raise
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors. Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
Ensures both source and target nodes exist and are unique before creating the edge. Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes. Uses entity_id property to uniquely identify nodes.
@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): edge_properties = edge_data
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await self._execute_write_with_retry(session, execute_upsert) # Manual transaction-level retry following official Memgraph documentation
except TransientError as e: max_retries = 100
logger.error( initial_wait_time = 0.2
f"Memgraph transient error during edge upsert after retries: {str(e)}" backoff_factor = 1.1
) jitter_factor = 0.1
raise
except Exception as e: for attempt in range(max_retries):
logger.error(f"Error during edge upsert: {str(e)}") try:
raise logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during edge upsert: {str(e)}")
raise
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label """Delete a node with the specified label

View file

@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges(
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort the edge_key components to ensure consistent lock key generation # Sort the edge_key components to ensure consistent lock key generation
sorted_edge_key = sorted([edge_key[0], edge_key[1]]) sorted_edge_key = sorted([edge_key[0], edge_key[1]])
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
async with get_storage_keyed_lock( async with get_storage_keyed_lock(
f"{sorted_edge_key[0]}-{sorted_edge_key[1]}", f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
namespace=namespace, namespace=namespace,