From 7ce46bacb60c67ad9bac558a94150982434a3d44 Mon Sep 17 00:00:00 2001 From: kevinnkansah Date: Sun, 5 Oct 2025 23:29:04 +0200 Subject: [PATCH] feat: add options for PostGres connection (cherry picked from commit 108cdbe133b2d2ac62adfbba63b5ceb4a220f5ef) --- env.example | 65 ++--- lightrag/kg/postgres_impl.py | 513 ++++++++++------------------------- 2 files changed, 167 insertions(+), 411 deletions(-) diff --git a/env.example b/env.example index 7a41daa4..b82a4cf0 100644 --- a/env.example +++ b/env.example @@ -23,13 +23,13 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System" # WORKING_DIR= ### Tiktoken cache directory (Store cached files in this folder for offline deployment) -# TIKTOKEN_CACHE_DIR=/app/data/tiktoken +# TIKTOKEN_CACHE_DIR=./temp/tiktoken ### Ollama Emulating Model and Tag # OLLAMA_EMULATING_MODEL_NAME=lightrag OLLAMA_EMULATING_MODEL_TAG=latest -### Max nodes return from graph retrieval in webui +### Max nodes return from grap retrieval in webui # MAX_GRAPH_NODES=1000 ### Logging level @@ -50,32 +50,35 @@ OLLAMA_EMULATING_MODEL_TAG=latest # JWT_ALGORITHM=HS256 ### API-Key to access LightRAG Server API -### Use this key in HTTP requests with the 'X-API-Key' header -### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query # LIGHTRAG_API_KEY=your-secure-api-key-here # WHITELIST_PATHS=/health,/api/* ###################################################################################### ### Query Configuration ### -### How to control the context length sent to LLM: +### How to control the context lenght sent to LLM: ### MAX_ENTITY_TOKENS + MAX_RELATION_TOKENS < MAX_TOTAL_TOKENS -### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Relation_Tokens +### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Reation_Tokens ###################################################################################### -# LLM response cache for query (Not valid for streaming response) +# LLM responde cache for query (Not valid for streaming response) ENABLE_LLM_CACHE=true # COSINE_THRESHOLD=0.2 ### Number of entities or relations retrieved from KG # TOP_K=40 -### Maximum number or chunks for naive vector search +### Maxmium number or chunks for naive vector search # CHUNK_TOP_K=20 -### control the actual entities send to LLM +### control the actual enties send to LLM # MAX_ENTITY_TOKENS=6000 ### control the actual relations send to LLM # MAX_RELATION_TOKENS=8000 -### control the maximum tokens send to LLM (include entities, relations and chunks) +### control the maximum tokens send to LLM (include entities, raltions and chunks) # MAX_TOTAL_TOKENS=30000 +### maximum number of related chunks per source entity or relation +### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) +### Higher values increase re-ranking time +# RELATED_CHUNK_NUMBER=5 + ### chunk selection strategies ### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval ### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM @@ -90,7 +93,7 @@ ENABLE_LLM_CACHE=true RERANK_BINDING=null ### Enable rerank by default in query params when RERANK_BINDING is not null # RERANK_BY_DEFAULT=True -### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enough) +### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought) # MIN_RERANK_SCORE=0.0 ### For local deployment with vLLM @@ -128,7 +131,7 @@ SUMMARY_LANGUAGE=English # CHUNK_SIZE=1200 # CHUNK_OVERLAP_SIZE=100 -### Number of summary segments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommended) +### Number of summary semgments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommented) # FORCE_LLM_SUMMARY_ON_MERGE=8 ### Max description token size to trigger LLM summary # SUMMARY_MAX_TOKENS = 1200 @@ -137,19 +140,6 @@ SUMMARY_LANGUAGE=English ### Maximum context size sent to LLM for description summary # SUMMARY_CONTEXT_SIZE=12000 -### control the maximum chunk_ids stored in vector and graph db -# MAX_SOURCE_IDS_PER_ENTITY=300 -# MAX_SOURCE_IDS_PER_RELATION=300 -### control chunk_ids limitation method: KEEP, FIFO (KEEP: Keep oldest, FIFO: First in first out) -# SOURCE_IDS_LIMIT_METHOD=KEEP -### Maximum number of file paths stored in entity/relation file_path field -# MAX_FILE_PATHS=30 - -### maximum number of related chunks per source entity or relation -### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) -### Higher values increase re-ranking time -# RELATED_CHUNK_NUMBER=5 - ############################### ### Concurrency Configuration ############################### @@ -189,7 +179,7 @@ LLM_BINDING_API_KEY=your_api_key # OPENAI_LLM_TEMPERATURE=0.9 ### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s) ### Typically, max_tokens does not include prompt content, though some models, such as Gemini Models, are exceptions -### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider +### For vLLM/SGLang doployed models, or most of OpenAI compatible API provider # OPENAI_LLM_MAX_TOKENS=9000 ### For OpenAI o1-mini or newer modles OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 @@ -203,7 +193,7 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 # OPENAI_LLM_REASONING_EFFORT=minimal ### OpenRouter Specific Parameters # OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}' -### Qwen3 Specific Parameters deploy by vLLM +### Qwen3 Specific Parameters depoly by vLLM # OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}' ### use the following command to see all support options for Ollama LLM @@ -257,8 +247,8 @@ OLLAMA_EMBEDDING_NUM_CTX=8192 ### lightrag-server --embedding-binding ollama --help #################################################################### -### WORKSPACE sets workspace name for all storage types -### for the purpose of isolating data from LightRAG instances. +### WORKSPACE setting workspace name for all storage types +### in the purpose of isolating data from LightRAG instances. ### Valid workspace name constraints: a-z, A-Z, 0-9, and _ #################################################################### # WORKSPACE=space1 @@ -313,16 +303,6 @@ POSTGRES_HNSW_M=16 POSTGRES_HNSW_EF=200 POSTGRES_IVFFLAT_LISTS=100 -### PostgreSQL Connection Retry Configuration (Network Robustness) -### Number of retry attempts (1-10, default: 3) -### Initial retry backoff in seconds (0.1-5.0, default: 0.5) -### Maximum retry backoff in seconds (backoff-60.0, default: 5.0) -### Connection pool close timeout in seconds (1.0-30.0, default: 5.0) -# POSTGRES_CONNECTION_RETRIES=3 -# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5 -# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0 -# POSTGRES_POOL_CLOSE_TIMEOUT=5.0 - ### PostgreSQL SSL Configuration (Optional) # POSTGRES_SSL_MODE=require # POSTGRES_SSL_CERT=/path/to/client-cert.pem @@ -330,13 +310,10 @@ POSTGRES_IVFFLAT_LISTS=100 # POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem # POSTGRES_SSL_CRL=/path/to/crl.pem -### PostgreSQL Server Settings (for Supabase Supavisor) +### PostgreSQL Server Options (for Supabase Supavisor) # Use this to pass extra options to the PostgreSQL connection string. # For Supabase, you might need to set it like this: -# POSTGRES_SERVER_SETTINGS="options=reference%3D[project-ref]" - -# Default is 100 set to 0 to disable -# POSTGRES_STATEMENT_CACHE_SIZE=100 +# POSTGRES_SERVER_OPTIONS="options=reference%3D[project-ref]" ### Neo4j Configuration NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 4a04e897..37b39b49 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -5,7 +5,7 @@ import re import datetime from datetime import timezone from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, TypeVar, Union, final +from typing import Any, Union, final import numpy as np import configparser import ssl @@ -14,13 +14,10 @@ import itertools from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( - AsyncRetrying, - RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, - wait_fixed, ) from ..base import ( @@ -51,8 +48,6 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) -T = TypeVar("T") - class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -82,55 +77,9 @@ class PostgreSQLDB: # Server settings self.server_settings = config.get("server_settings") - # Statement LRU cache size (keep as-is, allow None for optional configuration) - self.statement_cache_size = config.get("statement_cache_size") - - # Connection retry configuration - self.connection_retry_attempts = max( - 1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3))) - ) - self.connection_retry_backoff = max( - 0.1, - min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))), - ) - self.connection_retry_backoff_max = max( - self.connection_retry_backoff, - min( - 60.0, - float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)), - ), - ) - self.pool_close_timeout = max( - 1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0))) - ) - - self._transient_exceptions = ( - asyncio.TimeoutError, - TimeoutError, - ConnectionError, - OSError, - asyncpg.exceptions.InterfaceError, - asyncpg.exceptions.TooManyConnectionsError, - asyncpg.exceptions.CannotConnectNowError, - asyncpg.exceptions.PostgresConnectionError, - asyncpg.exceptions.ConnectionDoesNotExistError, - asyncpg.exceptions.ConnectionFailureError, - ) - - # Guard concurrent pool resets - self._pool_reconnect_lock = asyncio.Lock() - if self.user is None or self.password is None or self.database is None: raise ValueError("Missing database user, password, or database") - logger.info( - "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs", - self.connection_retry_attempts, - self.connection_retry_backoff, - self.connection_retry_backoff_max, - self.pool_close_timeout, - ) - def _create_ssl_context(self) -> ssl.SSLContext | None: """Create SSL context based on configuration parameters.""" if not self.ssl_mode: @@ -202,87 +151,54 @@ class PostgreSQLDB: return None async def initdb(self): - # Prepare connection parameters - connection_params = { - "user": self.user, - "password": self.password, - "database": self.database, - "host": self.host, - "port": self.port, - "min_size": 1, - "max_size": self.max, - } - - # Only add statement_cache_size if it's configured - if self.statement_cache_size is not None: - connection_params["statement_cache_size"] = int( - self.statement_cache_size - ) - logger.info( - f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" - ) - - # Add SSL configuration if provided - ssl_context = self._create_ssl_context() - if ssl_context is not None: - connection_params["ssl"] = ssl_context - logger.info("PostgreSQL, SSL configuration applied") - elif self.ssl_mode: - # Handle simple SSL modes without custom context - if self.ssl_mode.lower() in ["require", "prefer"]: - connection_params["ssl"] = True - elif self.ssl_mode.lower() == "disable": - connection_params["ssl"] = False - logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") - - # Add server settings if provided - if self.server_settings: - try: - settings = {} - # The format is expected to be a query string, e.g., "key1=value1&key2=value2" - pairs = self.server_settings.split("&") - for pair in pairs: - if "=" in pair: - key, value = pair.split("=", 1) - settings[key] = value - if settings: - connection_params["server_settings"] = settings - logger.info(f"PostgreSQL, Server settings applied: {settings}") - except Exception as e: - logger.warning( - f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" - ) - - wait_strategy = ( - wait_exponential( - multiplier=self.connection_retry_backoff, - min=self.connection_retry_backoff, - max=self.connection_retry_backoff_max, - ) - if self.connection_retry_backoff > 0 - else wait_fixed(0) - ) - - async def _create_pool_once() -> None: - pool = await asyncpg.create_pool(**connection_params) # type: ignore - try: - async with pool.acquire() as connection: - await self.configure_vector_extension(connection) - except Exception: - await pool.close() - raise - self.pool = pool - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(self.connection_retry_attempts), - retry=retry_if_exception_type(self._transient_exceptions), - wait=wait_strategy, - before_sleep=self._before_sleep, - reraise=True, - ): - with attempt: - await _create_pool_once() + # Prepare connection parameters + connection_params = { + "user": self.user, + "password": self.password, + "database": self.database, + "host": self.host, + "port": self.port, + "min_size": 1, + "max_size": self.max, + } + + # Add SSL configuration if provided + ssl_context = self._create_ssl_context() + if ssl_context is not None: + connection_params["ssl"] = ssl_context + logger.info("PostgreSQL, SSL configuration applied") + elif self.ssl_mode: + # Handle simple SSL modes without custom context + if self.ssl_mode.lower() in ["require", "prefer"]: + connection_params["ssl"] = True + elif self.ssl_mode.lower() == "disable": + connection_params["ssl"] = False + logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") + + # Add server settings if provided + if self.server_settings: + try: + settings = {} + # The format is expected to be a query string, e.g., "key1=value1&key2=value2" + pairs = self.server_settings.split("&") + for pair in pairs: + if "=" in pair: + key, value = pair.split("=", 1) + settings[key] = value + if settings: + connection_params["server_settings"] = settings + logger.info(f"PostgreSQL, Server settings applied: {settings}") + except Exception as e: + logger.warning( + f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" + ) + + self.pool = await asyncpg.create_pool(**connection_params) # type: ignore + + # Ensure VECTOR extension is available + async with self.pool.acquire() as connection: + await self.configure_vector_extension(connection) ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL" logger.info( @@ -294,97 +210,12 @@ class PostgreSQLDB: ) raise - async def _ensure_pool(self) -> None: - """Ensure the connection pool is initialised.""" - if self.pool is None: - async with self._pool_reconnect_lock: - if self.pool is None: - await self.initdb() - - async def _reset_pool(self) -> None: - async with self._pool_reconnect_lock: - if self.pool is not None: - try: - await asyncio.wait_for( - self.pool.close(), timeout=self.pool_close_timeout - ) - except asyncio.TimeoutError: - logger.error( - "PostgreSQL, Timed out closing connection pool after %.2fs", - self.pool_close_timeout, - ) - except Exception as close_error: # pragma: no cover - defensive logging - logger.warning( - f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}" - ) - self.pool = None - - async def _before_sleep(self, retry_state: RetryCallState) -> None: - """Hook invoked by tenacity before sleeping between retries.""" - exc = retry_state.outcome.exception() if retry_state.outcome else None - logger.warning( - "PostgreSQL transient connection issue on attempt %s/%s: %r", - retry_state.attempt_number, - self.connection_retry_attempts, - exc, - ) - await self._reset_pool() - - async def _run_with_retry( - self, - operation: Callable[[asyncpg.Connection], Awaitable[T]], - *, - with_age: bool = False, - graph_name: str | None = None, - ) -> T: - """ - Execute a database operation with automatic retry for transient failures. - - Args: - operation: Async callable that receives an active connection. - with_age: Whether to configure Apache AGE on the connection. - graph_name: AGE graph name; required when with_age is True. - - Returns: - The result returned by the operation. - - Raises: - Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs. - """ - wait_strategy = ( - wait_exponential( - multiplier=self.connection_retry_backoff, - min=self.connection_retry_backoff, - max=self.connection_retry_backoff_max, - ) - if self.connection_retry_backoff > 0 - else wait_fixed(0) - ) - - async for attempt in AsyncRetrying( - stop=stop_after_attempt(self.connection_retry_attempts), - retry=retry_if_exception_type(self._transient_exceptions), - wait=wait_strategy, - before_sleep=self._before_sleep, - reraise=True, - ): - with attempt: - await self._ensure_pool() - assert self.pool is not None - async with self.pool.acquire() as connection: # type: ignore[arg-type] - if with_age and graph_name: - await self.configure_age(connection, graph_name) - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") - - return await operation(connection) - @staticmethod async def configure_vector_extension(connection: asyncpg.Connection) -> None: """Create VECTOR extension if it doesn't exist for vector similarity operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore - logger.info("PostgreSQL, VECTOR extension enabled") + logger.info("VECTOR extension ensured for PostgreSQL") except Exception as e: logger.warning(f"Could not create VECTOR extension: {e}") # Don't raise - let the system continue without vector extension @@ -394,7 +225,7 @@ class PostgreSQLDB: """Create AGE extension if it doesn't exist for graph operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore - logger.info("PostgreSQL, AGE extension enabled") + logger.info("AGE extension ensured for PostgreSQL") except Exception as e: logger.warning(f"Could not create AGE extension: {e}") # Don't raise - let the system continue without AGE extension @@ -1411,31 +1242,35 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ) -> dict[str, Any] | None | list[dict[str, Any]]: - async def _operation(connection: asyncpg.Connection) -> Any: - prepared_params = tuple(params) if params else () - if prepared_params: - rows = await connection.fetch(sql, *prepared_params) - else: - rows = await connection.fetch(sql) + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) # type: ignore + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") - if multirows: - if rows: - columns = [col for col in rows[0].keys()] - return [dict(zip(columns, row)) for row in rows] - return [] + try: + if params: + rows = await connection.fetch(sql, *params) + else: + rows = await connection.fetch(sql) - if rows: - columns = rows[0].keys() - return dict(zip(columns, rows[0])) - return None + if multirows: + if rows: + columns = [col for col in rows[0].keys()] + data = [dict(zip(columns, row)) for row in rows] + else: + data = [] + else: + if rows: + columns = rows[0].keys() + data = dict(zip(columns, rows[0])) + else: + data = None - try: - return await self._run_with_retry( - _operation, with_age=with_age, graph_name=graph_name - ) - except Exception as e: - logger.error(f"PostgreSQL database, error:{e}") - raise + return data + except Exception as e: + logger.error(f"PostgreSQL database, error:{e}") + raise async def execute( self, @@ -1446,33 +1281,30 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ): - async def _operation(connection: asyncpg.Connection) -> Any: - prepared_values = tuple(data.values()) if data else () - try: - if not data: - return await connection.execute(sql) - return await connection.execute(sql, *prepared_values) - except ( - asyncpg.exceptions.UniqueViolationError, - asyncpg.exceptions.DuplicateTableError, - asyncpg.exceptions.DuplicateObjectError, - asyncpg.exceptions.InvalidSchemaNameError, - ) as e: - if ignore_if_exists: - logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e) - return None - if upsert: - logger.info( - "PostgreSQL, duplicate detected but treated as upsert success: %r", - e, - ) - return None - raise - try: - await self._run_with_retry( - _operation, with_age=with_age, graph_name=graph_name - ) + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + + if data is None: + await connection.execute(sql) + else: + await connection.execute(sql, *data.values()) + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error + asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists" + ) as e: + if ignore_if_exists: + # If the flag is set, just ignore these specific errors + pass + elif upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") except Exception as e: logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise @@ -1559,13 +1391,9 @@ class ClientManager: ), # Server settings for Supabase "server_settings": os.environ.get( - "POSTGRES_SERVER_SETTINGS", + "POSTGRES_SERVER_OPTIONS", config.get("postgres", "server_options", fallback=None), ), - "statement_cache_size": os.environ.get( - "POSTGRES_STATEMENT_CACHE_SIZE", - config.get("postgres", "statement_cache_size", fallback=None), - ), } @classmethod @@ -1813,33 +1641,12 @@ class PGKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get data by ids""" - if not ids: - return [] - - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace] - params = {"workspace": self.workspace, "ids": ids} + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.workspace} results = await self.db.query(sql, list(params.values()), multirows=True) - def _order_results( - rows: list[dict[str, Any]] | None, - ) -> list[dict[str, Any] | None]: - """Preserve the caller requested ordering for bulk id lookups.""" - if not rows: - return [None for _ in ids] - - id_map: dict[str, dict[str, Any]] = {} - for row in rows: - if row is None: - continue - row_id = row.get("id") - if row_id is not None: - id_map[str(row_id)] = row - - ordered: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered.append(id_map.get(str(requested_id))) - return ordered - if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result for result in results: @@ -1882,7 +1689,7 @@ class PGKVStorage(BaseKVStorage): "update_time": create_time if update_time == 0 else update_time, } processed_results.append(processed_row) - return _order_results(processed_results) + return processed_results # Special handling for FULL_ENTITIES namespace if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): @@ -1916,16 +1723,15 @@ class PGKVStorage(BaseKVStorage): result["create_time"] = create_time result["update_time"] = create_time if update_time == 0 else update_time - return _order_results(results) + return results if results else [] async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" - if not keys: - return set() - - table_name = namespace_to_table_name(self.namespace) - sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.workspace, "ids": list(keys)} + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.workspace} try: res = await self.db.query(sql, list(params.values()), multirows=True) if res: @@ -2375,23 +2181,7 @@ class PGVectorStorage(BaseVectorStorage): try: results = await self.db.query(query, list(params.values()), multirows=True) - if not results: - return [] - - # Preserve caller requested ordering while normalizing asyncpg rows to dicts. - id_map: dict[str, dict[str, Any]] = {} - for record in results: - if record is None: - continue - record_dict = dict(record) - row_id = record_dict.get("id") - if row_id is not None: - id_map[str(row_id)] = record_dict - - ordered_results: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered_results.append(id_map.get(str(requested_id))) - return ordered_results + return [dict(record) for record in results] except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" @@ -2504,12 +2294,11 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" - if not keys: - return set() - - table_name = namespace_to_table_name(self.namespace) - sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.workspace, "ids": list(keys)} + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.workspace} try: res = await self.db.query(sql, list(params.values()), multirows=True) if res: @@ -2580,7 +2369,7 @@ class PGDocStatusStorage(DocStatusStorage): if not results: return [] - processed_map: dict[str, dict[str, Any]] = {} + processed_results = [] for row in results: # Parse chunks_list JSON string back to list chunks_list = row.get("chunks_list", []) @@ -2602,25 +2391,23 @@ class PGDocStatusStorage(DocStatusStorage): created_at = self._format_datetime_with_timezone(row["created_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"]) - processed_map[str(row.get("id"))] = { - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": created_at, - "updated_at": updated_at, - "file_path": row["file_path"], - "chunks_list": chunks_list, - "metadata": metadata, - "error_msg": row.get("error_msg"), - "track_id": row.get("track_id"), - } + processed_results.append( + { + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": created_at, + "updated_at": updated_at, + "file_path": row["file_path"], + "chunks_list": chunks_list, + "metadata": metadata, + "error_msg": row.get("error_msg"), + "track_id": row.get("track_id"), + } + ) - ordered_results: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered_results.append(processed_map.get(str(requested_id))) - - return ordered_results + return processed_results async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Get document by file path @@ -2822,33 +2609,26 @@ class PGDocStatusStorage(DocStatusStorage): elif page_size > 200: page_size = 200 - # Whitelist validation for sort_field to prevent SQL injection - allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"} - if sort_field not in allowed_sort_fields: + if sort_field not in ["created_at", "updated_at", "id", "file_path"]: sort_field = "updated_at" - # Whitelist validation for sort_direction to prevent SQL injection if sort_direction.lower() not in ["asc", "desc"]: sort_direction = "desc" - else: - sort_direction = sort_direction.lower() # Calculate offset offset = (page - 1) * page_size - # Build parameterized query components + # Build WHERE clause + where_clause = "WHERE workspace=$1" params = {"workspace": self.workspace} param_count = 1 - # Build WHERE clause with parameterized query if status_filter is not None: param_count += 1 - where_clause = "WHERE workspace=$1 AND status=$2" + where_clause += f" AND status=${param_count}" params["status"] = status_filter.value - else: - where_clause = "WHERE workspace=$1" - # Build ORDER BY clause using validated whitelist values + # Build ORDER BY clause order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}" # Query for total count @@ -2856,7 +2636,7 @@ class PGDocStatusStorage(DocStatusStorage): count_result = await self.db.query(count_sql, list(params.values())) total_count = count_result["total"] if count_result else 0 - # Query for paginated data with parameterized LIMIT and OFFSET + # Query for paginated data data_sql = f""" SELECT * FROM LIGHTRAG_DOC_STATUS {where_clause} @@ -3387,8 +3167,7 @@ class PGGraphStorage(BaseGraphStorage): { "message": f"Error executing graph query: {query}", "wrapped": query, - "detail": repr(e), - "error_type": e.__class__.__name__, + "detail": str(e), } ) from e @@ -4854,19 +4633,19 @@ SQL_TEMPLATES = { """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content, COALESCE(doc_name, '') as file_path - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, chunk_order_index, full_doc_id, file_path, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) """, "get_by_id_full_entities": """SELECT id, entity_names, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, @@ -4881,12 +4660,12 @@ SQL_TEMPLATES = { "get_by_ids_full_entities": """SELECT id, entity_names, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_full_relations": """SELECT id, relation_pairs, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)