Merge branch 'fix-memgraph' into fix-keyed-lock

This commit is contained in:
yangdx 2025-07-19 11:55:24 +08:00
commit cba97c62fe
2 changed files with 144 additions and 48 deletions

View file

@ -1,4 +1,6 @@
import os import os
import asyncio
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final
import configparser import configparser
@ -11,11 +13,11 @@ import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
from neo4j import ( from neo4j import (
AsyncGraphDatabase, AsyncGraphDatabase,
AsyncManagedTransaction, AsyncManagedTransaction,
) )
from neo4j.exceptions import TransientError, ResultFailedError
from dotenv import load_dotenv from dotenv import load_dotenv
@ -435,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Memgraph database. Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
Args: Args:
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
@ -452,7 +454,17 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph: node properties must contain an 'entity_id' field" "Memgraph: node properties must contain an 'entity_id' field"
) )
# Manual transaction-level retry following official Memgraph documentation
max_retries = 100
initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
for attempt in range(max_retries):
try: try:
logger.debug(
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
@ -468,15 +480,51 @@ class MemgraphStorage(BaseGraphStorage):
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}")
raise
except Exception as e: except Exception as e:
logger.error(f"Error during upsert: {str(e)}") logger.error(f"Unexpected error during node upsert: {str(e)}")
raise raise
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
Ensures both source and target nodes exist and are unique before creating the edge. Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes. Uses entity_id property to uniquely identify nodes.
@ -492,8 +540,20 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try:
edge_properties = edge_data edge_properties = edge_data
# Manual transaction-level retry following official Memgraph documentation
max_retries = 100
initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): async def execute_upsert(tx: AsyncManagedTransaction):
@ -518,8 +578,44 @@ class MemgraphStorage(BaseGraphStorage):
await result.consume() # Ensure result is consumed await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}")
raise
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"Unexpected error during edge upsert: {str(e)}")
raise raise
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:

View file

@ -1280,7 +1280,7 @@ async def merge_nodes_and_edges(
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort the edge_key components to ensure consistent lock key generation # Sort the edge_key components to ensure consistent lock key generation
sorted_edge_key = sorted([edge_key[0], edge_key[1]]) sorted_edge_key = sorted([edge_key[0], edge_key[1]])
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
async with get_storage_keyed_lock( async with get_storage_keyed_lock(
sorted_edge_key, sorted_edge_key,
namespace=namespace, namespace=namespace,