Replace tenacity retries with manual Memgraph transaction retries
- Implement manual retry logic - Add exponential backoff with jitter - Improve error handling for transient errors
This commit is contained in:
parent
99e58ac752
commit
9f5399c2f1
2 changed files with 144 additions and 86 deletions
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import asyncio
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import final
|
||||
import configparser
|
||||
|
|
@ -11,20 +13,11 @@ import pipmaster as pm
|
|||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
AsyncManagedTransaction,
|
||||
)
|
||||
from neo4j.exceptions import TransientError
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from neo4j.exceptions import TransientError, ResultFailedError
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
|
@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
# Memgraph handles persistence automatically
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
retry=retry_if_exception_type(TransientError),
|
||||
reraise=True,
|
||||
)
|
||||
async def _execute_write_with_retry(self, session, operation_func):
|
||||
"""
|
||||
Execute a write operation with retry logic for Memgraph transient errors.
|
||||
|
||||
Args:
|
||||
session: Neo4j session
|
||||
operation_func: Async function that takes a transaction and executes the operation
|
||||
|
||||
Raises:
|
||||
TransientError: If all retry attempts fail
|
||||
"""
|
||||
return await session.execute_write(operation_func)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node exists in the graph.
|
||||
|
|
@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Upsert a node in the Memgraph database with retry logic for transient errors.
|
||||
Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
|
|
@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
"Memgraph: node properties must contain an 'entity_id' field"
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
workspace_label = self._get_workspace_label()
|
||||
# Manual transaction-level retry following official Memgraph documentation
|
||||
max_retries = 100
|
||||
initial_wait_time = 0.2
|
||||
backoff_factor = 1.1
|
||||
jitter_factor = 0.1
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
SET n += $properties
|
||||
SET n:`{entity_type}`
|
||||
"""
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.debug(
|
||||
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
workspace_label = self._get_workspace_label()
|
||||
|
||||
await self._execute_write_with_retry(session, execute_upsert)
|
||||
except TransientError as e:
|
||||
logger.error(
|
||||
f"Memgraph transient error during node upsert after retries: {str(e)}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during node upsert: {str(e)}")
|
||||
raise
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
SET n += $properties
|
||||
SET n:`{entity_type}`
|
||||
"""
|
||||
result = await tx.run(
|
||||
query, entity_id=node_id, properties=properties
|
||||
)
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
await session.execute_write(execute_upsert)
|
||||
break # Success - exit retry loop
|
||||
|
||||
except (TransientError, ResultFailedError) as e:
|
||||
# Check if the root cause is a TransientError
|
||||
root_cause = e
|
||||
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
|
||||
root_cause = root_cause.__cause__
|
||||
|
||||
# Check if this is a transient error that should be retried
|
||||
is_transient = (
|
||||
isinstance(root_cause, TransientError)
|
||||
or isinstance(e, TransientError)
|
||||
or "TransientError" in str(e)
|
||||
or "Cannot resolve conflicting transactions" in str(e)
|
||||
)
|
||||
|
||||
if is_transient:
|
||||
if attempt < max_retries - 1:
|
||||
# Calculate wait time with exponential backoff and jitter
|
||||
jitter = random.uniform(0, jitter_factor) * initial_wait_time
|
||||
wait_time = (
|
||||
initial_wait_time * (backoff_factor**attempt) + jitter
|
||||
)
|
||||
logger.warning(
|
||||
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Non-transient error, don't retry
|
||||
logger.error(f"Non-transient error during node upsert: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during node upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors.
|
||||
Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
|
||||
Ensures both source and target nodes exist and are unique before creating the edge.
|
||||
Uses entity_id property to uniquely identify nodes.
|
||||
|
||||
|
|
@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
raise RuntimeError(
|
||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||
)
|
||||
try:
|
||||
edge_properties = edge_data
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
||||
WITH source
|
||||
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += $properties
|
||||
RETURN r, source, target
|
||||
"""
|
||||
result = await tx.run(
|
||||
query,
|
||||
source_entity_id=source_node_id,
|
||||
target_entity_id=target_node_id,
|
||||
properties=edge_properties,
|
||||
)
|
||||
try:
|
||||
await result.fetch(2)
|
||||
finally:
|
||||
await result.consume() # Ensure result is consumed
|
||||
edge_properties = edge_data
|
||||
|
||||
await self._execute_write_with_retry(session, execute_upsert)
|
||||
except TransientError as e:
|
||||
logger.error(
|
||||
f"Memgraph transient error during edge upsert after retries: {str(e)}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge upsert: {str(e)}")
|
||||
raise
|
||||
# Manual transaction-level retry following official Memgraph documentation
|
||||
max_retries = 100
|
||||
initial_wait_time = 0.2
|
||||
backoff_factor = 1.1
|
||||
jitter_factor = 0.1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.debug(
|
||||
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
|
||||
)
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
|
||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
||||
WITH source
|
||||
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||
MERGE (source)-[r:DIRECTED]-(target)
|
||||
SET r += $properties
|
||||
RETURN r, source, target
|
||||
"""
|
||||
result = await tx.run(
|
||||
query,
|
||||
source_entity_id=source_node_id,
|
||||
target_entity_id=target_node_id,
|
||||
properties=edge_properties,
|
||||
)
|
||||
try:
|
||||
await result.fetch(2)
|
||||
finally:
|
||||
await result.consume() # Ensure result is consumed
|
||||
|
||||
await session.execute_write(execute_upsert)
|
||||
break # Success - exit retry loop
|
||||
|
||||
except (TransientError, ResultFailedError) as e:
|
||||
# Check if the root cause is a TransientError
|
||||
root_cause = e
|
||||
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
|
||||
root_cause = root_cause.__cause__
|
||||
|
||||
# Check if this is a transient error that should be retried
|
||||
is_transient = (
|
||||
isinstance(root_cause, TransientError)
|
||||
or isinstance(e, TransientError)
|
||||
or "TransientError" in str(e)
|
||||
or "Cannot resolve conflicting transactions" in str(e)
|
||||
)
|
||||
|
||||
if is_transient:
|
||||
if attempt < max_retries - 1:
|
||||
# Calculate wait time with exponential backoff and jitter
|
||||
jitter = random.uniform(0, jitter_factor) * initial_wait_time
|
||||
wait_time = (
|
||||
initial_wait_time * (backoff_factor**attempt) + jitter
|
||||
)
|
||||
logger.warning(
|
||||
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Non-transient error, don't retry
|
||||
logger.error(f"Non-transient error during edge upsert: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during edge upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node with the specified label
|
||||
|
|
|
|||
|
|
@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges(
|
|||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||
# Sort the edge_key components to ensure consistent lock key generation
|
||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
|
||||
# logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
|
||||
async with get_storage_keyed_lock(
|
||||
f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
|
||||
namespace=namespace,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue