This commit is contained in:
Raphaël MANSUY 2025-12-04 19:18:36 +08:00
parent ebf704701b
commit 64050aeb1e

View file

@ -5,7 +5,7 @@ import re
import datetime import datetime
from datetime import timezone from datetime import timezone
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Awaitable, Callable, TypeVar, Union, final
import numpy as np import numpy as np
import configparser import configparser
import ssl import ssl
@ -14,10 +14,13 @@ import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import ( from tenacity import (
AsyncRetrying,
RetryCallState,
retry, retry,
retry_if_exception_type, retry_if_exception_type,
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
wait_fixed,
) )
from ..base import ( from ..base import (
@ -48,6 +51,8 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
T = TypeVar("T")
class PostgreSQLDB: class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any): def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -77,12 +82,44 @@ class PostgreSQLDB:
# Server settings # Server settings
self.server_settings = config.get("server_settings") self.server_settings = config.get("server_settings")
# Statement LRU cache size # Statement LRU cache size (keep as-is, allow None for optional configuration)
self.statement_cache_size = int(config.get("statement_cache_size")) self.statement_cache_size = config.get("statement_cache_size")
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database") raise ValueError("Missing database user, password, or database")
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
self._transient_exceptions = (
asyncio.TimeoutError,
TimeoutError,
ConnectionError,
OSError,
asyncpg.exceptions.InterfaceError,
asyncpg.exceptions.TooManyConnectionsError,
asyncpg.exceptions.CannotConnectNowError,
asyncpg.exceptions.PostgresConnectionError,
asyncpg.exceptions.ConnectionDoesNotExistError,
asyncpg.exceptions.ConnectionFailureError,
)
# Connection retry configuration
self.connection_retry_attempts = config["connection_retry_attempts"]
self.connection_retry_backoff = config["connection_retry_backoff"]
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
config["connection_retry_backoff_max"],
)
self.pool_close_timeout = config["pool_close_timeout"]
logger.info(
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
self.connection_retry_attempts,
self.connection_retry_backoff,
self.connection_retry_backoff_max,
self.pool_close_timeout,
)
def _create_ssl_context(self) -> ssl.SSLContext | None: def _create_ssl_context(self) -> ssl.SSLContext | None:
"""Create SSL context based on configuration parameters.""" """Create SSL context based on configuration parameters."""
if not self.ssl_mode: if not self.ssl_mode:
@ -154,59 +191,85 @@ class PostgreSQLDB:
return None return None
async def initdb(self): async def initdb(self):
try: # Prepare connection parameters
# Prepare connection parameters connection_params = {
connection_params = { "user": self.user,
"user": self.user, "password": self.password,
"password": self.password, "database": self.database,
"database": self.database, "host": self.host,
"host": self.host, "port": self.port,
"port": self.port, "min_size": 1,
"min_size": 1, "max_size": self.max,
"max_size": self.max, }
"statement_cache_size": self.statement_cache_size,
}
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(self.statement_cache_size)
logger.info( logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
) )
# Add SSL configuration if provided # Add SSL configuration if provided
ssl_context = self._create_ssl_context() ssl_context = self._create_ssl_context()
if ssl_context is not None: if ssl_context is not None:
connection_params["ssl"] = ssl_context connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied") logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode: elif self.ssl_mode:
# Handle simple SSL modes without custom context # Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]: if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable": elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided # Add server settings if provided
if self.server_settings: if self.server_settings:
try: try:
settings = {} settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2" # The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&") pairs = self.server_settings.split("&")
for pair in pairs: for pair in pairs:
if "=" in pair: if "=" in pair:
key, value = pair.split("=", 1) key, value = pair.split("=", 1)
settings[key] = value settings[key] = value
if settings: if settings:
connection_params["server_settings"] = settings connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}") logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
) )
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
# Ensure VECTOR extension is available async def _create_pool_once() -> None:
async with self.pool.acquire() as connection: pool = await asyncpg.create_pool(**connection_params) # type: ignore
await self.configure_vector_extension(connection) try:
async with pool.acquire() as connection:
await self.configure_vector_extension(connection)
except Exception:
await pool.close()
raise
self.pool = pool
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await _create_pool_once()
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL" ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
logger.info( logger.info(
@ -218,12 +281,97 @@ class PostgreSQLDB:
) )
raise raise
async def _ensure_pool(self) -> None:
"""Ensure the connection pool is initialised."""
if self.pool is None:
async with self._pool_reconnect_lock:
if self.pool is None:
await self.initdb()
async def _reset_pool(self) -> None:
async with self._pool_reconnect_lock:
if self.pool is not None:
try:
await asyncio.wait_for(
self.pool.close(), timeout=self.pool_close_timeout
)
except asyncio.TimeoutError:
logger.error(
"PostgreSQL, Timed out closing connection pool after %.2fs",
self.pool_close_timeout,
)
except Exception as close_error: # pragma: no cover - defensive logging
logger.warning(
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
)
self.pool = None
async def _before_sleep(self, retry_state: RetryCallState) -> None:
"""Hook invoked by tenacity before sleeping between retries."""
exc = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
"PostgreSQL transient connection issue on attempt %s/%s: %r",
retry_state.attempt_number,
self.connection_retry_attempts,
exc,
)
await self._reset_pool()
async def _run_with_retry(
self,
operation: Callable[[asyncpg.Connection], Awaitable[T]],
*,
with_age: bool = False,
graph_name: str | None = None,
) -> T:
"""
Execute a database operation with automatic retry for transient failures.
Args:
operation: Async callable that receives an active connection.
with_age: Whether to configure Apache AGE on the connection.
graph_name: AGE graph name; required when with_age is True.
Returns:
The result returned by the operation.
Raises:
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
"""
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await self._ensure_pool()
assert self.pool is not None
async with self.pool.acquire() as connection: # type: ignore[arg-type]
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
return await operation(connection)
@staticmethod @staticmethod
async def configure_vector_extension(connection: asyncpg.Connection) -> None: async def configure_vector_extension(connection: asyncpg.Connection) -> None:
"""Create VECTOR extension if it doesn't exist for vector similarity operations.""" """Create VECTOR extension if it doesn't exist for vector similarity operations."""
try: try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
logger.info("VECTOR extension ensured for PostgreSQL") logger.info("PostgreSQL, VECTOR extension enabled")
except Exception as e: except Exception as e:
logger.warning(f"Could not create VECTOR extension: {e}") logger.warning(f"Could not create VECTOR extension: {e}")
# Don't raise - let the system continue without vector extension # Don't raise - let the system continue without vector extension
@ -233,7 +381,7 @@ class PostgreSQLDB:
"""Create AGE extension if it doesn't exist for graph operations.""" """Create AGE extension if it doesn't exist for graph operations."""
try: try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("AGE extension ensured for PostgreSQL") logger.info("PostgreSQL, AGE extension enabled")
except Exception as e: except Exception as e:
logger.warning(f"Could not create AGE extension: {e}") logger.warning(f"Could not create AGE extension: {e}")
# Don't raise - let the system continue without AGE extension # Don't raise - let the system continue without AGE extension
@ -1250,35 +1398,31 @@ class PostgreSQLDB:
with_age: bool = False, with_age: bool = False,
graph_name: str | None = None, graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]: ) -> dict[str, Any] | None | list[dict[str, Any]]:
async with self.pool.acquire() as connection: # type: ignore async def _operation(connection: asyncpg.Connection) -> Any:
if with_age and graph_name: prepared_params = tuple(params) if params else ()
await self.configure_age(connection, graph_name) # type: ignore if prepared_params:
elif with_age and not graph_name: rows = await connection.fetch(sql, *prepared_params)
raise ValueError("Graph name is required when with_age is True") else:
rows = await connection.fetch(sql)
try: if multirows:
if params: if rows:
rows = await connection.fetch(sql, *params) columns = [col for col in rows[0].keys()]
else: return [dict(zip(columns, row)) for row in rows]
rows = await connection.fetch(sql) return []
if multirows: if rows:
if rows: columns = rows[0].keys()
columns = [col for col in rows[0].keys()] return dict(zip(columns, rows[0]))
data = [dict(zip(columns, row)) for row in rows] return None
else:
data = []
else:
if rows:
columns = rows[0].keys()
data = dict(zip(columns, rows[0]))
else:
data = None
return data try:
except Exception as e: return await self._run_with_retry(
logger.error(f"PostgreSQL database, error:{e}") _operation, with_age=with_age, graph_name=graph_name
raise )
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
async def execute( async def execute(
self, self,
@ -1289,30 +1433,33 @@ class PostgreSQLDB:
with_age: bool = False, with_age: bool = False,
graph_name: str | None = None, graph_name: str | None = None,
): ):
try: async def _operation(connection: asyncpg.Connection) -> Any:
async with self.pool.acquire() as connection: # type: ignore prepared_values = tuple(data.values()) if data else ()
if with_age and graph_name: try:
await self.configure_age(connection, graph_name) if not data:
elif with_age and not graph_name: return await connection.execute(sql)
raise ValueError("Graph name is required when with_age is True") return await connection.execute(sql, *prepared_values)
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError,
asyncpg.exceptions.InvalidSchemaNameError,
) as e:
if ignore_if_exists:
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
return None
if upsert:
logger.info(
"PostgreSQL, duplicate detected but treated as upsert success: %r",
e,
)
return None
raise
if data is None: try:
await connection.execute(sql) await self._run_with_retry(
else: _operation, with_age=with_age, graph_name=graph_name
await connection.execute(sql, *data.values()) )
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
) as e:
if ignore_if_exists:
# If the flag is set, just ignore these specific errors
pass
elif upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise raise
@ -1406,6 +1553,49 @@ class ClientManager:
"POSTGRES_STATEMENT_CACHE_SIZE", "POSTGRES_STATEMENT_CACHE_SIZE",
config.get("postgres", "statement_cache_size", fallback=None), config.get("postgres", "statement_cache_size", fallback=None),
), ),
# Connection retry configuration
"connection_retry_attempts": min(
10,
int(
os.environ.get(
"POSTGRES_CONNECTION_RETRIES",
config.get("postgres", "connection_retries", fallback=3),
)
),
),
"connection_retry_backoff": min(
5.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF",
config.get(
"postgres", "connection_retry_backoff", fallback=0.5
),
)
),
),
"connection_retry_backoff_max": min(
60.0,
float(
os.environ.get(
"POSTGRES_CONNECTION_RETRY_BACKOFF_MAX",
config.get(
"postgres",
"connection_retry_backoff_max",
fallback=5.0,
),
)
),
),
"pool_close_timeout": min(
30.0,
float(
os.environ.get(
"POSTGRES_POOL_CLOSE_TIMEOUT",
config.get("postgres", "pool_close_timeout", fallback=5.0),
)
),
),
} }
@classmethod @classmethod
@ -1653,12 +1843,33 @@ class PGKVStorage(BaseKVStorage):
# Query by id # Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get data by ids""" """Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( if not ids:
ids=",".join([f"'{id}'" for id in ids]) return []
)
params = {"workspace": self.workspace} sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, list(params.values()), multirows=True) results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
rows: list[dict[str, Any]] | None,
) -> list[dict[str, Any] | None]:
"""Preserve the caller requested ordering for bulk id lookups."""
if not rows:
return [None for _ in ids]
id_map: dict[str, dict[str, Any]] = {}
for row in rows:
if row is None:
continue
row_id = row.get("id")
if row_id is not None:
id_map[str(row_id)] = row
ordered: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered.append(id_map.get(str(requested_id)))
return ordered
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result # Parse llm_cache_list JSON string back to list for each result
for result in results: for result in results:
@ -1701,7 +1912,7 @@ class PGKVStorage(BaseKVStorage):
"update_time": create_time if update_time == 0 else update_time, "update_time": create_time if update_time == 0 else update_time,
} }
processed_results.append(processed_row) processed_results.append(processed_row)
return processed_results return _order_results(processed_results)
# Special handling for FULL_ENTITIES namespace # Special handling for FULL_ENTITIES namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
@ -1735,15 +1946,16 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time result["update_time"] = create_time if update_time == 0 else update_time
return results if results else [] return _order_results(results)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( if not keys:
table_name=namespace_to_table_name(self.namespace), return set()
ids=",".join([f"'{id}'" for id in keys]),
) table_name = namespace_to_table_name(self.namespace)
params = {"workspace": self.workspace} sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try: try:
res = await self.db.query(sql, list(params.values()), multirows=True) res = await self.db.query(sql, list(params.values()), multirows=True)
if res: if res:
@ -2193,7 +2405,23 @@ class PGVectorStorage(BaseVectorStorage):
try: try:
results = await self.db.query(query, list(params.values()), multirows=True) results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results] if not results:
return []
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
id_map: dict[str, dict[str, Any]] = {}
for record in results:
if record is None:
continue
record_dict = dict(record)
row_id = record_dict.get("id")
if row_id is not None:
id_map[str(row_id)] = record_dict
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(id_map.get(str(requested_id)))
return ordered_results
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
@ -2306,11 +2534,12 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( if not keys:
table_name=namespace_to_table_name(self.namespace), return set()
ids=",".join([f"'{id}'" for id in keys]),
) table_name = namespace_to_table_name(self.namespace)
params = {"workspace": self.workspace} sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try: try:
res = await self.db.query(sql, list(params.values()), multirows=True) res = await self.db.query(sql, list(params.values()), multirows=True)
if res: if res:
@ -2381,7 +2610,7 @@ class PGDocStatusStorage(DocStatusStorage):
if not results: if not results:
return [] return []
processed_results = [] processed_map: dict[str, dict[str, Any]] = {}
for row in results: for row in results:
# Parse chunks_list JSON string back to list # Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", []) chunks_list = row.get("chunks_list", [])
@ -2403,23 +2632,25 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = self._format_datetime_with_timezone(row["created_at"]) created_at = self._format_datetime_with_timezone(row["created_at"])
updated_at = self._format_datetime_with_timezone(row["updated_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"])
processed_results.append( processed_map[str(row.get("id"))] = {
{ "content_length": row["content_length"],
"content_length": row["content_length"], "content_summary": row["content_summary"],
"content_summary": row["content_summary"], "status": row["status"],
"status": row["status"], "chunks_count": row["chunks_count"],
"chunks_count": row["chunks_count"], "created_at": created_at,
"created_at": created_at, "updated_at": updated_at,
"updated_at": updated_at, "file_path": row["file_path"],
"file_path": row["file_path"], "chunks_list": chunks_list,
"chunks_list": chunks_list, "metadata": metadata,
"metadata": metadata, "error_msg": row.get("error_msg"),
"error_msg": row.get("error_msg"), "track_id": row.get("track_id"),
"track_id": row.get("track_id"), }
}
)
return processed_results ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(processed_map.get(str(requested_id)))
return ordered_results
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path """Get document by file path
@ -2621,26 +2852,33 @@ class PGDocStatusStorage(DocStatusStorage):
elif page_size > 200: elif page_size > 200:
page_size = 200 page_size = 200
if sort_field not in ["created_at", "updated_at", "id", "file_path"]: # Whitelist validation for sort_field to prevent SQL injection
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
if sort_field not in allowed_sort_fields:
sort_field = "updated_at" sort_field = "updated_at"
# Whitelist validation for sort_direction to prevent SQL injection
if sort_direction.lower() not in ["asc", "desc"]: if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc" sort_direction = "desc"
else:
sort_direction = sort_direction.lower()
# Calculate offset # Calculate offset
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Build WHERE clause # Build parameterized query components
where_clause = "WHERE workspace=$1"
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
param_count = 1 param_count = 1
# Build WHERE clause with parameterized query
if status_filter is not None: if status_filter is not None:
param_count += 1 param_count += 1
where_clause += f" AND status=${param_count}" where_clause = "WHERE workspace=$1 AND status=$2"
params["status"] = status_filter.value params["status"] = status_filter.value
else:
where_clause = "WHERE workspace=$1"
# Build ORDER BY clause # Build ORDER BY clause using validated whitelist values
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}" order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
# Query for total count # Query for total count
@ -2648,7 +2886,7 @@ class PGDocStatusStorage(DocStatusStorage):
count_result = await self.db.query(count_sql, list(params.values())) count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0 total_count = count_result["total"] if count_result else 0
# Query for paginated data # Query for paginated data with parameterized LIMIT and OFFSET
data_sql = f""" data_sql = f"""
SELECT * FROM LIGHTRAG_DOC_STATUS SELECT * FROM LIGHTRAG_DOC_STATUS
{where_clause} {where_clause}
@ -3179,7 +3417,8 @@ class PGGraphStorage(BaseGraphStorage):
{ {
"message": f"Error executing graph query: {query}", "message": f"Error executing graph query: {query}",
"wrapped": query, "wrapped": query,
"detail": str(e), "detail": repr(e),
"error_type": e.__class__.__name__,
} }
) from e ) from e
@ -4645,19 +4884,19 @@ SQL_TEMPLATES = {
""", """,
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
COALESCE(doc_name, '') as file_path COALESCE(doc_name, '') as file_path
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
""", """,
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path, chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""", """,
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam, "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
""", """,
"get_by_id_full_entities": """SELECT id, entity_names, count, "get_by_id_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
@ -4672,12 +4911,12 @@ SQL_TEMPLATES = {
"get_by_ids_full_entities": """SELECT id, entity_names, count, "get_by_ids_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
""", """,
"get_by_ids_full_relations": """SELECT id, relation_pairs, count, "get_by_ids_full_relations": """SELECT id, relation_pairs, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids}) FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
""", """,
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)