Replace tenacity retries with manual Memgraph transaction retries
- Implement manual retry logic - Add exponential backoff with jitter - Improve error handling for transient errors
This commit is contained in:
parent
99e58ac752
commit
9f5399c2f1
2 changed files with 144 additions and 86 deletions
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import final
|
from typing import final
|
||||||
import configparser
|
import configparser
|
||||||
|
|
@ -11,20 +13,11 @@ import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
pm.install("neo4j")
|
pm.install("neo4j")
|
||||||
if not pm.is_installed("tenacity"):
|
|
||||||
pm.install("tenacity")
|
|
||||||
|
|
||||||
from neo4j import (
|
from neo4j import (
|
||||||
AsyncGraphDatabase,
|
AsyncGraphDatabase,
|
||||||
AsyncManagedTransaction,
|
AsyncManagedTransaction,
|
||||||
)
|
)
|
||||||
from neo4j.exceptions import TransientError
|
from neo4j.exceptions import TransientError, ResultFailedError
|
||||||
from tenacity import (
|
|
||||||
retry,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
retry_if_exception_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
# Memgraph handles persistence automatically
|
# Memgraph handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(5),
|
|
||||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
|
||||||
retry=retry_if_exception_type(TransientError),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def _execute_write_with_retry(self, session, operation_func):
|
|
||||||
"""
|
|
||||||
Execute a write operation with retry logic for Memgraph transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Neo4j session
|
|
||||||
operation_func: Async function that takes a transaction and executes the operation
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TransientError: If all retry attempts fail
|
|
||||||
"""
|
|
||||||
return await session.execute_write(operation_func)
|
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a node exists in the graph.
|
Check if a node exists in the graph.
|
||||||
|
|
@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert a node in the Memgraph database with retry logic for transient errors.
|
Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: The unique identifier for the node (used as label)
|
node_id: The unique identifier for the node (used as label)
|
||||||
|
|
@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
"Memgraph: node properties must contain an 'entity_id' field"
|
"Memgraph: node properties must contain an 'entity_id' field"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# Manual transaction-level retry following official Memgraph documentation
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
max_retries = 100
|
||||||
workspace_label = self._get_workspace_label()
|
initial_wait_time = 0.2
|
||||||
|
backoff_factor = 1.1
|
||||||
|
jitter_factor = 0.1
|
||||||
|
|
||||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
for attempt in range(max_retries):
|
||||||
query = f"""
|
try:
|
||||||
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
logger.debug(
|
||||||
SET n += $properties
|
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
|
||||||
SET n:`{entity_type}`
|
)
|
||||||
"""
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
result = await tx.run(
|
workspace_label = self._get_workspace_label()
|
||||||
query, entity_id=node_id, properties=properties
|
|
||||||
)
|
|
||||||
await result.consume() # Ensure result is fully consumed
|
|
||||||
|
|
||||||
await self._execute_write_with_retry(session, execute_upsert)
|
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||||
except TransientError as e:
|
query = f"""
|
||||||
logger.error(
|
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||||
f"Memgraph transient error during node upsert after retries: {str(e)}"
|
SET n += $properties
|
||||||
)
|
SET n:`{entity_type}`
|
||||||
raise
|
"""
|
||||||
except Exception as e:
|
result = await tx.run(
|
||||||
logger.error(f"Error during node upsert: {str(e)}")
|
query, entity_id=node_id, properties=properties
|
||||||
raise
|
)
|
||||||
|
await result.consume() # Ensure result is fully consumed
|
||||||
|
|
||||||
|
await session.execute_write(execute_upsert)
|
||||||
|
break # Success - exit retry loop
|
||||||
|
|
||||||
|
except (TransientError, ResultFailedError) as e:
|
||||||
|
# Check if the root cause is a TransientError
|
||||||
|
root_cause = e
|
||||||
|
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
|
||||||
|
root_cause = root_cause.__cause__
|
||||||
|
|
||||||
|
# Check if this is a transient error that should be retried
|
||||||
|
is_transient = (
|
||||||
|
isinstance(root_cause, TransientError)
|
||||||
|
or isinstance(e, TransientError)
|
||||||
|
or "TransientError" in str(e)
|
||||||
|
or "Cannot resolve conflicting transactions" in str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_transient:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
# Calculate wait time with exponential backoff and jitter
|
||||||
|
jitter = random.uniform(0, jitter_factor) * initial_wait_time
|
||||||
|
wait_time = (
|
||||||
|
initial_wait_time * (backoff_factor**attempt) + jitter
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Non-transient error, don't retry
|
||||||
|
logger.error(f"Non-transient error during node upsert: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during node upsert: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors.
|
Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
|
||||||
Ensures both source and target nodes exist and are unique before creating the edge.
|
Ensures both source and target nodes exist and are unique before creating the edge.
|
||||||
Uses entity_id property to uniquely identify nodes.
|
Uses entity_id property to uniquely identify nodes.
|
||||||
|
|
||||||
|
|
@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
edge_properties = edge_data
|
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
|
||||||
|
|
||||||
async def execute_upsert(tx: AsyncManagedTransaction):
|
edge_properties = edge_data
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
query = f"""
|
|
||||||
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
|
||||||
WITH source
|
|
||||||
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
|
||||||
MERGE (source)-[r:DIRECTED]-(target)
|
|
||||||
SET r += $properties
|
|
||||||
RETURN r, source, target
|
|
||||||
"""
|
|
||||||
result = await tx.run(
|
|
||||||
query,
|
|
||||||
source_entity_id=source_node_id,
|
|
||||||
target_entity_id=target_node_id,
|
|
||||||
properties=edge_properties,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await result.fetch(2)
|
|
||||||
finally:
|
|
||||||
await result.consume() # Ensure result is consumed
|
|
||||||
|
|
||||||
await self._execute_write_with_retry(session, execute_upsert)
|
# Manual transaction-level retry following official Memgraph documentation
|
||||||
except TransientError as e:
|
max_retries = 100
|
||||||
logger.error(
|
initial_wait_time = 0.2
|
||||||
f"Memgraph transient error during edge upsert after retries: {str(e)}"
|
backoff_factor = 1.1
|
||||||
)
|
jitter_factor = 0.1
|
||||||
raise
|
|
||||||
except Exception as e:
|
for attempt in range(max_retries):
|
||||||
logger.error(f"Error during edge upsert: {str(e)}")
|
try:
|
||||||
raise
|
logger.debug(
|
||||||
|
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
|
||||||
|
)
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
|
||||||
|
async def execute_upsert(tx: AsyncManagedTransaction):
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
query = f"""
|
||||||
|
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
|
||||||
|
WITH source
|
||||||
|
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
|
||||||
|
MERGE (source)-[r:DIRECTED]-(target)
|
||||||
|
SET r += $properties
|
||||||
|
RETURN r, source, target
|
||||||
|
"""
|
||||||
|
result = await tx.run(
|
||||||
|
query,
|
||||||
|
source_entity_id=source_node_id,
|
||||||
|
target_entity_id=target_node_id,
|
||||||
|
properties=edge_properties,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await result.fetch(2)
|
||||||
|
finally:
|
||||||
|
await result.consume() # Ensure result is consumed
|
||||||
|
|
||||||
|
await session.execute_write(execute_upsert)
|
||||||
|
break # Success - exit retry loop
|
||||||
|
|
||||||
|
except (TransientError, ResultFailedError) as e:
|
||||||
|
# Check if the root cause is a TransientError
|
||||||
|
root_cause = e
|
||||||
|
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
|
||||||
|
root_cause = root_cause.__cause__
|
||||||
|
|
||||||
|
# Check if this is a transient error that should be retried
|
||||||
|
is_transient = (
|
||||||
|
isinstance(root_cause, TransientError)
|
||||||
|
or isinstance(e, TransientError)
|
||||||
|
or "TransientError" in str(e)
|
||||||
|
or "Cannot resolve conflicting transactions" in str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_transient:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
# Calculate wait time with exponential backoff and jitter
|
||||||
|
jitter = random.uniform(0, jitter_factor) * initial_wait_time
|
||||||
|
wait_time = (
|
||||||
|
initial_wait_time * (backoff_factor**attempt) + jitter
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Non-transient error, don't retry
|
||||||
|
logger.error(f"Non-transient error during edge upsert: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during edge upsert: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""Delete a node with the specified label
|
"""Delete a node with the specified label
|
||||||
|
|
|
||||||
|
|
@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges(
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
# Sort the edge_key components to ensure consistent lock key generation
|
# Sort the edge_key components to ensure consistent lock key generation
|
||||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||||
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
|
# logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
|
||||||
async with get_storage_keyed_lock(
|
async with get_storage_keyed_lock(
|
||||||
f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
|
f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue