diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index 16185775..63ab13fb 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -11,11 +11,20 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") +if not pm.is_installed("tenacity"): + pm.install("tenacity") from neo4j import ( AsyncGraphDatabase, AsyncManagedTransaction, ) +from neo4j.exceptions import TransientError +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) from dotenv import load_dotenv @@ -102,6 +111,25 @@ class MemgraphStorage(BaseGraphStorage): # Memgraph handles persistence automatically pass + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_exception_type(TransientError), + reraise=True, + ) + async def _execute_write_with_retry(self, session, operation_func): + """ + Execute a write operation with retry logic for Memgraph transient errors. + + Args: + session: Neo4j session + operation_func: Async function that takes a transaction and executes the operation + + Raises: + TransientError: If all retry attempts fail + """ + return await session.execute_write(operation_func) + async def has_node(self, node_id: str) -> bool: """ Check if a node exists in the graph. @@ -435,7 +463,7 @@ class MemgraphStorage(BaseGraphStorage): async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ - Upsert a node in the Memgraph database. + Upsert a node in the Memgraph database with retry logic for transient errors. Args: node_id: The unique identifier for the node (used as label) @@ -467,16 +495,21 @@ class MemgraphStorage(BaseGraphStorage): ) await result.consume() # Ensure result is fully consumed - await session.execute_write(execute_upsert) + await self._execute_write_with_retry(session, execute_upsert) + except TransientError as e: + logger.error( + f"Memgraph transient error during node upsert after retries: {str(e)}" + ) + raise except Exception as e: - logger.error(f"Error during upsert: {str(e)}") + logger.error(f"Error during node upsert: {str(e)}") raise async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """ - Upsert an edge and its properties between two nodes identified by their labels. + Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors. Ensures both source and target nodes exist and are unique before creating the edge. Uses entity_id property to uniquely identify nodes. @@ -517,7 +550,12 @@ class MemgraphStorage(BaseGraphStorage): finally: await result.consume() # Ensure result is consumed - await session.execute_write(execute_upsert) + await self._execute_write_with_retry(session, execute_upsert) + except TransientError as e: + logger.error( + f"Memgraph transient error during edge upsert after retries: {str(e)}" + ) + raise except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise