Merge pull request #2192 from danielaskdd/postgres-network-retry
## Refactor: Add PostgreSQL Connection Retry Mechanism with Network Robustness
This commit is contained in:
commit
b4d61eb8f5
3 changed files with 644 additions and 109 deletions
10
env.example
10
env.example
|
|
@ -303,6 +303,16 @@ POSTGRES_HNSW_M=16
|
||||||
POSTGRES_HNSW_EF=200
|
POSTGRES_HNSW_EF=200
|
||||||
POSTGRES_IVFFLAT_LISTS=100
|
POSTGRES_IVFFLAT_LISTS=100
|
||||||
|
|
||||||
|
### PostgreSQL Connection Retry Configuration (Network Robustness)
|
||||||
|
### Number of retry attempts (1-10, default: 3)
|
||||||
|
### Initial retry backoff in seconds (0.1-5.0, default: 0.5)
|
||||||
|
### Maximum retry backoff in seconds (backoff-60.0, default: 5.0)
|
||||||
|
### Connection pool close timeout in seconds (1.0-30.0, default: 5.0)
|
||||||
|
# POSTGRES_CONNECTION_RETRIES=3
|
||||||
|
# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5
|
||||||
|
# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0
|
||||||
|
# POSTGRES_POOL_CLOSE_TIMEOUT=5.0
|
||||||
|
|
||||||
### PostgreSQL SSL Configuration (Optional)
|
### PostgreSQL SSL Configuration (Optional)
|
||||||
# POSTGRES_SSL_MODE=require
|
# POSTGRES_SSL_MODE=require
|
||||||
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import re
|
||||||
import datetime
|
import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Union, final
|
from typing import Any, Awaitable, Callable, TypeVar, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
import ssl
|
import ssl
|
||||||
|
|
@ -14,10 +14,13 @@ import itertools
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
AsyncRetrying,
|
||||||
|
RetryCallState,
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
|
wait_fixed,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
|
|
@ -48,6 +51,8 @@ from dotenv import load_dotenv
|
||||||
# the OS environment variables take precedence over the .env file
|
# the OS environment variables take precedence over the .env file
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDB:
|
class PostgreSQLDB:
|
||||||
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
||||||
|
|
@ -83,6 +88,38 @@ class PostgreSQLDB:
|
||||||
if self.user is None or self.password is None or self.database is None:
|
if self.user is None or self.password is None or self.database is None:
|
||||||
raise ValueError("Missing database user, password, or database")
|
raise ValueError("Missing database user, password, or database")
|
||||||
|
|
||||||
|
# Guard concurrent pool resets
|
||||||
|
self._pool_reconnect_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
self._transient_exceptions = (
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
OSError,
|
||||||
|
asyncpg.exceptions.InterfaceError,
|
||||||
|
asyncpg.exceptions.TooManyConnectionsError,
|
||||||
|
asyncpg.exceptions.CannotConnectNowError,
|
||||||
|
asyncpg.exceptions.PostgresConnectionError,
|
||||||
|
asyncpg.exceptions.ConnectionDoesNotExistError,
|
||||||
|
asyncpg.exceptions.ConnectionFailureError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection retry configuration
|
||||||
|
self.connection_retry_attempts = config["connection_retry_attempts"]
|
||||||
|
self.connection_retry_backoff = config["connection_retry_backoff"]
|
||||||
|
self.connection_retry_backoff_max = max(
|
||||||
|
self.connection_retry_backoff,
|
||||||
|
config["connection_retry_backoff_max"],
|
||||||
|
)
|
||||||
|
self.pool_close_timeout = config["pool_close_timeout"]
|
||||||
|
logger.info(
|
||||||
|
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
|
||||||
|
self.connection_retry_attempts,
|
||||||
|
self.connection_retry_backoff,
|
||||||
|
self.connection_retry_backoff_max,
|
||||||
|
self.pool_close_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
"""Create SSL context based on configuration parameters."""
|
"""Create SSL context based on configuration parameters."""
|
||||||
if not self.ssl_mode:
|
if not self.ssl_mode:
|
||||||
|
|
@ -154,63 +191,85 @@ class PostgreSQLDB:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def initdb(self):
|
async def initdb(self):
|
||||||
|
# Prepare connection parameters
|
||||||
|
connection_params = {
|
||||||
|
"user": self.user,
|
||||||
|
"password": self.password,
|
||||||
|
"database": self.database,
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"min_size": 1,
|
||||||
|
"max_size": self.max,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add statement_cache_size if it's configured
|
||||||
|
if self.statement_cache_size is not None:
|
||||||
|
connection_params["statement_cache_size"] = int(self.statement_cache_size)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add SSL configuration if provided
|
||||||
|
ssl_context = self._create_ssl_context()
|
||||||
|
if ssl_context is not None:
|
||||||
|
connection_params["ssl"] = ssl_context
|
||||||
|
logger.info("PostgreSQL, SSL configuration applied")
|
||||||
|
elif self.ssl_mode:
|
||||||
|
# Handle simple SSL modes without custom context
|
||||||
|
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||||
|
connection_params["ssl"] = True
|
||||||
|
elif self.ssl_mode.lower() == "disable":
|
||||||
|
connection_params["ssl"] = False
|
||||||
|
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||||
|
|
||||||
|
# Add server settings if provided
|
||||||
|
if self.server_settings:
|
||||||
|
try:
|
||||||
|
settings = {}
|
||||||
|
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
||||||
|
pairs = self.server_settings.split("&")
|
||||||
|
for pair in pairs:
|
||||||
|
if "=" in pair:
|
||||||
|
key, value = pair.split("=", 1)
|
||||||
|
settings[key] = value
|
||||||
|
if settings:
|
||||||
|
connection_params["server_settings"] = settings
|
||||||
|
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
wait_strategy = (
|
||||||
|
wait_exponential(
|
||||||
|
multiplier=self.connection_retry_backoff,
|
||||||
|
min=self.connection_retry_backoff,
|
||||||
|
max=self.connection_retry_backoff_max,
|
||||||
|
)
|
||||||
|
if self.connection_retry_backoff > 0
|
||||||
|
else wait_fixed(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_pool_once() -> None:
|
||||||
|
pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as connection:
|
||||||
|
await self.configure_vector_extension(connection)
|
||||||
|
except Exception:
|
||||||
|
await pool.close()
|
||||||
|
raise
|
||||||
|
self.pool = pool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare connection parameters
|
async for attempt in AsyncRetrying(
|
||||||
connection_params = {
|
stop=stop_after_attempt(self.connection_retry_attempts),
|
||||||
"user": self.user,
|
retry=retry_if_exception_type(self._transient_exceptions),
|
||||||
"password": self.password,
|
wait=wait_strategy,
|
||||||
"database": self.database,
|
before_sleep=self._before_sleep,
|
||||||
"host": self.host,
|
reraise=True,
|
||||||
"port": self.port,
|
):
|
||||||
"min_size": 1,
|
with attempt:
|
||||||
"max_size": self.max,
|
await _create_pool_once()
|
||||||
}
|
|
||||||
|
|
||||||
# Only add statement_cache_size if it's configured
|
|
||||||
if self.statement_cache_size is not None:
|
|
||||||
connection_params["statement_cache_size"] = int(
|
|
||||||
self.statement_cache_size
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add SSL configuration if provided
|
|
||||||
ssl_context = self._create_ssl_context()
|
|
||||||
if ssl_context is not None:
|
|
||||||
connection_params["ssl"] = ssl_context
|
|
||||||
logger.info("PostgreSQL, SSL configuration applied")
|
|
||||||
elif self.ssl_mode:
|
|
||||||
# Handle simple SSL modes without custom context
|
|
||||||
if self.ssl_mode.lower() in ["require", "prefer"]:
|
|
||||||
connection_params["ssl"] = True
|
|
||||||
elif self.ssl_mode.lower() == "disable":
|
|
||||||
connection_params["ssl"] = False
|
|
||||||
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
|
||||||
|
|
||||||
# Add server settings if provided
|
|
||||||
if self.server_settings:
|
|
||||||
try:
|
|
||||||
settings = {}
|
|
||||||
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
|
||||||
pairs = self.server_settings.split("&")
|
|
||||||
for pair in pairs:
|
|
||||||
if "=" in pair:
|
|
||||||
key, value = pair.split("=", 1)
|
|
||||||
settings[key] = value
|
|
||||||
if settings:
|
|
||||||
connection_params["server_settings"] = settings
|
|
||||||
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
|
||||||
|
|
||||||
# Ensure VECTOR extension is available
|
|
||||||
async with self.pool.acquire() as connection:
|
|
||||||
await self.configure_vector_extension(connection)
|
|
||||||
|
|
||||||
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -222,12 +281,97 @@ class PostgreSQLDB:
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _ensure_pool(self) -> None:
|
||||||
|
"""Ensure the connection pool is initialised."""
|
||||||
|
if self.pool is None:
|
||||||
|
async with self._pool_reconnect_lock:
|
||||||
|
if self.pool is None:
|
||||||
|
await self.initdb()
|
||||||
|
|
||||||
|
async def _reset_pool(self) -> None:
|
||||||
|
async with self._pool_reconnect_lock:
|
||||||
|
if self.pool is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.pool.close(), timeout=self.pool_close_timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"PostgreSQL, Timed out closing connection pool after %.2fs",
|
||||||
|
self.pool_close_timeout,
|
||||||
|
)
|
||||||
|
except Exception as close_error: # pragma: no cover - defensive logging
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
|
||||||
|
)
|
||||||
|
self.pool = None
|
||||||
|
|
||||||
|
async def _before_sleep(self, retry_state: RetryCallState) -> None:
|
||||||
|
"""Hook invoked by tenacity before sleeping between retries."""
|
||||||
|
exc = retry_state.outcome.exception() if retry_state.outcome else None
|
||||||
|
logger.warning(
|
||||||
|
"PostgreSQL transient connection issue on attempt %s/%s: %r",
|
||||||
|
retry_state.attempt_number,
|
||||||
|
self.connection_retry_attempts,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await self._reset_pool()
|
||||||
|
|
||||||
|
async def _run_with_retry(
|
||||||
|
self,
|
||||||
|
operation: Callable[[asyncpg.Connection], Awaitable[T]],
|
||||||
|
*,
|
||||||
|
with_age: bool = False,
|
||||||
|
graph_name: str | None = None,
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Execute a database operation with automatic retry for transient failures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: Async callable that receives an active connection.
|
||||||
|
with_age: Whether to configure Apache AGE on the connection.
|
||||||
|
graph_name: AGE graph name; required when with_age is True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result returned by the operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
|
||||||
|
"""
|
||||||
|
wait_strategy = (
|
||||||
|
wait_exponential(
|
||||||
|
multiplier=self.connection_retry_backoff,
|
||||||
|
min=self.connection_retry_backoff,
|
||||||
|
max=self.connection_retry_backoff_max,
|
||||||
|
)
|
||||||
|
if self.connection_retry_backoff > 0
|
||||||
|
else wait_fixed(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
async for attempt in AsyncRetrying(
|
||||||
|
stop=stop_after_attempt(self.connection_retry_attempts),
|
||||||
|
retry=retry_if_exception_type(self._transient_exceptions),
|
||||||
|
wait=wait_strategy,
|
||||||
|
before_sleep=self._before_sleep,
|
||||||
|
reraise=True,
|
||||||
|
):
|
||||||
|
with attempt:
|
||||||
|
await self._ensure_pool()
|
||||||
|
assert self.pool is not None
|
||||||
|
async with self.pool.acquire() as connection: # type: ignore[arg-type]
|
||||||
|
if with_age and graph_name:
|
||||||
|
await self.configure_age(connection, graph_name)
|
||||||
|
elif with_age and not graph_name:
|
||||||
|
raise ValueError("Graph name is required when with_age is True")
|
||||||
|
|
||||||
|
return await operation(connection)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
|
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
|
||||||
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
|
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
|
||||||
try:
|
try:
|
||||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
|
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
|
||||||
logger.info("VECTOR extension ensured for PostgreSQL")
|
logger.info("PostgreSQL, VECTOR extension enabled")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not create VECTOR extension: {e}")
|
logger.warning(f"Could not create VECTOR extension: {e}")
|
||||||
# Don't raise - let the system continue without vector extension
|
# Don't raise - let the system continue without vector extension
|
||||||
|
|
@ -237,7 +381,7 @@ class PostgreSQLDB:
|
||||||
"""Create AGE extension if it doesn't exist for graph operations."""
|
"""Create AGE extension if it doesn't exist for graph operations."""
|
||||||
try:
|
try:
|
||||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
|
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
|
||||||
logger.info("AGE extension ensured for PostgreSQL")
|
logger.info("PostgreSQL, AGE extension enabled")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not create AGE extension: {e}")
|
logger.warning(f"Could not create AGE extension: {e}")
|
||||||
# Don't raise - let the system continue without AGE extension
|
# Don't raise - let the system continue without AGE extension
|
||||||
|
|
@ -1254,35 +1398,31 @@ class PostgreSQLDB:
|
||||||
with_age: bool = False,
|
with_age: bool = False,
|
||||||
graph_name: str | None = None,
|
graph_name: str | None = None,
|
||||||
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
async def _operation(connection: asyncpg.Connection) -> Any:
|
||||||
if with_age and graph_name:
|
prepared_params = tuple(params) if params else ()
|
||||||
await self.configure_age(connection, graph_name) # type: ignore
|
if prepared_params:
|
||||||
elif with_age and not graph_name:
|
rows = await connection.fetch(sql, *prepared_params)
|
||||||
raise ValueError("Graph name is required when with_age is True")
|
else:
|
||||||
|
rows = await connection.fetch(sql)
|
||||||
|
|
||||||
try:
|
if multirows:
|
||||||
if params:
|
if rows:
|
||||||
rows = await connection.fetch(sql, *params)
|
columns = [col for col in rows[0].keys()]
|
||||||
else:
|
return [dict(zip(columns, row)) for row in rows]
|
||||||
rows = await connection.fetch(sql)
|
return []
|
||||||
|
|
||||||
if multirows:
|
if rows:
|
||||||
if rows:
|
columns = rows[0].keys()
|
||||||
columns = [col for col in rows[0].keys()]
|
return dict(zip(columns, rows[0]))
|
||||||
data = [dict(zip(columns, row)) for row in rows]
|
return None
|
||||||
else:
|
|
||||||
data = []
|
|
||||||
else:
|
|
||||||
if rows:
|
|
||||||
columns = rows[0].keys()
|
|
||||||
data = dict(zip(columns, rows[0]))
|
|
||||||
else:
|
|
||||||
data = None
|
|
||||||
|
|
||||||
return data
|
try:
|
||||||
except Exception as e:
|
return await self._run_with_retry(
|
||||||
logger.error(f"PostgreSQL database, error:{e}")
|
_operation, with_age=with_age, graph_name=graph_name
|
||||||
raise
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL database, error:{e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1293,30 +1433,33 @@ class PostgreSQLDB:
|
||||||
with_age: bool = False,
|
with_age: bool = False,
|
||||||
graph_name: str | None = None,
|
graph_name: str | None = None,
|
||||||
):
|
):
|
||||||
try:
|
async def _operation(connection: asyncpg.Connection) -> Any:
|
||||||
async with self.pool.acquire() as connection: # type: ignore
|
prepared_values = tuple(data.values()) if data else ()
|
||||||
if with_age and graph_name:
|
try:
|
||||||
await self.configure_age(connection, graph_name)
|
if not data:
|
||||||
elif with_age and not graph_name:
|
return await connection.execute(sql)
|
||||||
raise ValueError("Graph name is required when with_age is True")
|
return await connection.execute(sql, *prepared_values)
|
||||||
|
except (
|
||||||
|
asyncpg.exceptions.UniqueViolationError,
|
||||||
|
asyncpg.exceptions.DuplicateTableError,
|
||||||
|
asyncpg.exceptions.DuplicateObjectError,
|
||||||
|
asyncpg.exceptions.InvalidSchemaNameError,
|
||||||
|
) as e:
|
||||||
|
if ignore_if_exists:
|
||||||
|
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
|
||||||
|
return None
|
||||||
|
if upsert:
|
||||||
|
logger.info(
|
||||||
|
"PostgreSQL, duplicate detected but treated as upsert success: %r",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
if data is None:
|
try:
|
||||||
await connection.execute(sql)
|
await self._run_with_retry(
|
||||||
else:
|
_operation, with_age=with_age, graph_name=graph_name
|
||||||
await connection.execute(sql, *data.values())
|
)
|
||||||
except (
|
|
||||||
asyncpg.exceptions.UniqueViolationError,
|
|
||||||
asyncpg.exceptions.DuplicateTableError,
|
|
||||||
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
|
|
||||||
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
|
|
||||||
) as e:
|
|
||||||
if ignore_if_exists:
|
|
||||||
# If the flag is set, just ignore these specific errors
|
|
||||||
pass
|
|
||||||
elif upsert:
|
|
||||||
print("Key value duplicate, but upsert succeeded.")
|
|
||||||
else:
|
|
||||||
logger.error(f"Upsert error: {e}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||||
raise
|
raise
|
||||||
|
|
@ -1410,6 +1553,49 @@ class ClientManager:
|
||||||
"POSTGRES_STATEMENT_CACHE_SIZE",
|
"POSTGRES_STATEMENT_CACHE_SIZE",
|
||||||
config.get("postgres", "statement_cache_size", fallback=None),
|
config.get("postgres", "statement_cache_size", fallback=None),
|
||||||
),
|
),
|
||||||
|
# Connection retry configuration
|
||||||
|
"connection_retry_attempts": min(
|
||||||
|
10,
|
||||||
|
int(
|
||||||
|
os.environ.get(
|
||||||
|
"POSTGRES_CONNECTION_RETRIES",
|
||||||
|
config.get("postgres", "connection_retries", fallback=3),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"connection_retry_backoff": min(
|
||||||
|
5.0,
|
||||||
|
float(
|
||||||
|
os.environ.get(
|
||||||
|
"POSTGRES_CONNECTION_RETRY_BACKOFF",
|
||||||
|
config.get(
|
||||||
|
"postgres", "connection_retry_backoff", fallback=0.5
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"connection_retry_backoff_max": min(
|
||||||
|
60.0,
|
||||||
|
float(
|
||||||
|
os.environ.get(
|
||||||
|
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
|
||||||
|
config.get(
|
||||||
|
"postgres",
|
||||||
|
"connection_retry_backoff_max",
|
||||||
|
fallback=5.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"pool_close_timeout": min(
|
||||||
|
30.0,
|
||||||
|
float(
|
||||||
|
os.environ.get(
|
||||||
|
"POSTGRES_POOL_CLOSE_TIMEOUT",
|
||||||
|
config.get("postgres", "pool_close_timeout", fallback=5.0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -3183,7 +3369,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
{
|
{
|
||||||
"message": f"Error executing graph query: {query}",
|
"message": f"Error executing graph query: {query}",
|
||||||
"wrapped": query,
|
"wrapped": query,
|
||||||
"detail": str(e),
|
"detail": repr(e),
|
||||||
|
"error_type": e.__class__.__name__,
|
||||||
}
|
}
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
|
||||||
338
tests/test_postgres_retry_integration.py
Normal file
338
tests/test_postgres_retry_integration.py
Normal file
|
|
@ -0,0 +1,338 @@
|
||||||
|
"""
|
||||||
|
Integration test suite for PostgreSQL retry mechanism using real database.
|
||||||
|
|
||||||
|
This test suite connects to a real PostgreSQL database using credentials from .env
|
||||||
|
and tests the retry mechanism with actual network failures.
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
1. PostgreSQL server running and accessible
|
||||||
|
2. .env file with POSTGRES_* configuration
|
||||||
|
3. asyncpg installed: pip install asyncpg
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from unittest.mock import patch
|
||||||
|
from lightrag.kg.postgres_impl import PostgreSQLDB
|
||||||
|
|
||||||
|
asyncpg = pytest.importorskip("asyncpg")
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostgresRetryIntegration:
|
||||||
|
"""Integration tests for PostgreSQL retry mechanism with real database."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_config(self):
|
||||||
|
"""Load database configuration from environment variables."""
|
||||||
|
return {
|
||||||
|
"host": os.getenv("POSTGRES_HOST", "localhost"),
|
||||||
|
"port": int(os.getenv("POSTGRES_PORT", "5432")),
|
||||||
|
"user": os.getenv("POSTGRES_USER", "postgres"),
|
||||||
|
"password": os.getenv("POSTGRES_PASSWORD", ""),
|
||||||
|
"database": os.getenv("POSTGRES_DATABASE", "postgres"),
|
||||||
|
"workspace": os.getenv("POSTGRES_WORKSPACE", "test_retry"),
|
||||||
|
"max_connections": int(os.getenv("POSTGRES_MAX_CONNECTIONS", "10")),
|
||||||
|
# Connection retry configuration
|
||||||
|
"connection_retry_attempts": min(
|
||||||
|
10, int(os.getenv("POSTGRES_CONNECTION_RETRIES", "3"))
|
||||||
|
),
|
||||||
|
"connection_retry_backoff": min(
|
||||||
|
5.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5"))
|
||||||
|
),
|
||||||
|
"connection_retry_backoff_max": min(
|
||||||
|
60.0, float(os.getenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "5.0"))
|
||||||
|
),
|
||||||
|
"pool_close_timeout": min(
|
||||||
|
30.0, float(os.getenv("POSTGRES_POOL_CLOSE_TIMEOUT", "5.0"))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_env(self, monkeypatch):
|
||||||
|
"""Set up test environment variables for retry configuration."""
|
||||||
|
monkeypatch.setenv("POSTGRES_CONNECTION_RETRIES", "3")
|
||||||
|
monkeypatch.setenv("POSTGRES_CONNECTION_RETRY_BACKOFF", "0.5")
|
||||||
|
monkeypatch.setenv("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", "2.0")
|
||||||
|
monkeypatch.setenv("POSTGRES_POOL_CLOSE_TIMEOUT", "3.0")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_connection_success(self, db_config, test_env):
|
||||||
|
"""
|
||||||
|
Test successful connection to real PostgreSQL database.
|
||||||
|
|
||||||
|
This validates that:
|
||||||
|
1. Database credentials are correct
|
||||||
|
2. Connection pool initializes properly
|
||||||
|
3. Basic query works
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 1: Real Database Connection")
|
||||||
|
print("=" * 80)
|
||||||
|
print(
|
||||||
|
f" → Connecting to {db_config['host']}:{db_config['port']}/{db_config['database']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize database connection
|
||||||
|
await db.initdb()
|
||||||
|
print(" ✓ Connection successful")
|
||||||
|
|
||||||
|
# Test simple query
|
||||||
|
result = await db.query("SELECT 1 as test", multirows=False)
|
||||||
|
assert result is not None
|
||||||
|
assert result.get("test") == 1
|
||||||
|
print(" ✓ Query executed successfully")
|
||||||
|
|
||||||
|
print("\n✅ Test passed: Real database connection works")
|
||||||
|
print("=" * 80)
|
||||||
|
finally:
|
||||||
|
if db.pool:
|
||||||
|
await db.pool.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simulated_transient_error_with_real_db(self, db_config, test_env):
|
||||||
|
"""
|
||||||
|
Test retry mechanism with simulated transient errors on real database.
|
||||||
|
|
||||||
|
Simulates connection failures on first 2 attempts, then succeeds.
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 2: Simulated Transient Errors")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
attempt_count = {"value": 0}
|
||||||
|
|
||||||
|
# Original create_pool function
|
||||||
|
original_create_pool = asyncpg.create_pool
|
||||||
|
|
||||||
|
async def mock_create_pool_with_failures(*args, **kwargs):
|
||||||
|
"""Mock that fails first 2 times, then calls real create_pool."""
|
||||||
|
attempt_count["value"] += 1
|
||||||
|
print(f" → Connection attempt {attempt_count['value']}")
|
||||||
|
|
||||||
|
if attempt_count["value"] <= 2:
|
||||||
|
print(" ✗ Simulating connection failure")
|
||||||
|
raise asyncpg.exceptions.ConnectionFailureError(
|
||||||
|
f"Simulated failure on attempt {attempt_count['value']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" ✓ Allowing real connection")
|
||||||
|
return await original_create_pool(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Patch create_pool to simulate failures
|
||||||
|
with patch(
|
||||||
|
"asyncpg.create_pool", side_effect=mock_create_pool_with_failures
|
||||||
|
):
|
||||||
|
await db.initdb()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
attempt_count["value"] == 3
|
||||||
|
), f"Expected 3 attempts, got {attempt_count['value']}"
|
||||||
|
assert db.pool is not None, "Pool should be initialized after retries"
|
||||||
|
|
||||||
|
# Verify database is actually working
|
||||||
|
result = await db.query("SELECT 1 as test", multirows=False)
|
||||||
|
assert result.get("test") == 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n✅ Test passed: Retry mechanism worked, connected after {attempt_count['value']} attempts"
|
||||||
|
)
|
||||||
|
print("=" * 80)
|
||||||
|
finally:
|
||||||
|
if db.pool:
|
||||||
|
await db.pool.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_retry_with_real_db(self, db_config, test_env):
|
||||||
|
"""
|
||||||
|
Test query-level retry with simulated connection issues.
|
||||||
|
|
||||||
|
Tests that queries retry on transient failures by simulating
|
||||||
|
a temporary database unavailability.
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 3: Query-Level Retry")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First initialize normally
|
||||||
|
await db.initdb()
|
||||||
|
print(" ✓ Database initialized")
|
||||||
|
|
||||||
|
# Close the pool to simulate connection loss
|
||||||
|
print(" → Simulating connection loss (closing pool)...")
|
||||||
|
await db.pool.close()
|
||||||
|
db.pool = None
|
||||||
|
|
||||||
|
# Now query should trigger pool recreation and retry
|
||||||
|
print(" → Attempting query (should auto-reconnect)...")
|
||||||
|
result = await db.query("SELECT 1 as test", multirows=False)
|
||||||
|
|
||||||
|
assert result.get("test") == 1, "Query should succeed after reconnection"
|
||||||
|
assert db.pool is not None, "Pool should be recreated"
|
||||||
|
|
||||||
|
print(" ✓ Query succeeded after automatic reconnection")
|
||||||
|
print("\n✅ Test passed: Auto-reconnection works correctly")
|
||||||
|
print("=" * 80)
|
||||||
|
finally:
|
||||||
|
if db.pool:
|
||||||
|
await db.pool.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_queries_with_real_db(self, db_config, test_env):
|
||||||
|
"""
|
||||||
|
Test concurrent queries to validate thread safety and connection pooling.
|
||||||
|
|
||||||
|
Runs multiple concurrent queries to ensure no deadlocks or race conditions.
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 4: Concurrent Queries")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.initdb()
|
||||||
|
print(" ✓ Database initialized")
|
||||||
|
|
||||||
|
# Launch 10 concurrent queries
|
||||||
|
num_queries = 10
|
||||||
|
print(f" → Launching {num_queries} concurrent queries...")
|
||||||
|
|
||||||
|
async def run_query(query_id):
|
||||||
|
result = await db.query(
|
||||||
|
f"SELECT {query_id} as id, pg_sleep(0.1)", multirows=False
|
||||||
|
)
|
||||||
|
return result.get("id")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
tasks = [run_query(i) for i in range(num_queries)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
successful = sum(1 for r in results if not isinstance(r, Exception))
|
||||||
|
failed = sum(1 for r in results if isinstance(r, Exception))
|
||||||
|
|
||||||
|
print(f" → Completed in {elapsed:.2f}s")
|
||||||
|
print(f" → Results: {successful} successful, {failed} failed")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
successful == num_queries
|
||||||
|
), f"All {num_queries} queries should succeed"
|
||||||
|
assert failed == 0, "No queries should fail"
|
||||||
|
|
||||||
|
print("\n✅ Test passed: All concurrent queries succeeded, no deadlocks")
|
||||||
|
print("=" * 80)
|
||||||
|
finally:
|
||||||
|
if db.pool:
|
||||||
|
await db.pool.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pool_close_timeout_real(self, db_config, test_env):
|
||||||
|
"""
|
||||||
|
Test pool close timeout protection with real database.
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 5: Pool Close Timeout")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.initdb()
|
||||||
|
print(" ✓ Database initialized")
|
||||||
|
|
||||||
|
# Trigger pool reset (which includes close)
|
||||||
|
print(" → Triggering pool reset...")
|
||||||
|
start_time = time.time()
|
||||||
|
await db._reset_pool()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
print(f" ✓ Pool reset completed in {elapsed:.2f}s")
|
||||||
|
assert db.pool is None, "Pool should be None after reset"
|
||||||
|
assert (
|
||||||
|
elapsed < db.pool_close_timeout + 1
|
||||||
|
), "Reset should complete within timeout"
|
||||||
|
|
||||||
|
print("\n✅ Test passed: Pool reset handled correctly")
|
||||||
|
print("=" * 80)
|
||||||
|
finally:
|
||||||
|
# Already closed in test
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_configuration_from_env(self, db_config):
|
||||||
|
"""
|
||||||
|
Test that configuration is correctly loaded from environment variables.
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("INTEGRATION TEST 6: Environment Configuration")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
db = PostgreSQLDB(db_config)
|
||||||
|
|
||||||
|
print(" → Configuration loaded:")
|
||||||
|
print(f" • Host: {db.host}")
|
||||||
|
print(f" • Port: {db.port}")
|
||||||
|
print(f" • Database: {db.database}")
|
||||||
|
print(f" • User: {db.user}")
|
||||||
|
print(f" • Workspace: {db.workspace}")
|
||||||
|
print(f" • Max Connections: {db.max}")
|
||||||
|
print(f" • Retry Attempts: {db.connection_retry_attempts}")
|
||||||
|
print(f" • Retry Backoff: {db.connection_retry_backoff}s")
|
||||||
|
print(f" • Max Backoff: {db.connection_retry_backoff_max}s")
|
||||||
|
print(f" • Pool Close Timeout: {db.pool_close_timeout}s")
|
||||||
|
|
||||||
|
# Verify required fields are present
|
||||||
|
assert db.host, "Host should be configured"
|
||||||
|
assert db.port, "Port should be configured"
|
||||||
|
assert db.user, "User should be configured"
|
||||||
|
assert db.database, "Database should be configured"
|
||||||
|
|
||||||
|
print("\n✅ Test passed: All configuration loaded correctly from .env")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def run_integration_tests():
|
||||||
|
"""Run all integration tests with detailed output."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("POSTGRESQL RETRY MECHANISM - INTEGRATION TESTS")
|
||||||
|
print("Testing with REAL database from .env configuration")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Check if database configuration exists
|
||||||
|
if not os.getenv("POSTGRES_HOST"):
|
||||||
|
print("\n⚠️ WARNING: No POSTGRES_HOST in .env file")
|
||||||
|
print("Please ensure .env file exists with PostgreSQL configuration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\nRunning integration tests...\n")
|
||||||
|
|
||||||
|
# Run pytest with verbose output
|
||||||
|
pytest.main(
|
||||||
|
[
|
||||||
|
__file__,
|
||||||
|
"-v",
|
||||||
|
"-s", # Don't capture output
|
||||||
|
"--tb=short", # Short traceback format
|
||||||
|
"--color=yes",
|
||||||
|
"-x", # Stop on first failure
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_integration_tests()
|
||||||
Loading…
Add table
Reference in a new issue