feat: add options for PostGres connection
(cherry picked from commit 108cdbe133)
This commit is contained in:
parent
fe05563ecb
commit
7ce46bacb6
2 changed files with 167 additions and 411 deletions
65
env.example
65
env.example
|
|
@ -23,13 +23,13 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||||
# WORKING_DIR=<absolute_path_for_working_dir>
|
# WORKING_DIR=<absolute_path_for_working_dir>
|
||||||
|
|
||||||
### Tiktoken cache directory (Store cached files in this folder for offline deployment)
|
### Tiktoken cache directory (Store cached files in this folder for offline deployment)
|
||||||
# TIKTOKEN_CACHE_DIR=/app/data/tiktoken
|
# TIKTOKEN_CACHE_DIR=./temp/tiktoken
|
||||||
|
|
||||||
### Ollama Emulating Model and Tag
|
### Ollama Emulating Model and Tag
|
||||||
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
||||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
|
|
||||||
### Max nodes return from graph retrieval in webui
|
### Max nodes return from grap retrieval in webui
|
||||||
# MAX_GRAPH_NODES=1000
|
# MAX_GRAPH_NODES=1000
|
||||||
|
|
||||||
### Logging level
|
### Logging level
|
||||||
|
|
@ -50,32 +50,35 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
# JWT_ALGORITHM=HS256
|
# JWT_ALGORITHM=HS256
|
||||||
|
|
||||||
### API-Key to access LightRAG Server API
|
### API-Key to access LightRAG Server API
|
||||||
### Use this key in HTTP requests with the 'X-API-Key' header
|
|
||||||
### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query
|
|
||||||
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
||||||
# WHITELIST_PATHS=/health,/api/*
|
# WHITELIST_PATHS=/health,/api/*
|
||||||
|
|
||||||
######################################################################################
|
######################################################################################
|
||||||
### Query Configuration
|
### Query Configuration
|
||||||
###
|
###
|
||||||
### How to control the context length sent to LLM:
|
### How to control the context lenght sent to LLM:
|
||||||
### MAX_ENTITY_TOKENS + MAX_RELATION_TOKENS < MAX_TOTAL_TOKENS
|
### MAX_ENTITY_TOKENS + MAX_RELATION_TOKENS < MAX_TOTAL_TOKENS
|
||||||
### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Relation_Tokens
|
### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Reation_Tokens
|
||||||
######################################################################################
|
######################################################################################
|
||||||
# LLM response cache for query (Not valid for streaming response)
|
# LLM responde cache for query (Not valid for streaming response)
|
||||||
ENABLE_LLM_CACHE=true
|
ENABLE_LLM_CACHE=true
|
||||||
# COSINE_THRESHOLD=0.2
|
# COSINE_THRESHOLD=0.2
|
||||||
### Number of entities or relations retrieved from KG
|
### Number of entities or relations retrieved from KG
|
||||||
# TOP_K=40
|
# TOP_K=40
|
||||||
### Maximum number or chunks for naive vector search
|
### Maxmium number or chunks for naive vector search
|
||||||
# CHUNK_TOP_K=20
|
# CHUNK_TOP_K=20
|
||||||
### control the actual entities send to LLM
|
### control the actual enties send to LLM
|
||||||
# MAX_ENTITY_TOKENS=6000
|
# MAX_ENTITY_TOKENS=6000
|
||||||
### control the actual relations send to LLM
|
### control the actual relations send to LLM
|
||||||
# MAX_RELATION_TOKENS=8000
|
# MAX_RELATION_TOKENS=8000
|
||||||
### control the maximum tokens send to LLM (include entities, relations and chunks)
|
### control the maximum tokens send to LLM (include entities, raltions and chunks)
|
||||||
# MAX_TOTAL_TOKENS=30000
|
# MAX_TOTAL_TOKENS=30000
|
||||||
|
|
||||||
|
### maximum number of related chunks per source entity or relation
|
||||||
|
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
|
||||||
|
### Higher values increase re-ranking time
|
||||||
|
# RELATED_CHUNK_NUMBER=5
|
||||||
|
|
||||||
### chunk selection strategies
|
### chunk selection strategies
|
||||||
### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval
|
### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval
|
||||||
### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM
|
### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM
|
||||||
|
|
@ -90,7 +93,7 @@ ENABLE_LLM_CACHE=true
|
||||||
RERANK_BINDING=null
|
RERANK_BINDING=null
|
||||||
### Enable rerank by default in query params when RERANK_BINDING is not null
|
### Enable rerank by default in query params when RERANK_BINDING is not null
|
||||||
# RERANK_BY_DEFAULT=True
|
# RERANK_BY_DEFAULT=True
|
||||||
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enough)
|
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||||
# MIN_RERANK_SCORE=0.0
|
# MIN_RERANK_SCORE=0.0
|
||||||
|
|
||||||
### For local deployment with vLLM
|
### For local deployment with vLLM
|
||||||
|
|
@ -128,7 +131,7 @@ SUMMARY_LANGUAGE=English
|
||||||
# CHUNK_SIZE=1200
|
# CHUNK_SIZE=1200
|
||||||
# CHUNK_OVERLAP_SIZE=100
|
# CHUNK_OVERLAP_SIZE=100
|
||||||
|
|
||||||
### Number of summary segments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommended)
|
### Number of summary semgments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommented)
|
||||||
# FORCE_LLM_SUMMARY_ON_MERGE=8
|
# FORCE_LLM_SUMMARY_ON_MERGE=8
|
||||||
### Max description token size to trigger LLM summary
|
### Max description token size to trigger LLM summary
|
||||||
# SUMMARY_MAX_TOKENS = 1200
|
# SUMMARY_MAX_TOKENS = 1200
|
||||||
|
|
@ -137,19 +140,6 @@ SUMMARY_LANGUAGE=English
|
||||||
### Maximum context size sent to LLM for description summary
|
### Maximum context size sent to LLM for description summary
|
||||||
# SUMMARY_CONTEXT_SIZE=12000
|
# SUMMARY_CONTEXT_SIZE=12000
|
||||||
|
|
||||||
### control the maximum chunk_ids stored in vector and graph db
|
|
||||||
# MAX_SOURCE_IDS_PER_ENTITY=300
|
|
||||||
# MAX_SOURCE_IDS_PER_RELATION=300
|
|
||||||
### control chunk_ids limitation method: KEEP, FIFO (KEEP: Keep oldest, FIFO: First in first out)
|
|
||||||
# SOURCE_IDS_LIMIT_METHOD=KEEP
|
|
||||||
### Maximum number of file paths stored in entity/relation file_path field
|
|
||||||
# MAX_FILE_PATHS=30
|
|
||||||
|
|
||||||
### maximum number of related chunks per source entity or relation
|
|
||||||
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
|
|
||||||
### Higher values increase re-ranking time
|
|
||||||
# RELATED_CHUNK_NUMBER=5
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
### Concurrency Configuration
|
### Concurrency Configuration
|
||||||
###############################
|
###############################
|
||||||
|
|
@ -189,7 +179,7 @@ LLM_BINDING_API_KEY=your_api_key
|
||||||
# OPENAI_LLM_TEMPERATURE=0.9
|
# OPENAI_LLM_TEMPERATURE=0.9
|
||||||
### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
|
### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
|
||||||
### Typically, max_tokens does not include prompt content, though some models, such as Gemini Models, are exceptions
|
### Typically, max_tokens does not include prompt content, though some models, such as Gemini Models, are exceptions
|
||||||
### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider
|
### For vLLM/SGLang doployed models, or most of OpenAI compatible API provider
|
||||||
# OPENAI_LLM_MAX_TOKENS=9000
|
# OPENAI_LLM_MAX_TOKENS=9000
|
||||||
### For OpenAI o1-mini or newer modles
|
### For OpenAI o1-mini or newer modles
|
||||||
OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
|
OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
|
||||||
|
|
@ -203,7 +193,7 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
|
||||||
# OPENAI_LLM_REASONING_EFFORT=minimal
|
# OPENAI_LLM_REASONING_EFFORT=minimal
|
||||||
### OpenRouter Specific Parameters
|
### OpenRouter Specific Parameters
|
||||||
# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}'
|
# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}'
|
||||||
### Qwen3 Specific Parameters deploy by vLLM
|
### Qwen3 Specific Parameters depoly by vLLM
|
||||||
# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}'
|
# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}'
|
||||||
|
|
||||||
### use the following command to see all support options for Ollama LLM
|
### use the following command to see all support options for Ollama LLM
|
||||||
|
|
@ -257,8 +247,8 @@ OLLAMA_EMBEDDING_NUM_CTX=8192
|
||||||
### lightrag-server --embedding-binding ollama --help
|
### lightrag-server --embedding-binding ollama --help
|
||||||
|
|
||||||
####################################################################
|
####################################################################
|
||||||
### WORKSPACE sets workspace name for all storage types
|
### WORKSPACE setting workspace name for all storage types
|
||||||
### for the purpose of isolating data from LightRAG instances.
|
### in the purpose of isolating data from LightRAG instances.
|
||||||
### Valid workspace name constraints: a-z, A-Z, 0-9, and _
|
### Valid workspace name constraints: a-z, A-Z, 0-9, and _
|
||||||
####################################################################
|
####################################################################
|
||||||
# WORKSPACE=space1
|
# WORKSPACE=space1
|
||||||
|
|
@ -313,16 +303,6 @@ POSTGRES_HNSW_M=16
|
||||||
POSTGRES_HNSW_EF=200
|
POSTGRES_HNSW_EF=200
|
||||||
POSTGRES_IVFFLAT_LISTS=100
|
POSTGRES_IVFFLAT_LISTS=100
|
||||||
|
|
||||||
### PostgreSQL Connection Retry Configuration (Network Robustness)
|
|
||||||
### Number of retry attempts (1-10, default: 3)
|
|
||||||
### Initial retry backoff in seconds (0.1-5.0, default: 0.5)
|
|
||||||
### Maximum retry backoff in seconds (backoff-60.0, default: 5.0)
|
|
||||||
### Connection pool close timeout in seconds (1.0-30.0, default: 5.0)
|
|
||||||
# POSTGRES_CONNECTION_RETRIES=3
|
|
||||||
# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5
|
|
||||||
# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0
|
|
||||||
# POSTGRES_POOL_CLOSE_TIMEOUT=5.0
|
|
||||||
|
|
||||||
### PostgreSQL SSL Configuration (Optional)
|
### PostgreSQL SSL Configuration (Optional)
|
||||||
# POSTGRES_SSL_MODE=require
|
# POSTGRES_SSL_MODE=require
|
||||||
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
||||||
|
|
@ -330,13 +310,10 @@ POSTGRES_IVFFLAT_LISTS=100
|
||||||
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
||||||
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
||||||
|
|
||||||
### PostgreSQL Server Settings (for Supabase Supavisor)
|
### PostgreSQL Server Options (for Supabase Supavisor)
|
||||||
# Use this to pass extra options to the PostgreSQL connection string.
|
# Use this to pass extra options to the PostgreSQL connection string.
|
||||||
# For Supabase, you might need to set it like this:
|
# For Supabase, you might need to set it like this:
|
||||||
# POSTGRES_SERVER_SETTINGS="options=reference%3D[project-ref]"
|
# POSTGRES_SERVER_OPTIONS="options=reference%3D[project-ref]"
|
||||||
|
|
||||||
# Default is 100 set to 0 to disable
|
|
||||||
# POSTGRES_STATEMENT_CACHE_SIZE=100
|
|
||||||
|
|
||||||
### Neo4j Configuration
|
### Neo4j Configuration
|
||||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import re
|
||||||
import datetime
|
import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Awaitable, Callable, TypeVar, Union, final
|
from typing import Any, Union, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
import ssl
|
import ssl
|
||||||
|
|
@ -14,13 +14,10 @@ import itertools
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
AsyncRetrying,
|
|
||||||
RetryCallState,
|
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
wait_exponential,
|
wait_exponential,
|
||||||
wait_fixed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base import (
|
from ..base import (
|
||||||
|
|
@ -51,8 +48,6 @@ from dotenv import load_dotenv
|
||||||
# the OS environment variables take precedence over the .env file
|
# the OS environment variables take precedence over the .env file
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDB:
|
class PostgreSQLDB:
|
||||||
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
||||||
|
|
@ -82,55 +77,9 @@ class PostgreSQLDB:
|
||||||
# Server settings
|
# Server settings
|
||||||
self.server_settings = config.get("server_settings")
|
self.server_settings = config.get("server_settings")
|
||||||
|
|
||||||
# Statement LRU cache size (keep as-is, allow None for optional configuration)
|
|
||||||
self.statement_cache_size = config.get("statement_cache_size")
|
|
||||||
|
|
||||||
# Connection retry configuration
|
|
||||||
self.connection_retry_attempts = max(
|
|
||||||
1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3)))
|
|
||||||
)
|
|
||||||
self.connection_retry_backoff = max(
|
|
||||||
0.1,
|
|
||||||
min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))),
|
|
||||||
)
|
|
||||||
self.connection_retry_backoff_max = max(
|
|
||||||
self.connection_retry_backoff,
|
|
||||||
min(
|
|
||||||
60.0,
|
|
||||||
float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.pool_close_timeout = max(
|
|
||||||
1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0)))
|
|
||||||
)
|
|
||||||
|
|
||||||
self._transient_exceptions = (
|
|
||||||
asyncio.TimeoutError,
|
|
||||||
TimeoutError,
|
|
||||||
ConnectionError,
|
|
||||||
OSError,
|
|
||||||
asyncpg.exceptions.InterfaceError,
|
|
||||||
asyncpg.exceptions.TooManyConnectionsError,
|
|
||||||
asyncpg.exceptions.CannotConnectNowError,
|
|
||||||
asyncpg.exceptions.PostgresConnectionError,
|
|
||||||
asyncpg.exceptions.ConnectionDoesNotExistError,
|
|
||||||
asyncpg.exceptions.ConnectionFailureError,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guard concurrent pool resets
|
|
||||||
self._pool_reconnect_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
if self.user is None or self.password is None or self.database is None:
|
if self.user is None or self.password is None or self.database is None:
|
||||||
raise ValueError("Missing database user, password, or database")
|
raise ValueError("Missing database user, password, or database")
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs",
|
|
||||||
self.connection_retry_attempts,
|
|
||||||
self.connection_retry_backoff,
|
|
||||||
self.connection_retry_backoff_max,
|
|
||||||
self.pool_close_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
"""Create SSL context based on configuration parameters."""
|
"""Create SSL context based on configuration parameters."""
|
||||||
if not self.ssl_mode:
|
if not self.ssl_mode:
|
||||||
|
|
@ -202,87 +151,54 @@ class PostgreSQLDB:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def initdb(self):
|
async def initdb(self):
|
||||||
# Prepare connection parameters
|
|
||||||
connection_params = {
|
|
||||||
"user": self.user,
|
|
||||||
"password": self.password,
|
|
||||||
"database": self.database,
|
|
||||||
"host": self.host,
|
|
||||||
"port": self.port,
|
|
||||||
"min_size": 1,
|
|
||||||
"max_size": self.max,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only add statement_cache_size if it's configured
|
|
||||||
if self.statement_cache_size is not None:
|
|
||||||
connection_params["statement_cache_size"] = int(
|
|
||||||
self.statement_cache_size
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add SSL configuration if provided
|
|
||||||
ssl_context = self._create_ssl_context()
|
|
||||||
if ssl_context is not None:
|
|
||||||
connection_params["ssl"] = ssl_context
|
|
||||||
logger.info("PostgreSQL, SSL configuration applied")
|
|
||||||
elif self.ssl_mode:
|
|
||||||
# Handle simple SSL modes without custom context
|
|
||||||
if self.ssl_mode.lower() in ["require", "prefer"]:
|
|
||||||
connection_params["ssl"] = True
|
|
||||||
elif self.ssl_mode.lower() == "disable":
|
|
||||||
connection_params["ssl"] = False
|
|
||||||
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
|
||||||
|
|
||||||
# Add server settings if provided
|
|
||||||
if self.server_settings:
|
|
||||||
try:
|
|
||||||
settings = {}
|
|
||||||
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
|
||||||
pairs = self.server_settings.split("&")
|
|
||||||
for pair in pairs:
|
|
||||||
if "=" in pair:
|
|
||||||
key, value = pair.split("=", 1)
|
|
||||||
settings[key] = value
|
|
||||||
if settings:
|
|
||||||
connection_params["server_settings"] = settings
|
|
||||||
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
wait_strategy = (
|
|
||||||
wait_exponential(
|
|
||||||
multiplier=self.connection_retry_backoff,
|
|
||||||
min=self.connection_retry_backoff,
|
|
||||||
max=self.connection_retry_backoff_max,
|
|
||||||
)
|
|
||||||
if self.connection_retry_backoff > 0
|
|
||||||
else wait_fixed(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _create_pool_once() -> None:
|
|
||||||
pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
|
||||||
try:
|
|
||||||
async with pool.acquire() as connection:
|
|
||||||
await self.configure_vector_extension(connection)
|
|
||||||
except Exception:
|
|
||||||
await pool.close()
|
|
||||||
raise
|
|
||||||
self.pool = pool
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for attempt in AsyncRetrying(
|
# Prepare connection parameters
|
||||||
stop=stop_after_attempt(self.connection_retry_attempts),
|
connection_params = {
|
||||||
retry=retry_if_exception_type(self._transient_exceptions),
|
"user": self.user,
|
||||||
wait=wait_strategy,
|
"password": self.password,
|
||||||
before_sleep=self._before_sleep,
|
"database": self.database,
|
||||||
reraise=True,
|
"host": self.host,
|
||||||
):
|
"port": self.port,
|
||||||
with attempt:
|
"min_size": 1,
|
||||||
await _create_pool_once()
|
"max_size": self.max,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add SSL configuration if provided
|
||||||
|
ssl_context = self._create_ssl_context()
|
||||||
|
if ssl_context is not None:
|
||||||
|
connection_params["ssl"] = ssl_context
|
||||||
|
logger.info("PostgreSQL, SSL configuration applied")
|
||||||
|
elif self.ssl_mode:
|
||||||
|
# Handle simple SSL modes without custom context
|
||||||
|
if self.ssl_mode.lower() in ["require", "prefer"]:
|
||||||
|
connection_params["ssl"] = True
|
||||||
|
elif self.ssl_mode.lower() == "disable":
|
||||||
|
connection_params["ssl"] = False
|
||||||
|
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
||||||
|
|
||||||
|
# Add server settings if provided
|
||||||
|
if self.server_settings:
|
||||||
|
try:
|
||||||
|
settings = {}
|
||||||
|
# The format is expected to be a query string, e.g., "key1=value1&key2=value2"
|
||||||
|
pairs = self.server_settings.split("&")
|
||||||
|
for pair in pairs:
|
||||||
|
if "=" in pair:
|
||||||
|
key, value = pair.split("=", 1)
|
||||||
|
settings[key] = value
|
||||||
|
if settings:
|
||||||
|
connection_params["server_settings"] = settings
|
||||||
|
logger.info(f"PostgreSQL, Server settings applied: {settings}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
||||||
|
|
||||||
|
# Ensure VECTOR extension is available
|
||||||
|
async with self.pool.acquire() as connection:
|
||||||
|
await self.configure_vector_extension(connection)
|
||||||
|
|
||||||
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -294,97 +210,12 @@ class PostgreSQLDB:
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _ensure_pool(self) -> None:
|
|
||||||
"""Ensure the connection pool is initialised."""
|
|
||||||
if self.pool is None:
|
|
||||||
async with self._pool_reconnect_lock:
|
|
||||||
if self.pool is None:
|
|
||||||
await self.initdb()
|
|
||||||
|
|
||||||
async def _reset_pool(self) -> None:
|
|
||||||
async with self._pool_reconnect_lock:
|
|
||||||
if self.pool is not None:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
self.pool.close(), timeout=self.pool_close_timeout
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(
|
|
||||||
"PostgreSQL, Timed out closing connection pool after %.2fs",
|
|
||||||
self.pool_close_timeout,
|
|
||||||
)
|
|
||||||
except Exception as close_error: # pragma: no cover - defensive logging
|
|
||||||
logger.warning(
|
|
||||||
f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}"
|
|
||||||
)
|
|
||||||
self.pool = None
|
|
||||||
|
|
||||||
async def _before_sleep(self, retry_state: RetryCallState) -> None:
|
|
||||||
"""Hook invoked by tenacity before sleeping between retries."""
|
|
||||||
exc = retry_state.outcome.exception() if retry_state.outcome else None
|
|
||||||
logger.warning(
|
|
||||||
"PostgreSQL transient connection issue on attempt %s/%s: %r",
|
|
||||||
retry_state.attempt_number,
|
|
||||||
self.connection_retry_attempts,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await self._reset_pool()
|
|
||||||
|
|
||||||
async def _run_with_retry(
|
|
||||||
self,
|
|
||||||
operation: Callable[[asyncpg.Connection], Awaitable[T]],
|
|
||||||
*,
|
|
||||||
with_age: bool = False,
|
|
||||||
graph_name: str | None = None,
|
|
||||||
) -> T:
|
|
||||||
"""
|
|
||||||
Execute a database operation with automatic retry for transient failures.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation: Async callable that receives an active connection.
|
|
||||||
with_age: Whether to configure Apache AGE on the connection.
|
|
||||||
graph_name: AGE graph name; required when with_age is True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result returned by the operation.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs.
|
|
||||||
"""
|
|
||||||
wait_strategy = (
|
|
||||||
wait_exponential(
|
|
||||||
multiplier=self.connection_retry_backoff,
|
|
||||||
min=self.connection_retry_backoff,
|
|
||||||
max=self.connection_retry_backoff_max,
|
|
||||||
)
|
|
||||||
if self.connection_retry_backoff > 0
|
|
||||||
else wait_fixed(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
async for attempt in AsyncRetrying(
|
|
||||||
stop=stop_after_attempt(self.connection_retry_attempts),
|
|
||||||
retry=retry_if_exception_type(self._transient_exceptions),
|
|
||||||
wait=wait_strategy,
|
|
||||||
before_sleep=self._before_sleep,
|
|
||||||
reraise=True,
|
|
||||||
):
|
|
||||||
with attempt:
|
|
||||||
await self._ensure_pool()
|
|
||||||
assert self.pool is not None
|
|
||||||
async with self.pool.acquire() as connection: # type: ignore[arg-type]
|
|
||||||
if with_age and graph_name:
|
|
||||||
await self.configure_age(connection, graph_name)
|
|
||||||
elif with_age and not graph_name:
|
|
||||||
raise ValueError("Graph name is required when with_age is True")
|
|
||||||
|
|
||||||
return await operation(connection)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
|
async def configure_vector_extension(connection: asyncpg.Connection) -> None:
|
||||||
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
|
"""Create VECTOR extension if it doesn't exist for vector similarity operations."""
|
||||||
try:
|
try:
|
||||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
|
await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore
|
||||||
logger.info("PostgreSQL, VECTOR extension enabled")
|
logger.info("VECTOR extension ensured for PostgreSQL")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not create VECTOR extension: {e}")
|
logger.warning(f"Could not create VECTOR extension: {e}")
|
||||||
# Don't raise - let the system continue without vector extension
|
# Don't raise - let the system continue without vector extension
|
||||||
|
|
@ -394,7 +225,7 @@ class PostgreSQLDB:
|
||||||
"""Create AGE extension if it doesn't exist for graph operations."""
|
"""Create AGE extension if it doesn't exist for graph operations."""
|
||||||
try:
|
try:
|
||||||
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
|
await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore
|
||||||
logger.info("PostgreSQL, AGE extension enabled")
|
logger.info("AGE extension ensured for PostgreSQL")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not create AGE extension: {e}")
|
logger.warning(f"Could not create AGE extension: {e}")
|
||||||
# Don't raise - let the system continue without AGE extension
|
# Don't raise - let the system continue without AGE extension
|
||||||
|
|
@ -1411,31 +1242,35 @@ class PostgreSQLDB:
|
||||||
with_age: bool = False,
|
with_age: bool = False,
|
||||||
graph_name: str | None = None,
|
graph_name: str | None = None,
|
||||||
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
||||||
async def _operation(connection: asyncpg.Connection) -> Any:
|
async with self.pool.acquire() as connection: # type: ignore
|
||||||
prepared_params = tuple(params) if params else ()
|
if with_age and graph_name:
|
||||||
if prepared_params:
|
await self.configure_age(connection, graph_name) # type: ignore
|
||||||
rows = await connection.fetch(sql, *prepared_params)
|
elif with_age and not graph_name:
|
||||||
else:
|
raise ValueError("Graph name is required when with_age is True")
|
||||||
rows = await connection.fetch(sql)
|
|
||||||
|
|
||||||
if multirows:
|
try:
|
||||||
if rows:
|
if params:
|
||||||
columns = [col for col in rows[0].keys()]
|
rows = await connection.fetch(sql, *params)
|
||||||
return [dict(zip(columns, row)) for row in rows]
|
else:
|
||||||
return []
|
rows = await connection.fetch(sql)
|
||||||
|
|
||||||
if rows:
|
if multirows:
|
||||||
columns = rows[0].keys()
|
if rows:
|
||||||
return dict(zip(columns, rows[0]))
|
columns = [col for col in rows[0].keys()]
|
||||||
return None
|
data = [dict(zip(columns, row)) for row in rows]
|
||||||
|
else:
|
||||||
|
data = []
|
||||||
|
else:
|
||||||
|
if rows:
|
||||||
|
columns = rows[0].keys()
|
||||||
|
data = dict(zip(columns, rows[0]))
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
|
||||||
try:
|
return data
|
||||||
return await self._run_with_retry(
|
except Exception as e:
|
||||||
_operation, with_age=with_age, graph_name=graph_name
|
logger.error(f"PostgreSQL database, error:{e}")
|
||||||
)
|
raise
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"PostgreSQL database, error:{e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1446,33 +1281,30 @@ class PostgreSQLDB:
|
||||||
with_age: bool = False,
|
with_age: bool = False,
|
||||||
graph_name: str | None = None,
|
graph_name: str | None = None,
|
||||||
):
|
):
|
||||||
async def _operation(connection: asyncpg.Connection) -> Any:
|
|
||||||
prepared_values = tuple(data.values()) if data else ()
|
|
||||||
try:
|
|
||||||
if not data:
|
|
||||||
return await connection.execute(sql)
|
|
||||||
return await connection.execute(sql, *prepared_values)
|
|
||||||
except (
|
|
||||||
asyncpg.exceptions.UniqueViolationError,
|
|
||||||
asyncpg.exceptions.DuplicateTableError,
|
|
||||||
asyncpg.exceptions.DuplicateObjectError,
|
|
||||||
asyncpg.exceptions.InvalidSchemaNameError,
|
|
||||||
) as e:
|
|
||||||
if ignore_if_exists:
|
|
||||||
logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e)
|
|
||||||
return None
|
|
||||||
if upsert:
|
|
||||||
logger.info(
|
|
||||||
"PostgreSQL, duplicate detected but treated as upsert success: %r",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._run_with_retry(
|
async with self.pool.acquire() as connection: # type: ignore
|
||||||
_operation, with_age=with_age, graph_name=graph_name
|
if with_age and graph_name:
|
||||||
)
|
await self.configure_age(connection, graph_name)
|
||||||
|
elif with_age and not graph_name:
|
||||||
|
raise ValueError("Graph name is required when with_age is True")
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
await connection.execute(sql)
|
||||||
|
else:
|
||||||
|
await connection.execute(sql, *data.values())
|
||||||
|
except (
|
||||||
|
asyncpg.exceptions.UniqueViolationError,
|
||||||
|
asyncpg.exceptions.DuplicateTableError,
|
||||||
|
asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error
|
||||||
|
asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists"
|
||||||
|
) as e:
|
||||||
|
if ignore_if_exists:
|
||||||
|
# If the flag is set, just ignore these specific errors
|
||||||
|
pass
|
||||||
|
elif upsert:
|
||||||
|
print("Key value duplicate, but upsert succeeded.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Upsert error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
|
||||||
raise
|
raise
|
||||||
|
|
@ -1559,13 +1391,9 @@ class ClientManager:
|
||||||
),
|
),
|
||||||
# Server settings for Supabase
|
# Server settings for Supabase
|
||||||
"server_settings": os.environ.get(
|
"server_settings": os.environ.get(
|
||||||
"POSTGRES_SERVER_SETTINGS",
|
"POSTGRES_SERVER_OPTIONS",
|
||||||
config.get("postgres", "server_options", fallback=None),
|
config.get("postgres", "server_options", fallback=None),
|
||||||
),
|
),
|
||||||
"statement_cache_size": os.environ.get(
|
|
||||||
"POSTGRES_STATEMENT_CACHE_SIZE",
|
|
||||||
config.get("postgres", "statement_cache_size", fallback=None),
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -1813,33 +1641,12 @@ class PGKVStorage(BaseKVStorage):
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""Get data by ids"""
|
"""Get data by ids"""
|
||||||
if not ids:
|
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
return []
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
|
)
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
|
params = {"workspace": self.workspace}
|
||||||
params = {"workspace": self.workspace, "ids": ids}
|
|
||||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
|
|
||||||
def _order_results(
|
|
||||||
rows: list[dict[str, Any]] | None,
|
|
||||||
) -> list[dict[str, Any] | None]:
|
|
||||||
"""Preserve the caller requested ordering for bulk id lookups."""
|
|
||||||
if not rows:
|
|
||||||
return [None for _ in ids]
|
|
||||||
|
|
||||||
id_map: dict[str, dict[str, Any]] = {}
|
|
||||||
for row in rows:
|
|
||||||
if row is None:
|
|
||||||
continue
|
|
||||||
row_id = row.get("id")
|
|
||||||
if row_id is not None:
|
|
||||||
id_map[str(row_id)] = row
|
|
||||||
|
|
||||||
ordered: list[dict[str, Any] | None] = []
|
|
||||||
for requested_id in ids:
|
|
||||||
ordered.append(id_map.get(str(requested_id)))
|
|
||||||
return ordered
|
|
||||||
|
|
||||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
# Parse llm_cache_list JSON string back to list for each result
|
# Parse llm_cache_list JSON string back to list for each result
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -1882,7 +1689,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"update_time": create_time if update_time == 0 else update_time,
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
}
|
}
|
||||||
processed_results.append(processed_row)
|
processed_results.append(processed_row)
|
||||||
return _order_results(processed_results)
|
return processed_results
|
||||||
|
|
||||||
# Special handling for FULL_ENTITIES namespace
|
# Special handling for FULL_ENTITIES namespace
|
||||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
|
@ -1916,16 +1723,15 @@ class PGKVStorage(BaseKVStorage):
|
||||||
result["create_time"] = create_time
|
result["create_time"] = create_time
|
||||||
result["update_time"] = create_time if update_time == 0 else update_time
|
result["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
return _order_results(results)
|
return results if results else []
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
if not keys:
|
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||||
return set()
|
table_name=namespace_to_table_name(self.namespace),
|
||||||
|
ids=",".join([f"'{id}'" for id in keys]),
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
)
|
||||||
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
params = {"workspace": self.workspace}
|
||||||
params = {"workspace": self.workspace, "ids": list(keys)}
|
|
||||||
try:
|
try:
|
||||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
if res:
|
if res:
|
||||||
|
|
@ -2375,23 +2181,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.db.query(query, list(params.values()), multirows=True)
|
results = await self.db.query(query, list(params.values()), multirows=True)
|
||||||
if not results:
|
return [dict(record) for record in results]
|
||||||
return []
|
|
||||||
|
|
||||||
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
|
|
||||||
id_map: dict[str, dict[str, Any]] = {}
|
|
||||||
for record in results:
|
|
||||||
if record is None:
|
|
||||||
continue
|
|
||||||
record_dict = dict(record)
|
|
||||||
row_id = record_dict.get("id")
|
|
||||||
if row_id is not None:
|
|
||||||
id_map[str(row_id)] = record_dict
|
|
||||||
|
|
||||||
ordered_results: list[dict[str, Any] | None] = []
|
|
||||||
for requested_id in ids:
|
|
||||||
ordered_results.append(id_map.get(str(requested_id)))
|
|
||||||
return ordered_results
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
||||||
|
|
@ -2504,12 +2294,11 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
if not keys:
|
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||||
return set()
|
table_name=namespace_to_table_name(self.namespace),
|
||||||
|
ids=",".join([f"'{id}'" for id in keys]),
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
)
|
||||||
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
params = {"workspace": self.workspace}
|
||||||
params = {"workspace": self.workspace, "ids": list(keys)}
|
|
||||||
try:
|
try:
|
||||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
if res:
|
if res:
|
||||||
|
|
@ -2580,7 +2369,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
if not results:
|
if not results:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
processed_map: dict[str, dict[str, Any]] = {}
|
processed_results = []
|
||||||
for row in results:
|
for row in results:
|
||||||
# Parse chunks_list JSON string back to list
|
# Parse chunks_list JSON string back to list
|
||||||
chunks_list = row.get("chunks_list", [])
|
chunks_list = row.get("chunks_list", [])
|
||||||
|
|
@ -2602,25 +2391,23 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
created_at = self._format_datetime_with_timezone(row["created_at"])
|
created_at = self._format_datetime_with_timezone(row["created_at"])
|
||||||
updated_at = self._format_datetime_with_timezone(row["updated_at"])
|
updated_at = self._format_datetime_with_timezone(row["updated_at"])
|
||||||
|
|
||||||
processed_map[str(row.get("id"))] = {
|
processed_results.append(
|
||||||
"content_length": row["content_length"],
|
{
|
||||||
"content_summary": row["content_summary"],
|
"content_length": row["content_length"],
|
||||||
"status": row["status"],
|
"content_summary": row["content_summary"],
|
||||||
"chunks_count": row["chunks_count"],
|
"status": row["status"],
|
||||||
"created_at": created_at,
|
"chunks_count": row["chunks_count"],
|
||||||
"updated_at": updated_at,
|
"created_at": created_at,
|
||||||
"file_path": row["file_path"],
|
"updated_at": updated_at,
|
||||||
"chunks_list": chunks_list,
|
"file_path": row["file_path"],
|
||||||
"metadata": metadata,
|
"chunks_list": chunks_list,
|
||||||
"error_msg": row.get("error_msg"),
|
"metadata": metadata,
|
||||||
"track_id": row.get("track_id"),
|
"error_msg": row.get("error_msg"),
|
||||||
}
|
"track_id": row.get("track_id"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
ordered_results: list[dict[str, Any] | None] = []
|
return processed_results
|
||||||
for requested_id in ids:
|
|
||||||
ordered_results.append(processed_map.get(str(requested_id)))
|
|
||||||
|
|
||||||
return ordered_results
|
|
||||||
|
|
||||||
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
|
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
|
||||||
"""Get document by file path
|
"""Get document by file path
|
||||||
|
|
@ -2822,33 +2609,26 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
elif page_size > 200:
|
elif page_size > 200:
|
||||||
page_size = 200
|
page_size = 200
|
||||||
|
|
||||||
# Whitelist validation for sort_field to prevent SQL injection
|
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
|
||||||
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
|
|
||||||
if sort_field not in allowed_sort_fields:
|
|
||||||
sort_field = "updated_at"
|
sort_field = "updated_at"
|
||||||
|
|
||||||
# Whitelist validation for sort_direction to prevent SQL injection
|
|
||||||
if sort_direction.lower() not in ["asc", "desc"]:
|
if sort_direction.lower() not in ["asc", "desc"]:
|
||||||
sort_direction = "desc"
|
sort_direction = "desc"
|
||||||
else:
|
|
||||||
sort_direction = sort_direction.lower()
|
|
||||||
|
|
||||||
# Calculate offset
|
# Calculate offset
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Build parameterized query components
|
# Build WHERE clause
|
||||||
|
where_clause = "WHERE workspace=$1"
|
||||||
params = {"workspace": self.workspace}
|
params = {"workspace": self.workspace}
|
||||||
param_count = 1
|
param_count = 1
|
||||||
|
|
||||||
# Build WHERE clause with parameterized query
|
|
||||||
if status_filter is not None:
|
if status_filter is not None:
|
||||||
param_count += 1
|
param_count += 1
|
||||||
where_clause = "WHERE workspace=$1 AND status=$2"
|
where_clause += f" AND status=${param_count}"
|
||||||
params["status"] = status_filter.value
|
params["status"] = status_filter.value
|
||||||
else:
|
|
||||||
where_clause = "WHERE workspace=$1"
|
|
||||||
|
|
||||||
# Build ORDER BY clause using validated whitelist values
|
# Build ORDER BY clause
|
||||||
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
|
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
|
||||||
|
|
||||||
# Query for total count
|
# Query for total count
|
||||||
|
|
@ -2856,7 +2636,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
count_result = await self.db.query(count_sql, list(params.values()))
|
count_result = await self.db.query(count_sql, list(params.values()))
|
||||||
total_count = count_result["total"] if count_result else 0
|
total_count = count_result["total"] if count_result else 0
|
||||||
|
|
||||||
# Query for paginated data with parameterized LIMIT and OFFSET
|
# Query for paginated data
|
||||||
data_sql = f"""
|
data_sql = f"""
|
||||||
SELECT * FROM LIGHTRAG_DOC_STATUS
|
SELECT * FROM LIGHTRAG_DOC_STATUS
|
||||||
{where_clause}
|
{where_clause}
|
||||||
|
|
@ -3387,8 +3167,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
{
|
{
|
||||||
"message": f"Error executing graph query: {query}",
|
"message": f"Error executing graph query: {query}",
|
||||||
"wrapped": query,
|
"wrapped": query,
|
||||||
"detail": repr(e),
|
"detail": str(e),
|
||||||
"error_type": e.__class__.__name__,
|
|
||||||
}
|
}
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
@ -4854,19 +4633,19 @@ SQL_TEMPLATES = {
|
||||||
""",
|
""",
|
||||||
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
|
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
|
||||||
COALESCE(doc_name, '') as file_path
|
COALESCE(doc_name, '') as file_path
|
||||||
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
|
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||||
chunk_order_index, full_doc_id, file_path,
|
chunk_order_index, full_doc_id, file_path,
|
||||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_id_full_entities": """SELECT id, entity_names, count,
|
"get_by_id_full_entities": """SELECT id, entity_names, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
|
@ -4881,12 +4660,12 @@ SQL_TEMPLATES = {
|
||||||
"get_by_ids_full_entities": """SELECT id, entity_names, count,
|
"get_by_ids_full_entities": """SELECT id, entity_names, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
|
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
|
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
|
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
|
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue