Refactor PostgreSQL retry config to use centralized configuration
• Move retry config to ClientManager
• Remove env var parsing from PostgreSQLDB
• Add config params to test setup
(cherry picked from commit b3ed264707)
This commit is contained in:
parent
d5154bca73
commit
60a695539a
2 changed files with 309 additions and 109 deletions
|
|
@ -5,7 +5,7 @@ import re
|
|||
import datetime
|
||||
from datetime import timezone
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Union, final
|
||||
from typing import Any, Awaitable, Callable, TypeVar, Union, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
import ssl
|
||||
|
|
@ -14,10 +14,13 @@ import itertools
|
|||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryCallState,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_fixed,
|
||||
)
|
||||
|
||||
from ..base import (
|
||||
|
|
@ -48,6 +51,8 @@ from dotenv import load_dotenv
|
|||
# the OS environment variables take precedence over the .env file
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
||||
|
|
@ -83,6 +88,38 @@ class PostgreSQLDB:
|
|||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError("Missing database user, password, or database")
|
||||
|
||||
# Guard concurrent pool resets
|
||||
self._pool_reconnect_lock = asyncio.Lock()
|
||||
|
||||
self._transient_exceptions = (
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
OSError,
|
||||
asyncpg.exceptions.InterfaceError,
|
||||
asyncpg.exceptions.TooManyConnectionsError,
|
||||
asyncpg.exceptions.CannotConnectNowError,
|
||||
asyncpg.exceptions.PostgresConnectionError,
|
||||
asyncpg.exceptions.ConnectionDoesNotExistError,
|
||||
asyncpg.exceptions.ConnectionFailureError,
|
||||
)
|
||||
|
||||
# Connection retry configuration
|
||||
self.connection_retry_attempts = config["connection_retry_attempts"]
|
||||
self.connection_retry_backoff = config["connection_retry_backoff"]
|
||||
self.connection_retry_backoff_max = max(
|
||||
self.connection_retry_backoff,
|
||||
config["connection_retry_backoff_max"],
|
||||
)
|
||||
self.pool_close_timeout = config["pool_close_timeout"]
|
||||
logger.info(
|
||||
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
|
||||
self.connection_retry_attempts,
|
||||
self.connection_retry_backoff,
|
||||
self.connection_retry_backoff_max,
|
||||
self.pool_close_timeout,
|
||||
)
|
||||
|
||||
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
||||
"""Create SSL context based on configuration parameters."""
|
||||
if not self.ssl_mode:
|
||||
|
|
@ -154,63 +191,85 @@ class PostgreSQLDB:
|
|||
return None
|
||||
|
||||
async def initdb(self):
|
||||
# Prepare connection parameters
|
||||
connection_params = {
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"database": self.database,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"min_size": 1,
|
||||
"max_size": self.max,
|
||||
}
|
||||
|
||||
# Only add statement_cache_size if it's configured
|
||||
if self.statement_cache_size is not None:
|
||||
connection_params["statement_cache_size"] = int(self.statement_cache_size)
|
||||
logger.info(
|
||||
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
||||
)
|
||||
|
||||
# Add SSL configuration if provided
|
||||
ssl_context = self._create_ssl_context()
|
||||
if ssl_context is not None:
|
||||
connection_params["ssl"] = ssl_context
|
||||
logger.info("PostgreSQL, SSL configuration applied")
|
||||
elif self.ssl_mode:
|
||||
# Handle simple SSL modes without custom context
|
||||
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||
connection_params["ssl"] = True
|
||||
elif self.ssl_mode.lower() == "disable":
|
||||
connection_params["ssl"] = False
|
||||
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||
|
||||
# Add server settings if provided
|
||||
if self.server_settings:
|
||||
try:
|
||||
settings = {}
|
||||
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
||||
pairs = self.server_settings.split("&")
|
||||
for pair in pairs:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
settings[key] = value
|
||||
if settings:
|
||||
connection_params["server_settings"] = settings
|
||||
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
||||
)
|
||||
|
||||
wait_strategy = (
|
||||
wait_exponential(
|
||||
multiplier=self.connection_retry_backoff,
|
||||
min=self.connection_retry_backoff,
|
||||
max=self.connection_retry_backoff_max,
|
||||
)
|
||||
if self.connection_retry_backoff > 0
|
||||
else wait_fixed(0)
|
||||
)
|
||||
|
||||
async def _create_pool_once() -> None:
|
||||
pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||
try:
|
||||
async with pool.acquire() as connection:
|
||||
await self.configure_vector_extension(connection)
|
||||
except Exception:
|
||||
await pool.close()
|
||||
raise
|
||||
self.pool = pool
|
||||
|
||||
try:
|
||||
# Prepare connection parameters
|
||||
connection_params = {
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"database": self.database,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"min_size": 1,
|
||||
"max_size": self.max,
|
||||
}
|
||||
|
||||
# Only add statement_cache_size if it's configured
|
||||
if self.statement_cache_size is not None:
|
||||
connection_params["statement_cache_size"] = int(
|
||||
self.statement_cache_size
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
||||
)
|
||||
|
||||
# Add SSL configuration if provided
|
||||
ssl_context = self._create_ssl_context()
|
||||
if ssl_context is not None:
|
||||
connection_params["ssl"] = ssl_context
|
||||
logger.info("PostgreSQL, SSL configuration applied")
|
||||
elif self.ssl_mode:
|
||||
# Handle simple SSL modes without custom context
|
||||
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||
connection_params["ssl"] = True
|
||||
elif self.ssl_mode.lower() == "disable":
|
||||
connection_params["ssl"] = False
|
||||
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||
|
||||
# Add server settings if provided
|
||||
if self.server_settings:
|
||||
try:
|
||||
settings = {}
|
||||
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
||||
pairs = self.server_settings.split("&")
|
||||
for pair in pairs:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
settings[key] = value
|
||||
if settings:
|
||||
connection_params["server_settings"] = settings
|
||||
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
||||
)
|
||||
|
||||
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||
|
||||
# Ensure VECTOR extension is available
|
||||
async with self.pool.acquire() as connection:
|
||||
await self.configure_vector_extension(connection)
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(self.connection_retry_attempts),
|
||||
retry=retry_if_exception_type(self._transient_exceptions),
|
||||
wait=wait_strategy,
|
||||
before_sleep=self._before_sleep,
|
||||
reraise=True,
|
||||
):
|
||||
with attempt:
|
||||
await _create_pool_once()
|
||||
|
||||
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
||||
logger.info(
|
||||
|
|
@ -222,12 +281,97 @@ class PostgreSQLDB:
|
|||
)
|
||||
raise
|
||||
|
||||
async def _ensure_pool(self) -> None:
|
||||
"""Ensure the connection pool is initialised."""
|
||||
if self.pool is None:
|
||||
async with self._pool_reconnect_lock:
|
||||
if self.pool is None:
|
||||
await self.initdb()
|
||||
|
||||
async def _reset_pool(self) -> None:
|
||||
async with self._pool_reconnect_lock:
|
||||
if self.pool is not None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.pool.close(), timeout=self.pool_close_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"PostgreSQL, Timed out closing connection pool after %.2fs",
|
||||
self.pool_close_timeout,
|
||||
)
|
||||
except Exception as close_error: # pragma: no cover - defensive logging
|
||||
logger.warning(
|
||||
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
|
||||
)
|
||||
self.pool = None
|
||||
|
||||
async def _before_sleep(self, retry_state: RetryCallState) -> None:
|
||||
"""Hook invoked by tenacity before sleeping between retries."""
|
||||
exc = retry_state.outcome.exception() if retry_state.outcome else None
|
||||
logger.warning(
|
||||
"PostgreSQL transient connection issue on attempt %s/%s: %r",
|
||||
retry_state.attempt_number,
|
||||
self.connection_retry_attempts,
|
||||
exc,
|
||||
)
|
||||
await self._reset_pool()
|
||||
|
||||
async def _run_with_retry(
|
||||
self,
|
||||
operation: Callable[[asyncpg.Connection], Awaitable[T]],
|
||||
*,
|
||||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Execute a database operation with automatic retry for transient failures.
|
||||
|
||||
Args:
|
||||
operation: Async callable that receives an active connection.
|
||||
with_age: Whether to configure Apache AGE on the connection.
|
||||
graph_name: AGE graph name; required when with_age is True.
|
||||
|
||||
Returns:
|
||||
The result returned by the operation.
|
||||
|
||||
Raises:
|
||||
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
|
||||
"""
|
||||
wait_strategy = (
|
||||
wait_exponential(
|
||||
multiplier=self.connection_retry_backoff,
|
||||
min=self.connection_retry_backoff,
|
||||
max=self.connection_retry_backoff_max,
|
||||
)
|
||||
if self.connection_retry_backoff > 0
|
||||
else wait_fixed(0)
|
||||
)
|
||||
|
||||
async for attempt in AsyncRetrying(
|
||||
stop=stop_after_attempt(self.connection_retry_attempts),
|
||||
retry=retry_if_exception_type(self._transient_exceptions),
|
||||
wait=wait_strategy,
|
||||
before_sleep=self._before_sleep,
|
||||
reraise=True,
|
||||
):
|
||||
with attempt:
|
||||
await self._ensure_pool()
|
||||
assert self.pool is not None
|
||||
async with self.pool.acquire() as connection: # type: ignore[arg-type]
|
||||
if with_age and graph_name:
|
||||
await self.configure_age(connection, graph_name)
|
||||
elif with_age and not graph_name:
|
||||
raise ValueError("Graph name is required when with_age is True")
|
||||
|
||||
return await operation(connection)
|
||||
|
||||
@staticmethod
|
||||
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
|
||||
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
|
||||
try:
|
||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
|
||||
logger.info("VECTOR extension ensured for PostgreSQL")
|
||||
logger.info("PostgreSQL, VECTOR extension enabled")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not create VECTOR extension: {e}")
|
||||
# Don't raise - let the system continue without vector extension
|
||||
|
|
@ -237,7 +381,7 @@ class PostgreSQLDB:
|
|||
"""Create AGE extension if it doesn't exist for graph operations."""
|
||||
try:
|
||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
|
||||
logger.info("AGE extension ensured for PostgreSQL")
|
||||
logger.info("PostgreSQL, AGE extension enabled")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not create AGE extension: {e}")
|
||||
# Don't raise - let the system continue without AGE extension
|
||||
|
|
@ -1254,35 +1398,31 @@ class PostgreSQLDB:
|
|||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
if with_age and graph_name:
|
||||
await self.configure_age(connection, graph_name) # type: ignore
|
||||
elif with_age and not graph_name:
|
||||
raise ValueError("Graph name is required when with_age is True")
|
||||
async def _operation(connection: asyncpg.Connection) -> Any:
|
||||
prepared_params = tuple(params) if params else ()
|
||||
if prepared_params:
|
||||
rows = await connection.fetch(sql, *prepared_params)
|
||||
else:
|
||||
rows = await connection.fetch(sql)
|
||||
|
||||
try:
|
||||
if params:
|
||||
rows = await connection.fetch(sql, *params)
|
||||
else:
|
||||
rows = await connection.fetch(sql)
|
||||
if multirows:
|
||||
if rows:
|
||||
columns = [col for col in rows[0].keys()]
|
||||
return [dict(zip(columns, row)) for row in rows]
|
||||
return []
|
||||
|
||||
if multirows:
|
||||
if rows:
|
||||
columns = [col for col in rows[0].keys()]
|
||||
data = [dict(zip(columns, row)) for row in rows]
|
||||
else:
|
||||
data = []
|
||||
else:
|
||||
if rows:
|
||||
columns = rows[0].keys()
|
||||
data = dict(zip(columns, rows[0]))
|
||||
else:
|
||||
data = None
|
||||
if rows:
|
||||
columns = rows[0].keys()
|
||||
return dict(zip(columns, rows[0]))
|
||||
return None
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL database, error:{e}")
|
||||
raise
|
||||
try:
|
||||
return await self._run_with_retry(
|
||||
_operation, with_age=with_age, graph_name=graph_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL database, error:{e}")
|
||||
raise
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -1293,30 +1433,33 @@ class PostgreSQLDB:
|
|||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
):
|
||||
try:
|
||||
async with self.pool.acquire() as connection: # type: ignore
|
||||
if with_age and graph_name:
|
||||
await self.configure_age(connection, graph_name)
|
||||
elif with_age and not graph_name:
|
||||
raise ValueError("Graph name is required when with_age is True")
|
||||
async def _operation(connection: asyncpg.Connection) -> Any:
|
||||
prepared_values = tuple(data.values()) if data else ()
|
||||
try:
|
||||
if not data:
|
||||
return await connection.execute(sql)
|
||||
return await connection.execute(sql, *prepared_values)
|
||||
except (
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
asyncpg.exceptions.DuplicateTableError,
|
||||
asyncpg.exceptions.DuplicateObjectError,
|
||||
asyncpg.exceptions.InvalidSchemaNameError,
|
||||
) as e:
|
||||
if ignore_if_exists:
|
||||
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
|
||||
return None
|
||||
if upsert:
|
||||
logger.info(
|
||||
"PostgreSQL, duplicate detected but treated as upsert success: %r",
|
||||
e,
|
||||
)
|
||||
return None
|
||||
raise
|
||||
|
||||
if data is None:
|
||||
await connection.execute(sql)
|
||||
else:
|
||||
await connection.execute(sql, *data.values())
|
||||
except (
|
||||
asyncpg.exceptions.UniqueViolationError,
|
||||
asyncpg.exceptions.DuplicateTableError,
|
||||
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
|
||||
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
|
||||
) as e:
|
||||
if ignore_if_exists:
|
||||
# If the flag is set, just ignore these specific errors
|
||||
pass
|
||||
elif upsert:
|
||||
print("Key value duplicate, but upsert succeeded.")
|
||||
else:
|
||||
logger.error(f"Upsert error: {e}")
|
||||
try:
|
||||
await self._run_with_retry(
|
||||
_operation, with_age=with_age, graph_name=graph_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||
raise
|
||||
|
|
@ -1410,6 +1553,49 @@ class ClientManager:
|
|||
"POSTGRES_STATEMENT_CACHE_SIZE",
|
||||
config.get("postgres", "statement_cache_size", fallback=None),
|
||||
),
|
||||
# Connection retry configuration
|
||||
"connection_retry_attempts": min(
|
||||
10,
|
||||
int(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRIES",
|
||||
config.get("postgres", "connection_retries", fallback=3),
|
||||
)
|
||||
),
|
||||
),
|
||||
"connection_retry_backoff": min(
|
||||
5.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRY_BACKOFF",
|
||||
config.get(
|
||||
"postgres", "connection_retry_backoff", fallback=0.5
|
||||
),
|
||||
)
|
||||
),
|
||||
),
|
||||
"connection_retry_backoff_max": min(
|
||||
60.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
|
||||
config.get(
|
||||
"postgres",
|
||||
"connection_retry_backoff_max",
|
||||
fallback=5.0,
|
||||
),
|
||||
)
|
||||
),
|
||||
),
|
||||
"pool_close_timeout": min(
|
||||
30.0,
|
||||
float(
|
||||
os.environ.get(
|
||||
"POSTGRES_POOL_CLOSE_TIMEOUT",
|
||||
config.get("postgres", "pool_close_timeout", fallback=5.0),
|
||||
)
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -3183,7 +3369,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
{
|
||||
"message": f"Error executing graph query: {query}",
|
||||
"wrapped": query,
|
||||
"detail": str(e),
|
||||
"detail": repr(e),
|
||||
"error_type": e.__class__.__name__,
|
||||
}
|
||||
) from e
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,19 @@ class TestPostgresRetryIntegration:
|
|||
"database": os.getenv("POSTGRES_DATABASE", "postgres"),
|
||||
"workspace": os.getenv("POSTGRES_WORKSPACE", "test_retry"),
|
||||
"max_connections": int(os.getenv("POSTGRES_MAX_CONNECTIONS", "10")),
|
||||
# Connection retry configuration
|
||||
"connection_retry_attempts": min(
|
||||
10, int(os.getenv("POSTGRES_CONNECTION_RETRIES", "3"))
|
||||
),
|
||||
"connection_retry_backoff": min(
|
||||
5.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5"))
|
||||
),
|
||||
"connection_retry_backoff_max": min(
|
||||
60.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "5.0"))
|
||||
),
|
||||
"pool_close_timeout": min(
|
||||
30.0, float(os.getenv("POSTGRES_POOL_CLOSE_TIMEOUT", "5.0"))
|
||||
),
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue