From 9f5399c2f12cc7c7470536c5d8fbfbc15fa4603c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 19 Jul 2025 11:31:21 +0800 Subject: [PATCH] Replace tenacity retries with manual Memgraph transaction retries - Implement manual retry logic - Add exponential backoff with jitter - Improve error handling for transient errors --- lightrag/kg/memgraph_impl.py | 228 ++++++++++++++++++++++------------- lightrag/operate.py | 2 +- 2 files changed, 144 insertions(+), 86 deletions(-) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 63ab13fb..86958a1a 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -1,4 +1,6 @@ import os +import asyncio +import random from dataclasses import dataclass from typing import final import configparser @@ -11,20 +13,11 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") -if not pm.is_installed("tenacity"): - pm.install("tenacity") - from neo4j import ( AsyncGraphDatabase, AsyncManagedTransaction, ) -from neo4j.exceptions import TransientError -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, -) +from neo4j.exceptions import TransientError, ResultFailedError from dotenv import load_dotenv @@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage): # Memgraph handles persistence automatically pass - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=1, max=10), - retry=retry_if_exception_type(TransientError), - reraise=True, - ) - async def _execute_write_with_retry(self, session, operation_func): - """ - Execute a write operation with retry logic for Memgraph transient errors. - - Args: - session: Neo4j session - operation_func: Async function that takes a transaction and executes the operation - - Raises: - TransientError: If all retry attempts fail - """ - return await session.execute_write(operation_func) - async def has_node(self, node_id: str) -> bool: """ Check if a node exists in the graph. @@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ - Upsert a node in the Memgraph database with retry logic for transient errors. + Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors. Args: node_id: The unique identifier for the node (used as label) @@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage): "Memgraph: node properties must contain an 'entity_id' field" ) - try: - async with self._driver.session(database=self._DATABASE) as session: - workspace_label = self._get_workspace_label() + # Manual transaction-level retry following official Memgraph documentation + max_retries = 100 + initial_wait_time = 0.2 + backoff_factor = 1.1 + jitter_factor = 0.1 - async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) - SET n += $properties - SET n:`{entity_type}` - """ - result = await tx.run( - query, entity_id=node_id, properties=properties - ) - await result.consume() # Ensure result is fully consumed + for attempt in range(max_retries): + try: + logger.debug( + f"Attempting node upsert, attempt {attempt + 1}/{max_retries}" + ) + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() - await self._execute_write_with_retry(session, execute_upsert) - except TransientError as e: - logger.error( - f"Memgraph transient error during node upsert after retries: {str(e)}" - ) - raise - except Exception as e: - logger.error(f"Error during node upsert: {str(e)}") - raise + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) + break # Success - exit retry loop + + except (TransientError, ResultFailedError) as e: + # Check if the root cause is a TransientError + root_cause = e + while hasattr(root_cause, "__cause__") and root_cause.__cause__: + root_cause = root_cause.__cause__ + + # Check if this is a transient error that should be retried + is_transient = ( + isinstance(root_cause, TransientError) + or isinstance(e, TransientError) + or "TransientError" in str(e) + or "Cannot resolve conflicting transactions" in str(e) + ) + + if is_transient: + if attempt < max_retries - 1: + # Calculate wait time with exponential backoff and jitter + jitter = random.uniform(0, jitter_factor) * initial_wait_time + wait_time = ( + initial_wait_time * (backoff_factor**attempt) + jitter + ) + logger.warning( + f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" + ) + await asyncio.sleep(wait_time) + else: + logger.error( + f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}" + ) + raise + else: + # Non-transient error, don't retry + logger.error(f"Non-transient error during node upsert: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error during node upsert: {str(e)}") + raise async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """ - Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors. + Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors. Ensures both source and target nodes exist and are unique before creating the edge. Uses entity_id property to uniquely identify nodes. @@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage): raise RuntimeError( "Memgraph driver is not initialized. Call 'await initialize()' first." ) - try: - edge_properties = edge_data - async with self._driver.session(database=self._DATABASE) as session: - async def execute_upsert(tx: AsyncManagedTransaction): - workspace_label = self._get_workspace_label() - query = f""" - MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) - WITH source - MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r, source, target - """ - result = await tx.run( - query, - source_entity_id=source_node_id, - target_entity_id=target_node_id, - properties=edge_properties, - ) - try: - await result.fetch(2) - finally: - await result.consume() # Ensure result is consumed + edge_properties = edge_data - await self._execute_write_with_retry(session, execute_upsert) - except TransientError as e: - logger.error( - f"Memgraph transient error during edge upsert after retries: {str(e)}" - ) - raise - except Exception as e: - logger.error(f"Error during edge upsert: {str(e)}") - raise + # Manual transaction-level retry following official Memgraph documentation + max_retries = 100 + initial_wait_time = 0.2 + backoff_factor = 1.1 + jitter_factor = 0.1 + + for attempt in range(max_retries): + try: + logger.debug( + f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}" + ) + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) + WITH source + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + break # Success - exit retry loop + + except (TransientError, ResultFailedError) as e: + # Check if the root cause is a TransientError + root_cause = e + while hasattr(root_cause, "__cause__") and root_cause.__cause__: + root_cause = root_cause.__cause__ + + # Check if this is a transient error that should be retried + is_transient = ( + isinstance(root_cause, TransientError) + or isinstance(e, TransientError) + or "TransientError" in str(e) + or "Cannot resolve conflicting transactions" in str(e) + ) + + if is_transient: + if attempt < max_retries - 1: + # Calculate wait time with exponential backoff and jitter + jitter = random.uniform(0, jitter_factor) * initial_wait_time + wait_time = ( + initial_wait_time * (backoff_factor**attempt) + jitter + ) + logger.warning( + f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" + ) + await asyncio.sleep(wait_time) + else: + logger.error( + f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}" + ) + raise + else: + # Non-transient error, don't retry + logger.error(f"Non-transient error during edge upsert: {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error during edge upsert: {str(e)}") + raise async def delete_node(self, node_id: str) -> None: """Delete a node with the specified label diff --git a/lightrag/operate.py b/lightrag/operate.py index d978d77a..11476c7f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges( namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" # Sort the edge_key components to ensure consistent lock key generation sorted_edge_key = sorted([edge_key[0], edge_key[1]]) - logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") + # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") async with get_storage_keyed_lock( f"{sorted_edge_key[0]}-{sorted_edge_key[1]}", namespace=namespace,