Replace tenacity retries with manual Memgraph transaction retries

- Implement manual retry logic
- Add exponential backoff with jitter
- Improve error handling for transient errors
This commit is contained in:
yangdx 2025-07-19 11:31:21 +08:00
parent 99e58ac752
commit 9f5399c2f1
2 changed files with 144 additions and 86 deletions

View file

@ -1,4 +1,6 @@
import os
import asyncio
import random
from dataclasses import dataclass
from typing import final
import configparser
@ -11,20 +13,11 @@ import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from neo4j import (
AsyncGraphDatabase,
AsyncManagedTransaction,
)
from neo4j.exceptions import TransientError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from neo4j.exceptions import TransientError, ResultFailedError
from dotenv import load_dotenv
@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage):
# Memgraph handles persistence automatically
pass
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type(TransientError),
reraise=True,
)
async def _execute_write_with_retry(self, session, operation_func):
"""
Execute a write operation with retry logic for Memgraph transient errors.
Args:
session: Neo4j session
operation_func: Async function that takes a transaction and executes the operation
Raises:
TransientError: If all retry attempts fail
"""
return await session.execute_write(operation_func)
async def has_node(self, node_id: str) -> bool:
"""
Check if a node exists in the graph.
@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Memgraph database with retry logic for transient errors.
Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
Args:
node_id: The unique identifier for the node (used as label)
@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph: node properties must contain an 'entity_id' field"
)
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
# Manual transaction-level retry following official Memgraph documentation
max_retries = 100
initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
async def execute_upsert(tx: AsyncManagedTransaction):
query = f"""
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
SET n += $properties
SET n:`{entity_type}`
"""
result = await tx.run(
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
await self._execute_write_with_retry(session, execute_upsert)
except TransientError as e:
logger.error(
f"Memgraph transient error during node upsert after retries: {str(e)}"
)
raise
except Exception as e:
logger.error(f"Error during node upsert: {str(e)}")
raise
async def execute_upsert(tx: AsyncManagedTransaction):
query = f"""
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
SET n += $properties
SET n:`{entity_type}`
"""
result = await tx.run(
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during node upsert: {str(e)}")
raise
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors.
Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
edge_properties = edge_data
await self._execute_write_with_retry(session, execute_upsert)
except TransientError as e:
logger.error(
f"Memgraph transient error during edge upsert after retries: {str(e)}"
)
raise
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
# Manual transaction-level retry following official Memgraph documentation
max_retries = 100
initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during edge upsert: {str(e)}")
raise
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label

View file

@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges(
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort the edge_key components to ensure consistent lock key generation
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
# logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
async with get_storage_keyed_lock(
f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
namespace=namespace,