From e758204ab2a15b6cc269baaebfc6055b141fc146 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 10 Oct 2025 01:58:51 +0800 Subject: [PATCH 1/3] Add PostgreSQL connection retry mechanism with comprehensive error handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Implement connection retry with backoff • Add transient error detection • Pool management with timeout guards --- lightrag/kg/postgres_impl.py | 375 ++++++++++++++++------- tests/test_postgres_retry_integration.py | 324 ++++++++++++++++++++ 2 files changed, 590 insertions(+), 109 deletions(-) create mode 100644 tests/test_postgres_retry_integration.py diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 50c2108f..703730ab 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -5,7 +5,7 @@ import re import datetime from datetime import timezone from dataclasses import dataclass, field -from typing import Any, Union, final +from typing import Any, Awaitable, Callable, TypeVar, Union, final import numpy as np import configparser import ssl @@ -14,10 +14,13 @@ import itertools from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( + AsyncRetrying, + RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, + wait_fixed, ) from ..base import ( @@ -48,6 +51,8 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) +T = TypeVar("T") + class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -80,9 +85,52 @@ class PostgreSQLDB: # Statement LRU cache size (keep as-is, allow None for optional configuration) self.statement_cache_size = config.get("statement_cache_size") + # Connection retry configuration + self.connection_retry_attempts = max( + 1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3))) + ) + self.connection_retry_backoff = max( + 0.1, + min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))), + ) + self.connection_retry_backoff_max = max( + self.connection_retry_backoff, + min( + 60.0, + float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)), + ), + ) + self.pool_close_timeout = max( + 1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0))) + ) + + self._transient_exceptions = ( + asyncio.TimeoutError, + TimeoutError, + ConnectionError, + OSError, + asyncpg.exceptions.InterfaceError, + asyncpg.exceptions.TooManyConnectionsError, + asyncpg.exceptions.CannotConnectNowError, + asyncpg.exceptions.PostgresConnectionError, + asyncpg.exceptions.ConnectionDoesNotExistError, + asyncpg.exceptions.ConnectionFailureError, + ) + + # Guard concurrent pool resets + self._pool_reconnect_lock = asyncio.Lock() + if self.user is None or self.password is None or self.database is None: raise ValueError("Missing database user, password, or database") + logger.info( + "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs", + self.connection_retry_attempts, + self.connection_retry_backoff, + self.connection_retry_backoff_max, + self.pool_close_timeout, + ) + def _create_ssl_context(self) -> ssl.SSLContext | None: """Create SSL context based on configuration parameters.""" if not self.ssl_mode: @@ -154,63 +202,87 @@ class PostgreSQLDB: return None async def initdb(self): + # Prepare connection parameters + connection_params = { + "user": self.user, + "password": self.password, + "database": self.database, + "host": self.host, + "port": self.port, + "min_size": 1, + "max_size": self.max, + } + + # Only add statement_cache_size if it's configured + if self.statement_cache_size is not None: + connection_params["statement_cache_size"] = int( + self.statement_cache_size + ) + logger.info( + f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" + ) + + # Add SSL configuration if provided + ssl_context = self._create_ssl_context() + if ssl_context is not None: + connection_params["ssl"] = ssl_context + logger.info("PostgreSQL, SSL configuration applied") + elif self.ssl_mode: + # Handle simple SSL modes without custom context + if self.ssl_mode.lower() in ["require", "prefer"]: + connection_params["ssl"] = True + elif self.ssl_mode.lower() == "disable": + connection_params["ssl"] = False + logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") + + # Add server settings if provided + if self.server_settings: + try: + settings = {} + # The format is expected to be a query string, e.g., "key1=value1&key2=value2" + pairs = self.server_settings.split("&") + for pair in pairs: + if "=" in pair: + key, value = pair.split("=", 1) + settings[key] = value + if settings: + connection_params["server_settings"] = settings + logger.info(f"PostgreSQL, Server settings applied: {settings}") + except Exception as e: + logger.warning( + f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" + ) + + wait_strategy = ( + wait_exponential( + multiplier=self.connection_retry_backoff, + min=self.connection_retry_backoff, + max=self.connection_retry_backoff_max, + ) + if self.connection_retry_backoff > 0 + else wait_fixed(0) + ) + + async def _create_pool_once() -> None: + pool = await asyncpg.create_pool(**connection_params) # type: ignore + try: + async with pool.acquire() as connection: + await self.configure_vector_extension(connection) + except Exception: + await pool.close() + raise + self.pool = pool + try: - # Prepare connection parameters - connection_params = { - "user": self.user, - "password": self.password, - "database": self.database, - "host": self.host, - "port": self.port, - "min_size": 1, - "max_size": self.max, - } - - # Only add statement_cache_size if it's configured - if self.statement_cache_size is not None: - connection_params["statement_cache_size"] = int( - self.statement_cache_size - ) - logger.info( - f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" - ) - - # Add SSL configuration if provided - ssl_context = self._create_ssl_context() - if ssl_context is not None: - connection_params["ssl"] = ssl_context - logger.info("PostgreSQL, SSL configuration applied") - elif self.ssl_mode: - # Handle simple SSL modes without custom context - if self.ssl_mode.lower() in ["require", "prefer"]: - connection_params["ssl"] = True - elif self.ssl_mode.lower() == "disable": - connection_params["ssl"] = False - logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") - - # Add server settings if provided - if self.server_settings: - try: - settings = {} - # The format is expected to be a query string, e.g., "key1=value1&key2=value2" - pairs = self.server_settings.split("&") - for pair in pairs: - if "=" in pair: - key, value = pair.split("=", 1) - settings[key] = value - if settings: - connection_params["server_settings"] = settings - logger.info(f"PostgreSQL, Server settings applied: {settings}") - except Exception as e: - logger.warning( - f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" - ) - - self.pool = await asyncpg.create_pool(**connection_params) # type: ignore - - # Ensure VECTOR extension is available - async with self.pool.acquire() as connection: - await self.configure_vector_extension(connection) + async for attempt in AsyncRetrying( + stop=stop_after_attempt(self.connection_retry_attempts), + retry=retry_if_exception_type(self._transient_exceptions), + wait=wait_strategy, + before_sleep=self._before_sleep, + reraise=True, + ): + with attempt: + await _create_pool_once() ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL" logger.info( @@ -222,12 +294,97 @@ class PostgreSQLDB: ) raise + async def _ensure_pool(self) -> None: + """Ensure the connection pool is initialised.""" + if self.pool is None: + async with self._pool_reconnect_lock: + if self.pool is None: + await self.initdb() + + async def _reset_pool(self) -> None: + async with self._pool_reconnect_lock: + if self.pool is not None: + try: + await asyncio.wait_for( + self.pool.close(), timeout=self.pool_close_timeout + ) + except asyncio.TimeoutError: + logger.error( + "PostgreSQL, Timed out closing connection pool after %.2fs", + self.pool_close_timeout, + ) + except Exception as close_error: # pragma: no cover - defensive logging + logger.warning( + f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}" + ) + self.pool = None + + async def _before_sleep(self, retry_state: RetryCallState) -> None: + """Hook invoked by tenacity before sleeping between retries.""" + exc = retry_state.outcome.exception() if retry_state.outcome else None + logger.warning( + "PostgreSQL transient connection issue on attempt %s/%s: %r", + retry_state.attempt_number, + self.connection_retry_attempts, + exc, + ) + await self._reset_pool() + + async def _run_with_retry( + self, + operation: Callable[[asyncpg.Connection], Awaitable[T]], + *, + with_age: bool = False, + graph_name: str | None = None, + ) -> T: + """ + Execute a database operation with automatic retry for transient failures. + + Args: + operation: Async callable that receives an active connection. + with_age: Whether to configure Apache AGE on the connection. + graph_name: AGE graph name; required when with_age is True. + + Returns: + The result returned by the operation. + + Raises: + Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs. + """ + wait_strategy = ( + wait_exponential( + multiplier=self.connection_retry_backoff, + min=self.connection_retry_backoff, + max=self.connection_retry_backoff_max, + ) + if self.connection_retry_backoff > 0 + else wait_fixed(0) + ) + + async for attempt in AsyncRetrying( + stop=stop_after_attempt(self.connection_retry_attempts), + retry=retry_if_exception_type(self._transient_exceptions), + wait=wait_strategy, + before_sleep=self._before_sleep, + reraise=True, + ): + with attempt: + await self._ensure_pool() + assert self.pool is not None + async with self.pool.acquire() as connection: # type: ignore[arg-type] + if with_age and graph_name: + await self.configure_age(connection, graph_name) + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + + return await operation(connection) + @staticmethod async def configure_vector_extension(connection: asyncpg.Connection) -> None: """Create VECTOR extension if it doesn't exist for vector similarity operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore - logger.info("VECTOR extension ensured for PostgreSQL") + logger.info("PostgreSQL, VECTOR extension enabled") except Exception as e: logger.warning(f"Could not create VECTOR extension: {e}") # Don't raise - let the system continue without vector extension @@ -237,7 +394,7 @@ class PostgreSQLDB: """Create AGE extension if it doesn't exist for graph operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore - logger.info("AGE extension ensured for PostgreSQL") + logger.info("PostgreSQL, AGE extension enabled") except Exception as e: logger.warning(f"Could not create AGE extension: {e}") # Don't raise - let the system continue without AGE extension @@ -1254,35 +1411,31 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ) -> dict[str, Any] | None | list[dict[str, Any]]: - async with self.pool.acquire() as connection: # type: ignore - if with_age and graph_name: - await self.configure_age(connection, graph_name) # type: ignore - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") + async def _operation(connection: asyncpg.Connection) -> Any: + prepared_params = tuple(params) if params else () + if prepared_params: + rows = await connection.fetch(sql, *prepared_params) + else: + rows = await connection.fetch(sql) - try: - if params: - rows = await connection.fetch(sql, *params) - else: - rows = await connection.fetch(sql) + if multirows: + if rows: + columns = [col for col in rows[0].keys()] + return [dict(zip(columns, row)) for row in rows] + return [] - if multirows: - if rows: - columns = [col for col in rows[0].keys()] - data = [dict(zip(columns, row)) for row in rows] - else: - data = [] - else: - if rows: - columns = rows[0].keys() - data = dict(zip(columns, rows[0])) - else: - data = None + if rows: + columns = rows[0].keys() + return dict(zip(columns, rows[0])) + return None - return data - except Exception as e: - logger.error(f"PostgreSQL database, error:{e}") - raise + try: + return await self._run_with_retry( + _operation, with_age=with_age, graph_name=graph_name + ) + except Exception as e: + logger.error(f"PostgreSQL database, error:{e}") + raise async def execute( self, @@ -1293,30 +1446,33 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ): - try: - async with self.pool.acquire() as connection: # type: ignore - if with_age and graph_name: - await self.configure_age(connection, graph_name) - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") + async def _operation(connection: asyncpg.Connection) -> Any: + prepared_values = tuple(data.values()) if data else () + try: + if not data: + return await connection.execute(sql) + return await connection.execute(sql, *prepared_values) + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + asyncpg.exceptions.DuplicateObjectError, + asyncpg.exceptions.InvalidSchemaNameError, + ) as e: + if ignore_if_exists: + logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e) + return None + if upsert: + logger.info( + "PostgreSQL, duplicate detected but treated as upsert success: %r", + e, + ) + return None + raise - if data is None: - await connection.execute(sql) - else: - await connection.execute(sql, *data.values()) - except ( - asyncpg.exceptions.UniqueViolationError, - asyncpg.exceptions.DuplicateTableError, - asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error - asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists" - ) as e: - if ignore_if_exists: - # If the flag is set, just ignore these specific errors - pass - elif upsert: - print("Key value duplicate, but upsert succeeded.") - else: - logger.error(f"Upsert error: {e}") + try: + await self._run_with_retry( + _operation, with_age=with_age, graph_name=graph_name + ) except Exception as e: logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise @@ -3183,7 +3339,8 @@ class PGGraphStorage(BaseGraphStorage): { "message": f"Error executing graph query: {query}", "wrapped": query, - "detail": str(e), + "detail": repr(e), + "error_type": e.__class__.__name__, } ) from e diff --git a/tests/test_postgres_retry_integration.py b/tests/test_postgres_retry_integration.py new file mode 100644 index 00000000..71e5c47d --- /dev/null +++ b/tests/test_postgres_retry_integration.py @@ -0,0 +1,324 @@ +""" +Integration test suite for PostgreSQL retry mechanism using real database. + +This test suite connects to a real PostgreSQL database using credentials from .env +and tests the retry mechanism with actual network failures. + +Prerequisites: +1. PostgreSQL server running and accessible +2. .env file with POSTGRES_* configuration +3. asyncpg installed: pip install asyncpg +""" + +import pytest +import asyncio +import os +import time +from dotenv import load_dotenv +from unittest.mock import patch +import asyncpg +from lightrag.kg.postgres_impl import PostgreSQLDB + +# Load environment variables +load_dotenv(dotenv_path=".env", override=False) + + +class TestPostgresRetryIntegration: + """Integration tests for PostgreSQL retry mechanism with real database.""" + + @pytest.fixture + def db_config(self): + """Load database configuration from environment variables.""" + return { + "host": os.getenv("POSTGRES_HOST", "localhost"), + "port": int(os.getenv("POSTGRES_PORT", "5432")), + "user": os.getenv("POSTGRES_USER", "postgres"), + "password": os.getenv("POSTGRES_PASSWORD", ""), + "database": os.getenv("POSTGRES_DATABASE", "postgres"), + "workspace": os.getenv("POSTGRES_WORKSPACE", "test_retry"), + "max_connections": int(os.getenv("POSTGRES_MAX_CONNECTIONS", "10")), + } + + @pytest.fixture + def test_env(self, monkeypatch): + """Set up test environment variables for retry configuration.""" + monkeypatch.setenv("POSTGRES_CONNECTION_RETRIES", "3") + monkeypatch.setenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5") + monkeypatch.setenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "2.0") + monkeypatch.setenv("POSTGRES_POOL_CLOSE_TIMEOUT", "3.0") + + @pytest.mark.asyncio + async def test_real_connection_success(self, db_config, test_env): + """ + Test successful connection to real PostgreSQL database. + + This validates that: + 1. Database credentials are correct + 2. Connection pool initializes properly + 3. Basic query works + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 1: Real Database Connection") + print("=" * 80) + print( + f" → Connecting to {db_config['host']}:{db_config['port']}/{db_config['database']}" + ) + + db = PostgreSQLDB(db_config) + + try: + # Initialize database connection + await db.initdb() + print(" ✓ Connection successful") + + # Test simple query + result = await db.query("SELECT 1 as test", multirows=False) + assert result is not None + assert result.get("test") == 1 + print(" ✓ Query executed successfully") + + print("\n✅ Test passed: Real database connection works") + print("=" * 80) + finally: + if db.pool: + await db.pool.close() + + @pytest.mark.asyncio + async def test_simulated_transient_error_with_real_db(self, db_config, test_env): + """ + Test retry mechanism with simulated transient errors on real database. + + Simulates connection failures on first 2 attempts, then succeeds. + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 2: Simulated Transient Errors") + print("=" * 80) + + db = PostgreSQLDB(db_config) + attempt_count = {"value": 0} + + # Original create_pool function + original_create_pool = asyncpg.create_pool + + async def mock_create_pool_with_failures(*args, **kwargs): + """Mock that fails first 2 times, then calls real create_pool.""" + attempt_count["value"] += 1 + print(f" → Connection attempt {attempt_count['value']}") + + if attempt_count["value"] <= 2: + print(" ✗ Simulating connection failure") + raise asyncpg.exceptions.ConnectionFailureError( + f"Simulated failure on attempt {attempt_count['value']}" + ) + + print(" ✓ Allowing real connection") + return await original_create_pool(*args, **kwargs) + + try: + # Patch create_pool to simulate failures + with patch( + "asyncpg.create_pool", side_effect=mock_create_pool_with_failures + ): + await db.initdb() + + assert ( + attempt_count["value"] == 3 + ), f"Expected 3 attempts, got {attempt_count['value']}" + assert db.pool is not None, "Pool should be initialized after retries" + + # Verify database is actually working + result = await db.query("SELECT 1 as test", multirows=False) + assert result.get("test") == 1 + + print( + f"\n✅ Test passed: Retry mechanism worked, connected after {attempt_count['value']} attempts" + ) + print("=" * 80) + finally: + if db.pool: + await db.pool.close() + + @pytest.mark.asyncio + async def test_query_retry_with_real_db(self, db_config, test_env): + """ + Test query-level retry with simulated connection issues. + + Tests that queries retry on transient failures by simulating + a temporary database unavailability. + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 3: Query-Level Retry") + print("=" * 80) + + db = PostgreSQLDB(db_config) + + try: + # First initialize normally + await db.initdb() + print(" ✓ Database initialized") + + # Close the pool to simulate connection loss + print(" → Simulating connection loss (closing pool)...") + await db.pool.close() + db.pool = None + + # Now query should trigger pool recreation and retry + print(" → Attempting query (should auto-reconnect)...") + result = await db.query("SELECT 1 as test", multirows=False) + + assert result.get("test") == 1, "Query should succeed after reconnection" + assert db.pool is not None, "Pool should be recreated" + + print(" ✓ Query succeeded after automatic reconnection") + print("\n✅ Test passed: Auto-reconnection works correctly") + print("=" * 80) + finally: + if db.pool: + await db.pool.close() + + @pytest.mark.asyncio + async def test_concurrent_queries_with_real_db(self, db_config, test_env): + """ + Test concurrent queries to validate thread safety and connection pooling. + + Runs multiple concurrent queries to ensure no deadlocks or race conditions. + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 4: Concurrent Queries") + print("=" * 80) + + db = PostgreSQLDB(db_config) + + try: + await db.initdb() + print(" ✓ Database initialized") + + # Launch 10 concurrent queries + num_queries = 10 + print(f" → Launching {num_queries} concurrent queries...") + + async def run_query(query_id): + result = await db.query( + f"SELECT {query_id} as id, pg_sleep(0.1)", multirows=False + ) + return result.get("id") + + start_time = time.time() + tasks = [run_query(i) for i in range(num_queries)] + results = await asyncio.gather(*tasks, return_exceptions=True) + elapsed = time.time() - start_time + + # Check results + successful = sum(1 for r in results if not isinstance(r, Exception)) + failed = sum(1 for r in results if isinstance(r, Exception)) + + print(f" → Completed in {elapsed:.2f}s") + print(f" → Results: {successful} successful, {failed} failed") + + assert ( + successful == num_queries + ), f"All {num_queries} queries should succeed" + assert failed == 0, "No queries should fail" + + print("\n✅ Test passed: All concurrent queries succeeded, no deadlocks") + print("=" * 80) + finally: + if db.pool: + await db.pool.close() + + @pytest.mark.asyncio + async def test_pool_close_timeout_real(self, db_config, test_env): + """ + Test pool close timeout protection with real database. + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 5: Pool Close Timeout") + print("=" * 80) + + db = PostgreSQLDB(db_config) + + try: + await db.initdb() + print(" ✓ Database initialized") + + # Trigger pool reset (which includes close) + print(" → Triggering pool reset...") + start_time = time.time() + await db._reset_pool() + elapsed = time.time() - start_time + + print(f" ✓ Pool reset completed in {elapsed:.2f}s") + assert db.pool is None, "Pool should be None after reset" + assert ( + elapsed < db.pool_close_timeout + 1 + ), "Reset should complete within timeout" + + print("\n✅ Test passed: Pool reset handled correctly") + print("=" * 80) + finally: + # Already closed in test + pass + + @pytest.mark.asyncio + async def test_configuration_from_env(self, db_config): + """ + Test that configuration is correctly loaded from environment variables. + """ + print("\n" + "=" * 80) + print("INTEGRATION TEST 6: Environment Configuration") + print("=" * 80) + + db = PostgreSQLDB(db_config) + + print(" → Configuration loaded:") + print(f" • Host: {db.host}") + print(f" • Port: {db.port}") + print(f" • Database: {db.database}") + print(f" • User: {db.user}") + print(f" • Workspace: {db.workspace}") + print(f" • Max Connections: {db.max}") + print(f" • Retry Attempts: {db.connection_retry_attempts}") + print(f" • Retry Backoff: {db.connection_retry_backoff}s") + print(f" • Max Backoff: {db.connection_retry_backoff_max}s") + print(f" • Pool Close Timeout: {db.pool_close_timeout}s") + + # Verify required fields are present + assert db.host, "Host should be configured" + assert db.port, "Port should be configured" + assert db.user, "User should be configured" + assert db.database, "Database should be configured" + + print("\n✅ Test passed: All configuration loaded correctly from .env") + print("=" * 80) + + +def run_integration_tests(): + """Run all integration tests with detailed output.""" + print("\n" + "=" * 80) + print("POSTGRESQL RETRY MECHANISM - INTEGRATION TESTS") + print("Testing with REAL database from .env configuration") + print("=" * 80) + + # Check if database configuration exists + if not os.getenv("POSTGRES_HOST"): + print("\n⚠️ WARNING: No POSTGRES_HOST in .env file") + print("Please ensure .env file exists with PostgreSQL configuration.") + return + + print("\nRunning integration tests...\n") + + # Run pytest with verbose output + pytest.main( + [ + __file__, + "-v", + "-s", # Don't capture output + "--tb=short", # Short traceback format + "--color=yes", + "-x", # Stop on first failure + ] + ) + + +if __name__ == "__main__": + run_integration_tests() From bd535e3e7a1835afdbef53c016622c2d8e0c0be3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 10 Oct 2025 02:30:11 +0800 Subject: [PATCH 2/3] Add PostgreSQL connection retry configuration options - Add retry environment variables - Fix asyncpg import in retry tests --- env.example | 10 ++++++++++ tests/test_postgres_retry_integration.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/env.example b/env.example index f59c2969..d86b98f7 100644 --- a/env.example +++ b/env.example @@ -303,6 +303,16 @@ POSTGRES_HNSW_M=16 POSTGRES_HNSW_EF=200 POSTGRES_IVFFLAT_LISTS=100 +### PostgreSQL Connection Retry Configuration (Network Robustness) +### Number of retry attempts (1-10, default: 3) +### Initial retry backoff in seconds (0.1-5.0, default: 0.5) +### Maximum retry backoff in seconds (backoff-60.0, default: 5.0) +### Connection pool close timeout in seconds (1.0-30.0, default: 5.0) +# POSTGRES_CONNECTION_RETRIES=3 +# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5 +# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0 +# POSTGRES_POOL_CLOSE_TIMEOUT=5.0 + ### PostgreSQL SSL Configuration (Optional) # POSTGRES_SSL_MODE=require # POSTGRES_SSL_CERT=/path/to/client-cert.pem diff --git a/tests/test_postgres_retry_integration.py b/tests/test_postgres_retry_integration.py index 71e5c47d..137ba686 100644 --- a/tests/test_postgres_retry_integration.py +++ b/tests/test_postgres_retry_integration.py @@ -16,9 +16,10 @@ import os import time from dotenv import load_dotenv from unittest.mock import patch -import asyncpg from lightrag.kg.postgres_impl import PostgreSQLDB +asyncpg = pytest.importorskip("asyncpg") + # Load environment variables load_dotenv(dotenv_path=".env", override=False) From b3ed2647073727f79be56fdcd77d862fe1250130 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 10 Oct 2025 03:44:13 +0800 Subject: [PATCH 3/3] Refactor PostgreSQL retry config to use centralized configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Move retry config to ClientManager • Remove env var parsing from PostgreSQLDB • Add config params to test setup --- lightrag/kg/postgres_impl.py | 84 ++++++++++++++++-------- tests/test_postgres_retry_integration.py | 13 ++++ 2 files changed, 70 insertions(+), 27 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 703730ab..93f0cad7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -85,24 +85,11 @@ class PostgreSQLDB: # Statement LRU cache size (keep as-is, allow None for optional configuration) self.statement_cache_size = config.get("statement_cache_size") - # Connection retry configuration - self.connection_retry_attempts = max( - 1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3))) - ) - self.connection_retry_backoff = max( - 0.1, - min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))), - ) - self.connection_retry_backoff_max = max( - self.connection_retry_backoff, - min( - 60.0, - float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)), - ), - ) - self.pool_close_timeout = max( - 1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0))) - ) + if self.user is None or self.password is None or self.database is None: + raise ValueError("Missing database user, password, or database") + + # Guard concurrent pool resets + self._pool_reconnect_lock = asyncio.Lock() self._transient_exceptions = ( asyncio.TimeoutError, @@ -117,12 +104,14 @@ class PostgreSQLDB: asyncpg.exceptions.ConnectionFailureError, ) - # Guard concurrent pool resets - self._pool_reconnect_lock = asyncio.Lock() - - if self.user is None or self.password is None or self.database is None: - raise ValueError("Missing database user, password, or database") - + # Connection retry configuration + self.connection_retry_attempts = config["connection_retry_attempts"] + self.connection_retry_backoff = config["connection_retry_backoff"] + self.connection_retry_backoff_max = max( + self.connection_retry_backoff, + config["connection_retry_backoff_max"], + ) + self.pool_close_timeout = config["pool_close_timeout"] logger.info( "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs", self.connection_retry_attempts, @@ -215,9 +204,7 @@ class PostgreSQLDB: # Only add statement_cache_size if it's configured if self.statement_cache_size is not None: - connection_params["statement_cache_size"] = int( - self.statement_cache_size - ) + connection_params["statement_cache_size"] = int(self.statement_cache_size) logger.info( f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" ) @@ -1566,6 +1553,49 @@ class ClientManager: "POSTGRES_STATEMENT_CACHE_SIZE", config.get("postgres", "statement_cache_size", fallback=None), ), + # Connection retry configuration + "connection_retry_attempts": min( + 10, + int( + os.environ.get( + "POSTGRES_CONNECTION_RETRIES", + config.get("postgres", "connection_retries", fallback=3), + ) + ), + ), + "connection_retry_backoff": min( + 5.0, + float( + os.environ.get( + "POSTGRES_CONNECTION_RETRY_BACKOFF", + config.get( + "postgres", "connection_retry_backoff", fallback=0.5 + ), + ) + ), + ), + "connection_retry_backoff_max": min( + 60.0, + float( + os.environ.get( + "POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", + config.get( + "postgres", + "connection_retry_backoff_max", + fallback=5.0, + ), + ) + ), + ), + "pool_close_timeout": min( + 30.0, + float( + os.environ.get( + "POSTGRES_POOL_CLOSE_TIMEOUT", + config.get("postgres", "pool_close_timeout", fallback=5.0), + ) + ), + ), } @classmethod diff --git a/tests/test_postgres_retry_integration.py b/tests/test_postgres_retry_integration.py index 137ba686..515f3072 100644 --- a/tests/test_postgres_retry_integration.py +++ b/tests/test_postgres_retry_integration.py @@ -38,6 +38,19 @@ class TestPostgresRetryIntegration: "database": os.getenv("POSTGRES_DATABASE", "postgres"), "workspace": os.getenv("POSTGRES_WORKSPACE", "test_retry"), "max_connections": int(os.getenv("POSTGRES_MAX_CONNECTIONS", "10")), + # Connection retry configuration + "connection_retry_attempts": min( + 10, int(os.getenv("POSTGRES_CONNECTION_RETRIES", "3")) + ), + "connection_retry_backoff": min( + 5.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5")) + ), + "connection_retry_backoff_max": min( + 60.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "5.0")) + ), + "pool_close_timeout": min( + 30.0, float(os.getenv("POSTGRES_POOL_CLOSE_TIMEOUT", "5.0")) + ), } @pytest.fixture