This commit is contained in:
Raphaël MANSUY 2025-12-04 19:18:36 +08:00
parent ebf704701b
commit 64050aeb1e

View file

@ -5,7 +5,7 @@ import re
import datetime
from datetime import timezone
from dataclasses import dataclass, field
from typing import Any, Union, final
from typing import Any, Awaitable, Callable, TypeVar, Union, final
import numpy as np
import configparser
import ssl
@ -14,10 +14,13 @@ import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import (
AsyncRetrying,
RetryCallState,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
from ..base import (
@ -48,6 +51,8 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
T = TypeVar("T")
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -77,12 +82,44 @@ class PostgreSQLDB:
# Server settings
self.server_settings = config.get("server_settings")
# Statement LRU cache size
self.statement_cache_size = int(config.get("statement_cache_size"))
# Statement LRU cache size (keep as-is, allow None for optional configuration)
self.statement_cache_size = config.get("statement_cache_size")
if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database")
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
self._transient_exceptions = (
asyncio.TimeoutError,
TimeoutError,
ConnectionError,
OSError,
asyncpg.exceptions.InterfaceError,
asyncpg.exceptions.TooManyConnectionsError,
asyncpg.exceptions.CannotConnectNowError,
asyncpg.exceptions.PostgresConnectionError,
asyncpg.exceptions.ConnectionDoesNotExistError,
asyncpg.exceptions.ConnectionFailureError,
)
# Connection retry configuration
self.connection_retry_attempts = config["connection_retry_attempts"]
self.connection_retry_backoff = config["connection_retry_backoff"]
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
config["connection_retry_backoff_max"],
)
self.pool_close_timeout = config["pool_close_timeout"]
logger.info(
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
self.connection_retry_attempts,
self.connection_retry_backoff,
self.connection_retry_backoff_max,
self.pool_close_timeout,
)
def _create_ssl_context(self) -> ssl.SSLContext | None:
"""Create SSL context based on configuration parameters."""
if not self.ssl_mode:
@ -154,59 +191,85 @@ class PostgreSQLDB:
return None
async def initdb(self):
try:
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
"statement_cache_size": self.statement_cache_size,
}
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
}
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(self.statement_cache_size)
logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
)
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
# Ensure VECTOR extension is available
async with self.pool.acquire() as connection:
await self.configure_vector_extension(connection)
async def _create_pool_once() -> None:
pool = await asyncpg.create_pool(**connection_params) # type: ignore
try:
async with pool.acquire() as connection:
await self.configure_vector_extension(connection)
except Exception:
await pool.close()
raise
self.pool = pool
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await _create_pool_once()
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
logger.info(
@ -218,12 +281,97 @@ class PostgreSQLDB:
)
raise
async def _ensure_pool(self) -> None:
"""Ensure the connection pool is initialised."""
if self.pool is None:
async with self._pool_reconnect_lock:
if self.pool is None:
await self.initdb()
async def _reset_pool(self) -> None:
async with self._pool_reconnect_lock:
if self.pool is not None:
try:
await asyncio.wait_for(
self.pool.close(), timeout=self.pool_close_timeout
)
except asyncio.TimeoutError:
logger.error(
"PostgreSQL, Timed out closing connection pool after %.2fs",
self.pool_close_timeout,
)
except Exception as close_error: # pragma: no cover - defensive logging
logger.warning(
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
)
self.pool = None
async def _before_sleep(self, retry_state: RetryCallState) -> None:
"""Hook invoked by tenacity before sleeping between retries."""
exc = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
"PostgreSQL transient connection issue on attempt %s/%s: %r",
retry_state.attempt_number,
self.connection_retry_attempts,
exc,
)
await self._reset_pool()
async def _run_with_retry(
self,
operation: Callable[[asyncpg.Connection], Awaitable[T]],
*,
with_age: bool = False,
graph_name: str | None = None,
) -> T:
"""
Execute a database operation with automatic retry for transient failures.
Args:
operation: Async callable that receives an active connection.
with_age: Whether to configure Apache AGE on the connection.
graph_name: AGE graph name; required when with_age is True.
Returns:
The result returned by the operation.
Raises:
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
"""
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await self._ensure_pool()
assert self.pool is not None
async with self.pool.acquire() as connection: # type: ignore[arg-type]
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
return await operation(connection)
@staticmethod
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
logger.info("VECTOR extension ensured for PostgreSQL")
logger.info("PostgreSQL, VECTOR extension enabled")
except Exception as e:
logger.warning(f"Could not create VECTOR extension: {e}")
# Don't raise - let the system continue without vector extension
@ -233,7 +381,7 @@ class PostgreSQLDB:
"""Create AGE extension if it doesn't exist for graph operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("AGE extension ensured for PostgreSQL")
logger.info("PostgreSQL, AGE extension enabled")
except Exception as e:
logger.warning(f"Could not create AGE extension: {e}")
# Don't raise - let the system continue without AGE extension
@ -1250,35 +1398,31 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_params = tuple(params) if params else ()
if prepared_params:
rows = await connection.fetch(sql, *prepared_params)
else:
rows = await connection.fetch(sql)
try:
if params:
rows = await connection.fetch(sql, *params)
else:
rows = await connection.fetch(sql)
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
return [dict(zip(columns, row)) for row in rows]
return []
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
data = [dict(zip(columns, row)) for row in rows]
else:
data = []
else:
if rows:
columns = rows[0].keys()
data = dict(zip(columns, rows[0]))
else:
data = None
if rows:
columns = rows[0].keys()
return dict(zip(columns, rows[0]))
return None
return data
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
try:
return await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
async def execute(
self,
@ -1289,30 +1433,33 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
):
try:
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_values = tuple(data.values()) if data else ()
try:
if not data:
return await connection.execute(sql)
return await connection.execute(sql, *prepared_values)
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError,
asyncpg.exceptions.InvalidSchemaNameError,
) as e:
if ignore_if_exists:
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
return None
if upsert:
logger.info(
"PostgreSQL, duplicate detected but treated as upsert success: %r",
e,
)
return None
raise
if data is None:
await connection.execute(sql)
else:
await connection.execute(sql, *data.values())
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
) as e:
if ignore_if_exists:
# If the flag is set, just ignore these specific errors
pass
elif upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
try:
await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
except Exception as e:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
@ -1406,6 +1553,49 @@ class ClientManager:
"POSTGRES_STATEMENT_CACHE_SIZE",
config.get("postgres", "statement_cache_size", fallback=None),
),
# Connection retry configuration
"connection_retry_attempts": min(
10,
int(
os.environ.get(
"POSTGRES_CONNECTION_RETRIES",
config.get("postgres", "connection_retries", fallback=3),
)
),
),
"connection_retry_backoff": min(
5.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF",
config.get(
"postgres", "connection_retry_backoff", fallback=0.5
),
)
),
),
"connection_retry_backoff_max": min(
60.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
config.get(
"postgres",
"connection_retry_backoff_max",
fallback=5.0,
),
)
),
),
"pool_close_timeout": min(
30.0,
float(
os.environ.get(
"POSTGRES_POOL_CLOSE_TIMEOUT",
config.get("postgres", "pool_close_timeout", fallback=5.0),
)
),
),
}
@classmethod
@ -1653,12 +1843,33 @@ class PGKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
if not ids:
return []
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
rows: list[dict[str, Any]] | None,
) -> list[dict[str, Any] | None]:
"""Preserve the caller requested ordering for bulk id lookups."""
if not rows:
return [None for _ in ids]
id_map: dict[str, dict[str, Any]] = {}
for row in rows:
if row is None:
continue
row_id = row.get("id")
if row_id is not None:
id_map[str(row_id)] = row
ordered: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered.append(id_map.get(str(requested_id)))
return ordered
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result
for result in results:
@ -1701,7 +1912,7 @@ class PGKVStorage(BaseKVStorage):
"update_time": create_time if update_time == 0 else update_time,
}
processed_results.append(processed_row)
return processed_results
return _order_results(processed_results)
# Special handling for FULL_ENTITIES namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
@ -1735,15 +1946,16 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
return results if results else []
return _order_results(results)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2193,7 +2405,23 @@ class PGVectorStorage(BaseVectorStorage):
try:
results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results]
if not results:
return []
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
id_map: dict[str, dict[str, Any]] = {}
for record in results:
if record is None:
continue
record_dict = dict(record)
row_id = record_dict.get("id")
if row_id is not None:
id_map[str(row_id)] = record_dict
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(id_map.get(str(requested_id)))
return ordered_results
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
@ -2306,11 +2534,12 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2381,7 +2610,7 @@ class PGDocStatusStorage(DocStatusStorage):
if not results:
return []
processed_results = []
processed_map: dict[str, dict[str, Any]] = {}
for row in results:
# Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", [])
@ -2403,23 +2632,25 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = self._format_datetime_with_timezone(row["created_at"])
updated_at = self._format_datetime_with_timezone(row["updated_at"])
processed_results.append(
{
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": created_at,
"updated_at": updated_at,
"file_path": row["file_path"],
"chunks_list": chunks_list,
"metadata": metadata,
"error_msg": row.get("error_msg"),
"track_id": row.get("track_id"),
}
)
processed_map[str(row.get("id"))] = {
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": created_at,
"updated_at": updated_at,
"file_path": row["file_path"],
"chunks_list": chunks_list,
"metadata": metadata,
"error_msg": row.get("error_msg"),
"track_id": row.get("track_id"),
}
return processed_results
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(processed_map.get(str(requested_id)))
return ordered_results
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path
@ -2621,26 +2852,33 @@ class PGDocStatusStorage(DocStatusStorage):
elif page_size > 200:
page_size = 200
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
# Whitelist validation for sort_field to prevent SQL injection
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
if sort_field not in allowed_sort_fields:
sort_field = "updated_at"
# Whitelist validation for sort_direction to prevent SQL injection
if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc"
else:
sort_direction = sort_direction.lower()
# Calculate offset
offset = (page - 1) * page_size
# Build WHERE clause
where_clause = "WHERE workspace=$1"
# Build parameterized query components
params = {"workspace": self.workspace}
param_count = 1
# Build WHERE clause with parameterized query
if status_filter is not None:
param_count += 1
where_clause += f" AND status=${param_count}"
where_clause = "WHERE workspace=$1 AND status=$2"
params["status"] = status_filter.value
else:
where_clause = "WHERE workspace=$1"
# Build ORDER BY clause
# Build ORDER BY clause using validated whitelist values
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
# Query for total count
@ -2648,7 +2886,7 @@ class PGDocStatusStorage(DocStatusStorage):
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0
# Query for paginated data
# Query for paginated data with parameterized LIMIT and OFFSET
data_sql = f"""
SELECT * FROM LIGHTRAG_DOC_STATUS
{where_clause}
@ -3179,7 +3417,8 @@ class PGGraphStorage(BaseGraphStorage):
{
"message": f"Error executing graph query: {query}",
"wrapped": query,
"detail": str(e),
"detail": repr(e),
"error_type": e.__class__.__name__,
}
) from e
@ -4645,19 +4884,19 @@ SQL_TEMPLATES = {
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
COALESCE(doc_name, '') as file_path
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_id_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
@ -4672,12 +4911,12 @@ SQL_TEMPLATES = {
"get_by_ids_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)