Refactor PostgreSQL retry config to use centralized configuration

• Move retry config to ClientManager
• Remove env var parsing from PostgreSQLDB
• Add config params to test setup

(cherry picked from commit b3ed264707)
This commit is contained in:
yangdx 2025-10-10 03:44:13 +08:00 committed by Raphaël MANSUY
parent d5154bca73
commit 60a695539a
2 changed files with 309 additions and 109 deletions

View file

@ -5,7 +5,7 @@ import re
import datetime
from datetime import timezone
from dataclasses import dataclass, field
from typing import Any, Union, final
from typing import Any, Awaitable, Callable, TypeVar, Union, final
import numpy as np
import configparser
import ssl
@ -14,10 +14,13 @@ import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import (
AsyncRetrying,
RetryCallState,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
from ..base import (
@ -48,6 +51,8 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
T = TypeVar("T")
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -83,6 +88,38 @@ class PostgreSQLDB:
if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database")
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
self._transient_exceptions = (
asyncio.TimeoutError,
TimeoutError,
ConnectionError,
OSError,
asyncpg.exceptions.InterfaceError,
asyncpg.exceptions.TooManyConnectionsError,
asyncpg.exceptions.CannotConnectNowError,
asyncpg.exceptions.PostgresConnectionError,
asyncpg.exceptions.ConnectionDoesNotExistError,
asyncpg.exceptions.ConnectionFailureError,
)
# Connection retry configuration
self.connection_retry_attempts = config["connection_retry_attempts"]
self.connection_retry_backoff = config["connection_retry_backoff"]
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
config["connection_retry_backoff_max"],
)
self.pool_close_timeout = config["pool_close_timeout"]
logger.info(
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
self.connection_retry_attempts,
self.connection_retry_backoff,
self.connection_retry_backoff_max,
self.pool_close_timeout,
)
def _create_ssl_context(self) -> ssl.SSLContext | None:
"""Create SSL context based on configuration parameters."""
if not self.ssl_mode:
@ -154,63 +191,85 @@ class PostgreSQLDB:
return None
async def initdb(self):
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
}
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(self.statement_cache_size)
logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
)
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async def _create_pool_once() -> None:
pool = await asyncpg.create_pool(**connection_params) # type: ignore
try:
async with pool.acquire() as connection:
await self.configure_vector_extension(connection)
except Exception:
await pool.close()
raise
self.pool = pool
try:
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
}
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(
self.statement_cache_size
)
logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
)
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
# Ensure VECTOR extension is available
async with self.pool.acquire() as connection:
await self.configure_vector_extension(connection)
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await _create_pool_once()
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
logger.info(
@ -222,12 +281,97 @@ class PostgreSQLDB:
)
raise
async def _ensure_pool(self) -> None:
"""Ensure the connection pool is initialised."""
if self.pool is None:
async with self._pool_reconnect_lock:
if self.pool is None:
await self.initdb()
async def _reset_pool(self) -> None:
async with self._pool_reconnect_lock:
if self.pool is not None:
try:
await asyncio.wait_for(
self.pool.close(), timeout=self.pool_close_timeout
)
except asyncio.TimeoutError:
logger.error(
"PostgreSQL, Timed out closing connection pool after %.2fs",
self.pool_close_timeout,
)
except Exception as close_error: # pragma: no cover - defensive logging
logger.warning(
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
)
self.pool = None
async def _before_sleep(self, retry_state: RetryCallState) -> None:
"""Hook invoked by tenacity before sleeping between retries."""
exc = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
"PostgreSQL transient connection issue on attempt %s/%s: %r",
retry_state.attempt_number,
self.connection_retry_attempts,
exc,
)
await self._reset_pool()
async def _run_with_retry(
self,
operation: Callable[[asyncpg.Connection], Awaitable[T]],
*,
with_age: bool = False,
graph_name: str | None = None,
) -> T:
"""
Execute a database operation with automatic retry for transient failures.
Args:
operation: Async callable that receives an active connection.
with_age: Whether to configure Apache AGE on the connection.
graph_name: AGE graph name; required when with_age is True.
Returns:
The result returned by the operation.
Raises:
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
"""
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await self._ensure_pool()
assert self.pool is not None
async with self.pool.acquire() as connection: # type: ignore[arg-type]
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
return await operation(connection)
@staticmethod
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
logger.info("VECTOR extension ensured for PostgreSQL")
logger.info("PostgreSQL, VECTOR extension enabled")
except Exception as e:
logger.warning(f"Could not create VECTOR extension: {e}")
# Don't raise - let the system continue without vector extension
@ -237,7 +381,7 @@ class PostgreSQLDB:
"""Create AGE extension if it doesn't exist for graph operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("AGE extension ensured for PostgreSQL")
logger.info("PostgreSQL, AGE extension enabled")
except Exception as e:
logger.warning(f"Could not create AGE extension: {e}")
# Don't raise - let the system continue without AGE extension
@ -1254,35 +1398,31 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_params = tuple(params) if params else ()
if prepared_params:
rows = await connection.fetch(sql, *prepared_params)
else:
rows = await connection.fetch(sql)
try:
if params:
rows = await connection.fetch(sql, *params)
else:
rows = await connection.fetch(sql)
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
return [dict(zip(columns, row)) for row in rows]
return []
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
data = [dict(zip(columns, row)) for row in rows]
else:
data = []
else:
if rows:
columns = rows[0].keys()
data = dict(zip(columns, rows[0]))
else:
data = None
if rows:
columns = rows[0].keys()
return dict(zip(columns, rows[0]))
return None
return data
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
try:
return await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
async def execute(
self,
@ -1293,30 +1433,33 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
):
try:
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_values = tuple(data.values()) if data else ()
try:
if not data:
return await connection.execute(sql)
return await connection.execute(sql, *prepared_values)
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError,
asyncpg.exceptions.InvalidSchemaNameError,
) as e:
if ignore_if_exists:
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
return None
if upsert:
logger.info(
"PostgreSQL, duplicate detected but treated as upsert success: %r",
e,
)
return None
raise
if data is None:
await connection.execute(sql)
else:
await connection.execute(sql, *data.values())
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
) as e:
if ignore_if_exists:
# If the flag is set, just ignore these specific errors
pass
elif upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
try:
await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
except Exception as e:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
@ -1410,6 +1553,49 @@ class ClientManager:
"POSTGRES_STATEMENT_CACHE_SIZE",
config.get("postgres", "statement_cache_size", fallback=None),
),
# Connection retry configuration
"connection_retry_attempts": min(
10,
int(
os.environ.get(
"POSTGRES_CONNECTION_RETRIES",
config.get("postgres", "connection_retries", fallback=3),
)
),
),
"connection_retry_backoff": min(
5.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF",
config.get(
"postgres", "connection_retry_backoff", fallback=0.5
),
)
),
),
"connection_retry_backoff_max": min(
60.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
config.get(
"postgres",
"connection_retry_backoff_max",
fallback=5.0,
),
)
),
),
"pool_close_timeout": min(
30.0,
float(
os.environ.get(
"POSTGRES_POOL_CLOSE_TIMEOUT",
config.get("postgres", "pool_close_timeout", fallback=5.0),
)
),
),
}
@classmethod
@ -3183,7 +3369,8 @@ class PGGraphStorage(BaseGraphStorage):
{
"message": f"Error executing graph query: {query}",
"wrapped": query,
"detail": str(e),
"detail": repr(e),
"error_type": e.__class__.__name__,
}
) from e

View file

@ -37,6 +37,19 @@ class TestPostgresRetryIntegration:
"database": os.getenv("POSTGRES_DATABASE", "postgres"),
"workspace": os.getenv("POSTGRES_WORKSPACE", "test_retry"),
"max_connections": int(os.getenv("POSTGRES_MAX_CONNECTIONS", "10")),
# Connection retry configuration
"connection_retry_attempts": min(
10, int(os.getenv("POSTGRES_CONNECTION_RETRIES", "3"))
),
"connection_retry_backoff": min(
5.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5"))
),
"connection_retry_backoff_max": min(
60.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "5.0"))
),
"pool_close_timeout": min(
30.0, float(os.getenv("POSTGRES_POOL_CLOSE_TIMEOUT", "5.0"))
),
}
@pytest.fixture