feat: add options for PostGres connection

(cherry picked from commit 108cdbe133)
This commit is contained in:
kevinnkansah 2025-10-05 23:29:04 +02:00 committed by Raphaël MANSUY
parent fe05563ecb
commit 7ce46bacb6
2 changed files with 167 additions and 411 deletions

View file

@ -23,13 +23,13 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
# WORKING_DIR=<absolute_path_for_working_dir>
### Tiktoken cache directory (Store cached files in this folder for offline deployment)
# TIKTOKEN_CACHE_DIR=/app/data/tiktoken
# TIKTOKEN_CACHE_DIR=./temp/tiktoken
### Ollama Emulating Model and Tag
# OLLAMA_EMULATING_MODEL_NAME=lightrag
OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from graph retrieval in webui
### Max nodes return from grap retrieval in webui
# MAX_GRAPH_NODES=1000
### Logging level
@ -50,32 +50,35 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# JWT_ALGORITHM=HS256
### API-Key to access LightRAG Server API
### Use this key in HTTP requests with the 'X-API-Key' header
### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query
# LIGHTRAG_API_KEY=your-secure-api-key-here
# WHITELIST_PATHS=/health,/api/*
######################################################################################
### Query Configuration
###
### How to control the context length sent to LLM:
### How to control the context lenght sent to LLM:
### MAX_ENTITY_TOKENS + MAX_RELATION_TOKENS < MAX_TOTAL_TOKENS
### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Relation_Tokens
### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Reation_Tokens
######################################################################################
# LLM response cache for query (Not valid for streaming response)
# LLM responde cache for query (Not valid for streaming response)
ENABLE_LLM_CACHE=true
# COSINE_THRESHOLD=0.2
### Number of entities or relations retrieved from KG
# TOP_K=40
### Maximum number or chunks for naive vector search
### Maxmium number or chunks for naive vector search
# CHUNK_TOP_K=20
### control the actual entities send to LLM
### control the actual enties send to LLM
# MAX_ENTITY_TOKENS=6000
### control the actual relations send to LLM
# MAX_RELATION_TOKENS=8000
### control the maximum tokens send to LLM (include entities, relations and chunks)
### control the maximum tokens send to LLM (include entities, raltions and chunks)
# MAX_TOTAL_TOKENS=30000
### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
### Higher values increase re-ranking time
# RELATED_CHUNK_NUMBER=5
### chunk selection strategies
### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval
### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM
@ -90,7 +93,7 @@ ENABLE_LLM_CACHE=true
RERANK_BINDING=null
### Enable rerank by default in query params when RERANK_BINDING is not null
# RERANK_BY_DEFAULT=True
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enough)
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
# MIN_RERANK_SCORE=0.0
### For local deployment with vLLM
@ -128,7 +131,7 @@ SUMMARY_LANGUAGE=English
# CHUNK_SIZE=1200
# CHUNK_OVERLAP_SIZE=100
### Number of summary segments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommended)
### Number of summary semgments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommented)
# FORCE_LLM_SUMMARY_ON_MERGE=8
### Max description token size to trigger LLM summary
# SUMMARY_MAX_TOKENS = 1200
@ -137,19 +140,6 @@ SUMMARY_LANGUAGE=English
### Maximum context size sent to LLM for description summary
# SUMMARY_CONTEXT_SIZE=12000
### control the maximum chunk_ids stored in vector and graph db
# MAX_SOURCE_IDS_PER_ENTITY=300
# MAX_SOURCE_IDS_PER_RELATION=300
### control chunk_ids limitation method: KEEP, FIFO (KEEP: Keep oldest, FIFO: First in first out)
# SOURCE_IDS_LIMIT_METHOD=KEEP
### Maximum number of file paths stored in entity/relation file_path field
# MAX_FILE_PATHS=30
### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
### Higher values increase re-ranking time
# RELATED_CHUNK_NUMBER=5
###############################
### Concurrency Configuration
###############################
@ -189,7 +179,7 @@ LLM_BINDING_API_KEY=your_api_key
# OPENAI_LLM_TEMPERATURE=0.9
### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
### Typically, max_tokens does not include prompt content, though some models, such as Gemini Models, are exceptions
### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider
### For vLLM/SGLang doployed models, or most of OpenAI compatible API provider
# OPENAI_LLM_MAX_TOKENS=9000
### For OpenAI o1-mini or newer modles
OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
@ -203,7 +193,7 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
# OPENAI_LLM_REASONING_EFFORT=minimal
### OpenRouter Specific Parameters
# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}'
### Qwen3 Specific Parameters deploy by vLLM
### Qwen3 Specific Parameters depoly by vLLM
# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}'
### use the following command to see all support options for Ollama LLM
@ -257,8 +247,8 @@ OLLAMA_EMBEDDING_NUM_CTX=8192
### lightrag-server --embedding-binding ollama --help
####################################################################
### WORKSPACE sets workspace name for all storage types
### for the purpose of isolating data from LightRAG instances.
### WORKSPACE setting workspace name for all storage types
### in the purpose of isolating data from LightRAG instances.
### Valid workspace name constraints: a-z, A-Z, 0-9, and _
####################################################################
# WORKSPACE=space1
@ -313,16 +303,6 @@ POSTGRES_HNSW_M=16
POSTGRES_HNSW_EF=200
POSTGRES_IVFFLAT_LISTS=100
### PostgreSQL Connection Retry Configuration (Network Robustness)
### Number of retry attempts (1-10, default: 3)
### Initial retry backoff in seconds (0.1-5.0, default: 0.5)
### Maximum retry backoff in seconds (backoff-60.0, default: 5.0)
### Connection pool close timeout in seconds (1.0-30.0, default: 5.0)
# POSTGRES_CONNECTION_RETRIES=3
# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5
# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0
# POSTGRES_POOL_CLOSE_TIMEOUT=5.0
### PostgreSQL SSL Configuration (Optional)
# POSTGRES_SSL_MODE=require
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
@ -330,13 +310,10 @@ POSTGRES_IVFFLAT_LISTS=100
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
# POSTGRES_SSL_CRL=/path/to/crl.pem
### PostgreSQL Server Settings (for Supabase Supavisor)
### PostgreSQL Server Options (for Supabase Supavisor)
# Use this to pass extra options to the PostgreSQL connection string.
# For Supabase, you might need to set it like this:
# POSTGRES_SERVER_SETTINGS="options=reference%3D[project-ref]"
# Default is 100 set to 0 to disable
# POSTGRES_STATEMENT_CACHE_SIZE=100
# POSTGRES_SERVER_OPTIONS="options=reference%3D[project-ref]"
### Neo4j Configuration
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io

View file

@ -5,7 +5,7 @@ import re
import datetime
from datetime import timezone
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, TypeVar, Union, final
from typing import Any, Union, final
import numpy as np
import configparser
import ssl
@ -14,13 +14,10 @@ import itertools
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from tenacity import (
AsyncRetrying,
RetryCallState,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
from ..base import (
@ -51,8 +48,6 @@ from dotenv import load_dotenv
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
T = TypeVar("T")
class PostgreSQLDB:
def __init__(self, config: dict[str, Any], **kwargs: Any):
@ -82,55 +77,9 @@ class PostgreSQLDB:
# Server settings
self.server_settings = config.get("server_settings")
# Statement LRU cache size (keep as-is, allow None for optional configuration)
self.statement_cache_size = config.get("statement_cache_size")
# Connection retry configuration
self.connection_retry_attempts = max(
1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3)))
)
self.connection_retry_backoff = max(
0.1,
min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))),
)
self.connection_retry_backoff_max = max(
self.connection_retry_backoff,
min(
60.0,
float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)),
),
)
self.pool_close_timeout = max(
1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0)))
)
self._transient_exceptions = (
asyncio.TimeoutError,
TimeoutError,
ConnectionError,
OSError,
asyncpg.exceptions.InterfaceError,
asyncpg.exceptions.TooManyConnectionsError,
asyncpg.exceptions.CannotConnectNowError,
asyncpg.exceptions.PostgresConnectionError,
asyncpg.exceptions.ConnectionDoesNotExistError,
asyncpg.exceptions.ConnectionFailureError,
)
# Guard concurrent pool resets
self._pool_reconnect_lock = asyncio.Lock()
if self.user is None or self.password is None or self.database is None:
raise ValueError("Missing database user, password, or database")
logger.info(
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
self.connection_retry_attempts,
self.connection_retry_backoff,
self.connection_retry_backoff_max,
self.pool_close_timeout,
)
def _create_ssl_context(self) -> ssl.SSLContext | None:
"""Create SSL context based on configuration parameters."""
if not self.ssl_mode:
@ -202,87 +151,54 @@ class PostgreSQLDB:
return None
async def initdb(self):
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
}
# Only add statement_cache_size if it's configured
if self.statement_cache_size is not None:
connection_params["statement_cache_size"] = int(
self.statement_cache_size
)
logger.info(
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
)
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async def _create_pool_once() -> None:
pool = await asyncpg.create_pool(**connection_params) # type: ignore
try:
async with pool.acquire() as connection:
await self.configure_vector_extension(connection)
except Exception:
await pool.close()
raise
self.pool = pool
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await _create_pool_once()
# Prepare connection parameters
connection_params = {
"user": self.user,
"password": self.password,
"database": self.database,
"host": self.host,
"port": self.port,
"min_size": 1,
"max_size": self.max,
}
# Add SSL configuration if provided
ssl_context = self._create_ssl_context()
if ssl_context is not None:
connection_params["ssl"] = ssl_context
logger.info("PostgreSQL, SSL configuration applied")
elif self.ssl_mode:
# Handle simple SSL modes without custom context
if self.ssl_mode.lower() in ["require", "prefer"]:
connection_params["ssl"] = True
elif self.ssl_mode.lower() == "disable":
connection_params["ssl"] = False
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
# Add server settings if provided
if self.server_settings:
try:
settings = {}
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
pairs = self.server_settings.split("&")
for pair in pairs:
if "=" in pair:
key, value = pair.split("=", 1)
settings[key] = value
if settings:
connection_params["server_settings"] = settings
logger.info(f"PostgreSQL, Server settings applied: {settings}")
except Exception as e:
logger.warning(
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
)
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
# Ensure VECTOR extension is available
async with self.pool.acquire() as connection:
await self.configure_vector_extension(connection)
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
logger.info(
@ -294,97 +210,12 @@ class PostgreSQLDB:
)
raise
async def _ensure_pool(self) -> None:
"""Ensure the connection pool is initialised."""
if self.pool is None:
async with self._pool_reconnect_lock:
if self.pool is None:
await self.initdb()
async def _reset_pool(self) -> None:
async with self._pool_reconnect_lock:
if self.pool is not None:
try:
await asyncio.wait_for(
self.pool.close(), timeout=self.pool_close_timeout
)
except asyncio.TimeoutError:
logger.error(
"PostgreSQL, Timed out closing connection pool after %.2fs",
self.pool_close_timeout,
)
except Exception as close_error: # pragma: no cover - defensive logging
logger.warning(
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
)
self.pool = None
async def _before_sleep(self, retry_state: RetryCallState) -> None:
"""Hook invoked by tenacity before sleeping between retries."""
exc = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
"PostgreSQL transient connection issue on attempt %s/%s: %r",
retry_state.attempt_number,
self.connection_retry_attempts,
exc,
)
await self._reset_pool()
async def _run_with_retry(
self,
operation: Callable[[asyncpg.Connection], Awaitable[T]],
*,
with_age: bool = False,
graph_name: str | None = None,
) -> T:
"""
Execute a database operation with automatic retry for transient failures.
Args:
operation: Async callable that receives an active connection.
with_age: Whether to configure Apache AGE on the connection.
graph_name: AGE graph name; required when with_age is True.
Returns:
The result returned by the operation.
Raises:
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
"""
wait_strategy = (
wait_exponential(
multiplier=self.connection_retry_backoff,
min=self.connection_retry_backoff,
max=self.connection_retry_backoff_max,
)
if self.connection_retry_backoff > 0
else wait_fixed(0)
)
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.connection_retry_attempts),
retry=retry_if_exception_type(self._transient_exceptions),
wait=wait_strategy,
before_sleep=self._before_sleep,
reraise=True,
):
with attempt:
await self._ensure_pool()
assert self.pool is not None
async with self.pool.acquire() as connection: # type: ignore[arg-type]
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
return await operation(connection)
@staticmethod
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
logger.info("PostgreSQL, VECTOR extension enabled")
logger.info("VECTOR extension ensured for PostgreSQL")
except Exception as e:
logger.warning(f"Could not create VECTOR extension: {e}")
# Don't raise - let the system continue without vector extension
@ -394,7 +225,7 @@ class PostgreSQLDB:
"""Create AGE extension if it doesn't exist for graph operations."""
try:
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
logger.info("PostgreSQL, AGE extension enabled")
logger.info("AGE extension ensured for PostgreSQL")
except Exception as e:
logger.warning(f"Could not create AGE extension: {e}")
# Don't raise - let the system continue without AGE extension
@ -1411,31 +1242,35 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
) -> dict[str, Any] | None | list[dict[str, Any]]:
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_params = tuple(params) if params else ()
if prepared_params:
rows = await connection.fetch(sql, *prepared_params)
else:
rows = await connection.fetch(sql)
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name) # type: ignore
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
return [dict(zip(columns, row)) for row in rows]
return []
try:
if params:
rows = await connection.fetch(sql, *params)
else:
rows = await connection.fetch(sql)
if rows:
columns = rows[0].keys()
return dict(zip(columns, rows[0]))
return None
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
data = [dict(zip(columns, row)) for row in rows]
else:
data = []
else:
if rows:
columns = rows[0].keys()
data = dict(zip(columns, rows[0]))
else:
data = None
try:
return await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
return data
except Exception as e:
logger.error(f"PostgreSQL database, error:{e}")
raise
async def execute(
self,
@ -1446,33 +1281,30 @@ class PostgreSQLDB:
with_age: bool = False,
graph_name: str | None = None,
):
async def _operation(connection: asyncpg.Connection) -> Any:
prepared_values = tuple(data.values()) if data else ()
try:
if not data:
return await connection.execute(sql)
return await connection.execute(sql, *prepared_values)
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError,
asyncpg.exceptions.InvalidSchemaNameError,
) as e:
if ignore_if_exists:
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
return None
if upsert:
logger.info(
"PostgreSQL, duplicate detected but treated as upsert success: %r",
e,
)
return None
raise
try:
await self._run_with_retry(
_operation, with_age=with_age, graph_name=graph_name
)
async with self.pool.acquire() as connection: # type: ignore
if with_age and graph_name:
await self.configure_age(connection, graph_name)
elif with_age and not graph_name:
raise ValueError("Graph name is required when with_age is True")
if data is None:
await connection.execute(sql)
else:
await connection.execute(sql, *data.values())
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
) as e:
if ignore_if_exists:
# If the flag is set, just ignore these specific errors
pass
elif upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
@ -1559,13 +1391,9 @@ class ClientManager:
),
# Server settings for Supabase
"server_settings": os.environ.get(
"POSTGRES_SERVER_SETTINGS",
"POSTGRES_SERVER_OPTIONS",
config.get("postgres", "server_options", fallback=None),
),
"statement_cache_size": os.environ.get(
"POSTGRES_STATEMENT_CACHE_SIZE",
config.get("postgres", "statement_cache_size", fallback=None),
),
}
@classmethod
@ -1813,33 +1641,12 @@ class PGKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get data by ids"""
if not ids:
return []
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
rows: list[dict[str, Any]] | None,
) -> list[dict[str, Any] | None]:
"""Preserve the caller requested ordering for bulk id lookups."""
if not rows:
return [None for _ in ids]
id_map: dict[str, dict[str, Any]] = {}
for row in rows:
if row is None:
continue
row_id = row.get("id")
if row_id is not None:
id_map[str(row_id)] = row
ordered: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered.append(id_map.get(str(requested_id)))
return ordered
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result
for result in results:
@ -1882,7 +1689,7 @@ class PGKVStorage(BaseKVStorage):
"update_time": create_time if update_time == 0 else update_time,
}
processed_results.append(processed_row)
return _order_results(processed_results)
return processed_results
# Special handling for FULL_ENTITIES namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
@ -1916,16 +1723,15 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time
return _order_results(results)
return results if results else []
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2375,23 +2181,7 @@ class PGVectorStorage(BaseVectorStorage):
try:
results = await self.db.query(query, list(params.values()), multirows=True)
if not results:
return []
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
id_map: dict[str, dict[str, Any]] = {}
for record in results:
if record is None:
continue
record_dict = dict(record)
row_id = record_dict.get("id")
if row_id is not None:
id_map[str(row_id)] = record_dict
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(id_map.get(str(requested_id)))
return ordered_results
return [dict(record) for record in results]
except Exception as e:
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
@ -2504,12 +2294,11 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2580,7 +2369,7 @@ class PGDocStatusStorage(DocStatusStorage):
if not results:
return []
processed_map: dict[str, dict[str, Any]] = {}
processed_results = []
for row in results:
# Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", [])
@ -2602,25 +2391,23 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = self._format_datetime_with_timezone(row["created_at"])
updated_at = self._format_datetime_with_timezone(row["updated_at"])
processed_map[str(row.get("id"))] = {
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": created_at,
"updated_at": updated_at,
"file_path": row["file_path"],
"chunks_list": chunks_list,
"metadata": metadata,
"error_msg": row.get("error_msg"),
"track_id": row.get("track_id"),
}
processed_results.append(
{
"content_length": row["content_length"],
"content_summary": row["content_summary"],
"status": row["status"],
"chunks_count": row["chunks_count"],
"created_at": created_at,
"updated_at": updated_at,
"file_path": row["file_path"],
"chunks_list": chunks_list,
"metadata": metadata,
"error_msg": row.get("error_msg"),
"track_id": row.get("track_id"),
}
)
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(processed_map.get(str(requested_id)))
return ordered_results
return processed_results
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path
@ -2822,33 +2609,26 @@ class PGDocStatusStorage(DocStatusStorage):
elif page_size > 200:
page_size = 200
# Whitelist validation for sort_field to prevent SQL injection
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
if sort_field not in allowed_sort_fields:
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
sort_field = "updated_at"
# Whitelist validation for sort_direction to prevent SQL injection
if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc"
else:
sort_direction = sort_direction.lower()
# Calculate offset
offset = (page - 1) * page_size
# Build parameterized query components
# Build WHERE clause
where_clause = "WHERE workspace=$1"
params = {"workspace": self.workspace}
param_count = 1
# Build WHERE clause with parameterized query
if status_filter is not None:
param_count += 1
where_clause = "WHERE workspace=$1 AND status=$2"
where_clause += f" AND status=${param_count}"
params["status"] = status_filter.value
else:
where_clause = "WHERE workspace=$1"
# Build ORDER BY clause using validated whitelist values
# Build ORDER BY clause
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
# Query for total count
@ -2856,7 +2636,7 @@ class PGDocStatusStorage(DocStatusStorage):
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0
# Query for paginated data with parameterized LIMIT and OFFSET
# Query for paginated data
data_sql = f"""
SELECT * FROM LIGHTRAG_DOC_STATUS
{where_clause}
@ -3387,8 +3167,7 @@ class PGGraphStorage(BaseGraphStorage):
{
"message": f"Error executing graph query: {query}",
"wrapped": query,
"detail": repr(e),
"error_type": e.__class__.__name__,
"detail": str(e),
}
) from e
@ -4854,19 +4633,19 @@ SQL_TEMPLATES = {
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
COALESCE(doc_name, '') as file_path
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_id_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
@ -4881,12 +4660,12 @@ SQL_TEMPLATES = {
"get_by_ids_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)