Merge branch 'fix-memgraph' into fix-keyed-lock

This commit is contained in:
yangdx 2025-07-19 11:55:24 +08:00
commit cba97c62fe
2 changed files with 144 additions and 48 deletions

View file

@ -1,4 +1,6 @@
import os import os
import asyncio
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import final from typing import final
import configparser import configparser
@ -11,11 +13,11 @@ import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
from neo4j import ( from neo4j import (
AsyncGraphDatabase, AsyncGraphDatabase,
AsyncManagedTransaction, AsyncManagedTransaction,
) )
from neo4j.exceptions import TransientError, ResultFailedError
from dotenv import load_dotenv from dotenv import load_dotenv
@ -435,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
""" """
Upsert a node in the Memgraph database. Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
Args: Args:
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
@ -452,31 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph: node properties must contain an 'entity_id' field" "Memgraph: node properties must contain an 'entity_id' field"
) )
try: # Manual transaction-level retry following official Memgraph documentation
async with self._driver.session(database=self._DATABASE) as session: max_retries = 100
workspace_label = self._get_workspace_label() initial_wait_time = 0.2
backoff_factor = 1.1
jitter_factor = 0.1
async def execute_upsert(tx: AsyncManagedTransaction): for attempt in range(max_retries):
query = f""" try:
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) logger.debug(
SET n += $properties f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
SET n:`{entity_type}` )
""" async with self._driver.session(database=self._DATABASE) as session:
result = await tx.run( workspace_label = self._get_workspace_label()
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert) async def execute_upsert(tx: AsyncManagedTransaction):
except Exception as e: query = f"""
logger.error(f"Error during upsert: {str(e)}") MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
raise SET n += $properties
SET n:`{entity_type}`
"""
result = await tx.run(
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during node upsert: {str(e)}")
raise
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
Ensures both source and target nodes exist and are unique before creating the edge. Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes. Uses entity_id property to uniquely identify nodes.
@ -492,35 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction): edge_properties = edge_data
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert) # Manual transaction-level retry following official Memgraph documentation
except Exception as e: max_retries = 100
logger.error(f"Error during edge upsert: {str(e)}") initial_wait_time = 0.2
raise backoff_factor = 1.1
jitter_factor = 0.1
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert)
break # Success - exit retry loop
except (TransientError, ResultFailedError) as e:
# Check if the root cause is a TransientError
root_cause = e
while hasattr(root_cause, "__cause__") and root_cause.__cause__:
root_cause = root_cause.__cause__
# Check if this is a transient error that should be retried
is_transient = (
isinstance(root_cause, TransientError)
or isinstance(e, TransientError)
or "TransientError" in str(e)
or "Cannot resolve conflicting transactions" in str(e)
)
if is_transient:
if attempt < max_retries - 1:
# Calculate wait time with exponential backoff and jitter
jitter = random.uniform(0, jitter_factor) * initial_wait_time
wait_time = (
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error during edge upsert: {str(e)}")
raise
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label """Delete a node with the specified label

View file

@ -1280,7 +1280,7 @@ async def merge_nodes_and_edges(
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort the edge_key components to ensure consistent lock key generation # Sort the edge_key components to ensure consistent lock key generation
sorted_edge_key = sorted([edge_key[0], edge_key[1]]) sorted_edge_key = sorted([edge_key[0], edge_key[1]])
logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
async with get_storage_keyed_lock( async with get_storage_keyed_lock(
sorted_edge_key, sorted_edge_key,
namespace=namespace, namespace=namespace,