From 0bc702127f7f56338abed021117570d560266f01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:18:14 +0800 Subject: [PATCH] cherry-pick d8a9617c --- lightrag/kg/postgres_impl.py | 1314 +++++++++++++--------------------- 1 file changed, 503 insertions(+), 811 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 2a7c6158..8fc9d590 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -5,7 +5,7 @@ import re import datetime from datetime import timezone from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, TypeVar, Union, final +from typing import Any, Union, final import numpy as np import configparser import ssl @@ -14,13 +14,10 @@ import itertools from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from tenacity import ( - AsyncRetrying, - RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, - wait_fixed, ) from ..base import ( @@ -33,6 +30,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -50,8 +48,6 @@ from dotenv import load_dotenv # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) -T = TypeVar("T") - class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -81,44 +77,9 @@ class PostgreSQLDB: # Server settings self.server_settings = config.get("server_settings") - # Statement LRU cache size (keep as-is, allow None for optional configuration) - self.statement_cache_size = config.get("statement_cache_size") - if self.user is None or self.password is None or self.database is None: raise ValueError("Missing database user, password, or database") - # Guard concurrent pool resets - self._pool_reconnect_lock = asyncio.Lock() - - self._transient_exceptions = ( - asyncio.TimeoutError, - TimeoutError, - ConnectionError, - OSError, - asyncpg.exceptions.InterfaceError, - asyncpg.exceptions.TooManyConnectionsError, - asyncpg.exceptions.CannotConnectNowError, - asyncpg.exceptions.PostgresConnectionError, - asyncpg.exceptions.ConnectionDoesNotExistError, - asyncpg.exceptions.ConnectionFailureError, - ) - - # Connection retry configuration - self.connection_retry_attempts = config["connection_retry_attempts"] - self.connection_retry_backoff = config["connection_retry_backoff"] - self.connection_retry_backoff_max = max( - self.connection_retry_backoff, - config["connection_retry_backoff_max"], - ) - self.pool_close_timeout = config["pool_close_timeout"] - logger.info( - "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs", - self.connection_retry_attempts, - self.connection_retry_backoff, - self.connection_retry_backoff_max, - self.pool_close_timeout, - ) - def _create_ssl_context(self) -> ssl.SSLContext | None: """Create SSL context based on configuration parameters.""" if not self.ssl_mode: @@ -190,85 +151,55 @@ class PostgreSQLDB: return None async def initdb(self): - # Prepare connection parameters - connection_params = { - "user": self.user, - "password": self.password, - "database": self.database, - "host": self.host, - "port": self.port, - "min_size": 1, - "max_size": self.max, - } - - # Only add statement_cache_size if it's configured - if self.statement_cache_size is not None: - connection_params["statement_cache_size"] = int(self.statement_cache_size) - logger.info( - f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" - ) - - # Add SSL configuration if provided - ssl_context = self._create_ssl_context() - if ssl_context is not None: - connection_params["ssl"] = ssl_context - logger.info("PostgreSQL, SSL configuration applied") - elif self.ssl_mode: - # Handle simple SSL modes without custom context - if self.ssl_mode.lower() in ["require", "prefer"]: - connection_params["ssl"] = True - elif self.ssl_mode.lower() == "disable": - connection_params["ssl"] = False - logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") - - # Add server settings if provided - if self.server_settings: - try: - settings = {} - # The format is expected to be a query string, e.g., "key1=value1&key2=value2" - pairs = self.server_settings.split("&") - for pair in pairs: - if "=" in pair: - key, value = pair.split("=", 1) - settings[key] = value - if settings: - connection_params["server_settings"] = settings - logger.info(f"PostgreSQL, Server settings applied: {settings}") - except Exception as e: - logger.warning( - f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" - ) - - wait_strategy = ( - wait_exponential( - multiplier=self.connection_retry_backoff, - min=self.connection_retry_backoff, - max=self.connection_retry_backoff_max, - ) - if self.connection_retry_backoff > 0 - else wait_fixed(0) - ) - - async def _create_pool_once() -> None: - pool = await asyncpg.create_pool(**connection_params) # type: ignore - try: - async with pool.acquire() as connection: - await self.configure_vector_extension(connection) - except Exception: - await pool.close() - raise - self.pool = pool - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(self.connection_retry_attempts), - retry=retry_if_exception_type(self._transient_exceptions), - wait=wait_strategy, - before_sleep=self._before_sleep, - reraise=True, - ): - with attempt: - await _create_pool_once() + # Prepare connection parameters + connection_params = { + "user": self.user, + "password": self.password, + "database": self.database, + "host": self.host, + "port": self.port, + "min_size": 1, + "max_size": self.max, + "statement_cache_size": 0, + } + + # Add SSL configuration if provided + ssl_context = self._create_ssl_context() + if ssl_context is not None: + connection_params["ssl"] = ssl_context + logger.info("PostgreSQL, SSL configuration applied") + elif self.ssl_mode: + # Handle simple SSL modes without custom context + if self.ssl_mode.lower() in ["require", "prefer"]: + connection_params["ssl"] = True + elif self.ssl_mode.lower() == "disable": + connection_params["ssl"] = False + logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}") + + # Add server settings if provided + if self.server_settings: + try: + settings = {} + # The format is expected to be a query string, e.g., "key1=value1&key2=value2" + pairs = self.server_settings.split("&") + for pair in pairs: + if "=" in pair: + key, value = pair.split("=", 1) + settings[key] = value + if settings: + connection_params["server_settings"] = settings + logger.info(f"PostgreSQL, Server settings applied: {settings}") + except Exception as e: + logger.warning( + f"PostgreSQL, Failed to parse server_settings: {self.server_settings}, error: {e}" + ) + + self.pool = await asyncpg.create_pool(**connection_params) # type: ignore + + # Ensure VECTOR extension is available + async with self.pool.acquire() as connection: + await self.configure_vector_extension(connection) ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL" logger.info( @@ -280,97 +211,12 @@ class PostgreSQLDB: ) raise - async def _ensure_pool(self) -> None: - """Ensure the connection pool is initialised.""" - if self.pool is None: - async with self._pool_reconnect_lock: - if self.pool is None: - await self.initdb() - - async def _reset_pool(self) -> None: - async with self._pool_reconnect_lock: - if self.pool is not None: - try: - await asyncio.wait_for( - self.pool.close(), timeout=self.pool_close_timeout - ) - except asyncio.TimeoutError: - logger.error( - "PostgreSQL, Timed out closing connection pool after %.2fs", - self.pool_close_timeout, - ) - except Exception as close_error: # pragma: no cover - defensive logging - logger.warning( - f"PostgreSQL, Failed to close existing connection pool cleanly: {close_error!r}" - ) - self.pool = None - - async def _before_sleep(self, retry_state: RetryCallState) -> None: - """Hook invoked by tenacity before sleeping between retries.""" - exc = retry_state.outcome.exception() if retry_state.outcome else None - logger.warning( - "PostgreSQL transient connection issue on attempt %s/%s: %r", - retry_state.attempt_number, - self.connection_retry_attempts, - exc, - ) - await self._reset_pool() - - async def _run_with_retry( - self, - operation: Callable[[asyncpg.Connection], Awaitable[T]], - *, - with_age: bool = False, - graph_name: str | None = None, - ) -> T: - """ - Execute a database operation with automatic retry for transient failures. - - Args: - operation: Async callable that receives an active connection. - with_age: Whether to configure Apache AGE on the connection. - graph_name: AGE graph name; required when with_age is True. - - Returns: - The result returned by the operation. - - Raises: - Exception: Propagates the last error if all retry attempts fail or a non-transient error occurs. - """ - wait_strategy = ( - wait_exponential( - multiplier=self.connection_retry_backoff, - min=self.connection_retry_backoff, - max=self.connection_retry_backoff_max, - ) - if self.connection_retry_backoff > 0 - else wait_fixed(0) - ) - - async for attempt in AsyncRetrying( - stop=stop_after_attempt(self.connection_retry_attempts), - retry=retry_if_exception_type(self._transient_exceptions), - wait=wait_strategy, - before_sleep=self._before_sleep, - reraise=True, - ): - with attempt: - await self._ensure_pool() - assert self.pool is not None - async with self.pool.acquire() as connection: # type: ignore[arg-type] - if with_age and graph_name: - await self.configure_age(connection, graph_name) - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") - - return await operation(connection) - @staticmethod async def configure_vector_extension(connection: asyncpg.Connection) -> None: """Create VECTOR extension if it doesn't exist for vector similarity operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS vector") # type: ignore - logger.info("PostgreSQL, VECTOR extension enabled") + logger.info("VECTOR extension ensured for PostgreSQL") except Exception as e: logger.warning(f"Could not create VECTOR extension: {e}") # Don't raise - let the system continue without vector extension @@ -380,7 +226,7 @@ class PostgreSQLDB: """Create AGE extension if it doesn't exist for graph operations.""" try: await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore - logger.info("PostgreSQL, AGE extension enabled") + logger.info("AGE extension ensured for PostgreSQL") except Exception as e: logger.warning(f"Could not create AGE extension: {e}") # Don't raise - let the system continue without AGE extension @@ -536,74 +382,49 @@ class PostgreSQLDB: "LIGHTRAG_DOC_STATUS": ["created_at", "updated_at"], } - try: - # Optimization: Batch check all columns in one query instead of 8 separate queries - table_names_lower = [t.lower() for t in tables_to_migrate.keys()] - all_column_names = list( - set(col for cols in tables_to_migrate.values() for col in cols) - ) + for table_name, columns in tables_to_migrate.items(): + for column_name in columns: + try: + # Check if column exists + check_column_sql = f""" + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '{table_name.lower()}' + AND column_name = '{column_name}' + """ - check_all_columns_sql = """ - SELECT table_name, column_name, data_type - FROM information_schema.columns - WHERE table_name = ANY($1) - AND column_name = ANY($2) - """ - - all_columns_result = await self.query( - check_all_columns_sql, - [table_names_lower, all_column_names], - multirows=True, - ) - - # Build lookup dict: (table_name, column_name) -> data_type - column_types = {} - if all_columns_result: - column_types = { - (row["table_name"].upper(), row["column_name"]): row["data_type"] - for row in all_columns_result - } - - # Now iterate and migrate only what's needed - for table_name, columns in tables_to_migrate.items(): - for column_name in columns: - try: - data_type = column_types.get((table_name, column_name)) - - if not data_type: - logger.warning( - f"Column {table_name}.{column_name} does not exist, skipping migration" - ) - continue - - # Check column type - if data_type == "timestamp without time zone": - logger.debug( - f"Column {table_name}.{column_name} is already witimezone-free, no migration needed" - ) - continue - - # Execute migration, explicitly specifying UTC timezone for interpreting original data - logger.info( - f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type" - ) - migration_sql = f""" - ALTER TABLE {table_name} - ALTER COLUMN {column_name} TYPE TIMESTAMP(0), - ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP - """ - - await self.execute(migration_sql) - logger.info( - f"Successfully migrated {table_name}.{column_name} to timezone-free type" - ) - except Exception as e: - # Log error but don't interrupt the process + column_info = await self.query(check_column_sql) + if not column_info: logger.warning( - f"Failed to migrate {table_name}.{column_name}: {e}" + f"Column {table_name}.{column_name} does not exist, skipping migration" ) - except Exception as e: - logger.error(f"Failed to batch check timestamp columns: {e}") + continue + + # Check column type + data_type = column_info.get("data_type") + if data_type == "timestamp without time zone": + logger.debug( + f"Column {table_name}.{column_name} is already witimezone-free, no migration needed" + ) + continue + + # Execute migration, explicitly specifying UTC timezone for interpreting original data + logger.info( + f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type" + ) + migration_sql = f""" + ALTER TABLE {table_name} + ALTER COLUMN {column_name} TYPE TIMESTAMP(0), + ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP + """ + + await self.execute(migration_sql) + logger.info( + f"Successfully migrated {table_name}.{column_name} to timezone-free type" + ) + except Exception as e: + # Log error but don't interrupt the process + logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") async def _migrate_doc_chunks_to_vdb_chunks(self): """ @@ -980,89 +801,73 @@ class PostgreSQLDB: }, ] - try: - # Optimization: Batch check all columns in one query instead of 5 separate queries - unique_tables = list(set(m["table"].lower() for m in field_migrations)) - unique_columns = list(set(m["column"] for m in field_migrations)) - - check_all_columns_sql = """ - SELECT table_name, column_name, data_type, character_maximum_length, is_nullable - FROM information_schema.columns - WHERE table_name = ANY($1) - AND column_name = ANY($2) - """ - - all_columns_result = await self.query( - check_all_columns_sql, [unique_tables, unique_columns], multirows=True - ) - - # Build lookup dict: (table_name, column_name) -> column_info - column_info_map = {} - if all_columns_result: - column_info_map = { - (row["table_name"].upper(), row["column_name"]): row - for row in all_columns_result + for migration in field_migrations: + try: + # Check current column definition + check_column_sql = """ + SELECT column_name, data_type, character_maximum_length, is_nullable + FROM information_schema.columns + WHERE table_name = $1 AND column_name = $2 + """ + params = { + "table_name": migration["table"].lower(), + "column_name": migration["column"], } + column_info = await self.query( + check_column_sql, + list(params.values()), + ) - # Now iterate and migrate only what's needed - for migration in field_migrations: - try: - column_info = column_info_map.get( - (migration["table"], migration["column"]) - ) - - if not column_info: - logger.warning( - f"Column {migration['table']}.{migration['column']} does not exist, skipping migration" - ) - continue - - current_type = column_info.get("data_type", "").lower() - current_length = column_info.get("character_maximum_length") - - # Check if migration is needed - needs_migration = False - - if migration["column"] == "entity_name" and current_length == 255: - needs_migration = True - elif ( - migration["column"] in ["source_id", "target_id"] - and current_length == 256 - ): - needs_migration = True - elif ( - migration["column"] == "file_path" - and current_type == "character varying" - ): - needs_migration = True - - if needs_migration: - logger.info( - f"Migrating {migration['table']}.{migration['column']}: {migration['description']}" - ) - - # Execute the migration - alter_sql = f""" - ALTER TABLE {migration["table"]} - ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]} - """ - - await self.execute(alter_sql) - logger.info( - f"Successfully migrated {migration['table']}.{migration['column']}" - ) - else: - logger.debug( - f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed" - ) - - except Exception as e: - # Log error but don't interrupt the process + if not column_info: logger.warning( - f"Failed to migrate {migration['table']}.{migration['column']}: {e}" + f"Column {migration['table']}.{migration['column']} does not exist, skipping migration" ) - except Exception as e: - logger.error(f"Failed to batch check field lengths: {e}") + continue + + current_type = column_info.get("data_type", "").lower() + current_length = column_info.get("character_maximum_length") + + # Check if migration is needed + needs_migration = False + + if migration["column"] == "entity_name" and current_length == 255: + needs_migration = True + elif ( + migration["column"] in ["source_id", "target_id"] + and current_length == 256 + ): + needs_migration = True + elif ( + migration["column"] == "file_path" + and current_type == "character varying" + ): + needs_migration = True + + if needs_migration: + logger.info( + f"Migrating {migration['table']}.{migration['column']}: {migration['description']}" + ) + + # Execute the migration + alter_sql = f""" + ALTER TABLE {migration["table"]} + ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]} + """ + + await self.execute(alter_sql) + logger.info( + f"Successfully migrated {migration['table']}.{migration['column']}" + ) + else: + logger.debug( + f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed" + ) + + except Exception as e: + # Log error but don't interrupt the process + logger.warning( + f"Failed to migrate {migration['table']}.{migration['column']}: {e}" + ) async def check_tables(self): # First create all tables @@ -1082,59 +887,47 @@ class PostgreSQLDB: ) raise e - # Batch check all indexes at once (optimization: single query instead of N queries) - try: - table_names = list(TABLES.keys()) - table_names_lower = [t.lower() for t in table_names] - - # Get all existing indexes for our tables in one query - check_all_indexes_sql = """ - SELECT indexname, tablename - FROM pg_indexes - WHERE tablename = ANY($1) - """ - existing_indexes_result = await self.query( - check_all_indexes_sql, [table_names_lower], multirows=True - ) - - # Build a set of existing index names for fast lookup - existing_indexes = set() - if existing_indexes_result: - existing_indexes = {row["indexname"] for row in existing_indexes_result} - - # Create missing indexes - for k in table_names: - # Create index for id column if missing + # Create index for id column in each table + try: index_name = f"idx_{k.lower()}_id" - if index_name not in existing_indexes: - try: - create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" - logger.info( - f"PostgreSQL, Creating index {index_name} on table {k}" - ) - await self.execute(create_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create index {index_name}, Got: {e}" - ) + check_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{index_name}' + AND tablename = '{k.lower()}' + """ + index_exists = await self.query(check_index_sql) - # Create composite index for (workspace, id) if missing + if not index_exists: + create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" + logger.info(f"PostgreSQL, Creating index {index_name} on table {k}") + await self.execute(create_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index on table {k}, Got: {e}" + ) + + # Create composite index for (workspace, id) columns in each table + try: composite_index_name = f"idx_{k.lower()}_workspace_id" - if composite_index_name not in existing_indexes: - try: - create_composite_index_sql = ( - f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)" - ) - logger.info( - f"PostgreSQL, Creating composite index {composite_index_name} on table {k}" - ) - await self.execute(create_composite_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create composite index {composite_index_name}, Got: {e}" - ) - except Exception as e: - logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}") + check_composite_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{composite_index_name}' + AND tablename = '{k.lower()}' + """ + composite_index_exists = await self.query(check_composite_index_sql) + + if not composite_index_exists: + create_composite_index_sql = ( + f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)" + ) + logger.info( + f"PostgreSQL, Creating composite index {composite_index_name} on table {k}" + ) + await self.execute(create_composite_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create composite index on table {k}, Got: {e}" + ) # Create vector indexs if self.vector_index_type: @@ -1450,31 +1243,35 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ) -> dict[str, Any] | None | list[dict[str, Any]]: - async def _operation(connection: asyncpg.Connection) -> Any: - prepared_params = tuple(params) if params else () - if prepared_params: - rows = await connection.fetch(sql, *prepared_params) - else: - rows = await connection.fetch(sql) + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) # type: ignore + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") - if multirows: - if rows: - columns = [col for col in rows[0].keys()] - return [dict(zip(columns, row)) for row in rows] - return [] + try: + if params: + rows = await connection.fetch(sql, *params) + else: + rows = await connection.fetch(sql) - if rows: - columns = rows[0].keys() - return dict(zip(columns, rows[0])) - return None + if multirows: + if rows: + columns = [col for col in rows[0].keys()] + data = [dict(zip(columns, row)) for row in rows] + else: + data = [] + else: + if rows: + columns = rows[0].keys() + data = dict(zip(columns, rows[0])) + else: + data = None - try: - return await self._run_with_retry( - _operation, with_age=with_age, graph_name=graph_name - ) - except Exception as e: - logger.error(f"PostgreSQL database, error:{e}") - raise + return data + except Exception as e: + logger.error(f"PostgreSQL database, error:{e}") + raise async def execute( self, @@ -1485,33 +1282,30 @@ class PostgreSQLDB: with_age: bool = False, graph_name: str | None = None, ): - async def _operation(connection: asyncpg.Connection) -> Any: - prepared_values = tuple(data.values()) if data else () - try: - if not data: - return await connection.execute(sql) - return await connection.execute(sql, *prepared_values) - except ( - asyncpg.exceptions.UniqueViolationError, - asyncpg.exceptions.DuplicateTableError, - asyncpg.exceptions.DuplicateObjectError, - asyncpg.exceptions.InvalidSchemaNameError, - ) as e: - if ignore_if_exists: - logger.debug("PostgreSQL, ignoring duplicate during execute: %r", e) - return None - if upsert: - logger.info( - "PostgreSQL, duplicate detected but treated as upsert success: %r", - e, - ) - return None - raise - try: - await self._run_with_retry( - _operation, with_age=with_age, graph_name=graph_name - ) + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + + if data is None: + await connection.execute(sql) + else: + await connection.execute(sql, *data.values()) + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error + asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists" + ) as e: + if ignore_if_exists: + # If the flag is set, just ignore these specific errors + pass + elif upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") except Exception as e: logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise @@ -1598,56 +1392,9 @@ class ClientManager: ), # Server settings for Supabase "server_settings": os.environ.get( - "POSTGRES_SERVER_SETTINGS", + "POSTGRES_SERVER_OPTIONS", config.get("postgres", "server_options", fallback=None), ), - "statement_cache_size": os.environ.get( - "POSTGRES_STATEMENT_CACHE_SIZE", - config.get("postgres", "statement_cache_size", fallback=None), - ), - # Connection retry configuration - "connection_retry_attempts": min( - 10, - int( - os.environ.get( - "POSTGRES_CONNECTION_RETRIES", - config.get("postgres", "connection_retries", fallback=3), - ) - ), - ), - "connection_retry_backoff": min( - 5.0, - float( - os.environ.get( - "POSTGRES_CONNECTION_RETRY_BACKOFF", - config.get( - "postgres", "connection_retry_backoff", fallback=0.5 - ), - ) - ), - ), - "connection_retry_backoff_max": min( - 60.0, - float( - os.environ.get( - "POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", - config.get( - "postgres", - "connection_retry_backoff_max", - fallback=5.0, - ), - ) - ), - ), - "pool_close_timeout": min( - 30.0, - float( - os.environ.get( - "POSTGRES_POOL_CLOSE_TIMEOUT", - config.get("postgres", "pool_close_timeout", fallback=5.0), - ) - ), - ), } @classmethod @@ -1708,6 +1455,113 @@ class PGKVStorage(BaseKVStorage): self.db = None ################ QUERY METHODS ################ + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error( + f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}" + ) + return {} + + sql = f"SELECT * FROM {table_name} WHERE workspace=$1" + params = {"workspace": self.workspace} + + try: + results = await self.db.query(sql, list(params.values()), multirows=True) + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + processed_results = {} + for row in results: + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + # Map field names and add cache_type for compatibility + processed_row = { + **row, + "return": row.get("return_value", ""), + "cache_type": row.get("original_prompt", "unknow"), + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "mode": row.get("mode", "default"), + "create_time": create_time, + "update_time": create_time if update_time == 0 else update_time, + } + processed_results[row["id"]] = processed_row + return processed_results + + # For text_chunks namespace, parse llm_cache_list JSON string back to list + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + processed_results = {} + for row in results: + llm_cache_list = row.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + row["llm_cache_list"] = llm_cache_list + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + row["create_time"] = create_time + row["update_time"] = ( + create_time if update_time == 0 else update_time + ) + processed_results[row["id"]] = row + return processed_results + + # For FULL_ENTITIES namespace, parse entity_names JSON string back to list + if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): + processed_results = {} + for row in results: + entity_names = row.get("entity_names", []) + if isinstance(entity_names, str): + try: + entity_names = json.loads(entity_names) + except json.JSONDecodeError: + entity_names = [] + row["entity_names"] = entity_names + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + row["create_time"] = create_time + row["update_time"] = ( + create_time if update_time == 0 else update_time + ) + processed_results[row["id"]] = row + return processed_results + + # For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list + if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS): + processed_results = {} + for row in results: + relation_pairs = row.get("relation_pairs", []) + if isinstance(relation_pairs, str): + try: + relation_pairs = json.loads(relation_pairs) + except json.JSONDecodeError: + relation_pairs = [] + row["relation_pairs"] = relation_pairs + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + row["create_time"] = create_time + row["update_time"] = ( + create_time if update_time == 0 else update_time + ) + processed_results[row["id"]] = row + return processed_results + + # For other namespaces, return as-is + return {row["id"]: row for row in results} + except Exception as e: + logger.error( + f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}" + ) + return {} + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] @@ -1783,70 +1637,17 @@ class PGKVStorage(BaseKVStorage): response["create_time"] = create_time response["update_time"] = create_time if update_time == 0 else update_time - # Special handling for ENTITY_CHUNKS namespace - if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): - # Parse chunk_ids JSON string back to list - chunk_ids = response.get("chunk_ids", []) - if isinstance(chunk_ids, str): - try: - chunk_ids = json.loads(chunk_ids) - except json.JSONDecodeError: - chunk_ids = [] - response["chunk_ids"] = chunk_ids - create_time = response.get("create_time", 0) - update_time = response.get("update_time", 0) - response["create_time"] = create_time - response["update_time"] = create_time if update_time == 0 else update_time - - # Special handling for RELATION_CHUNKS namespace - if response and is_namespace( - self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS - ): - # Parse chunk_ids JSON string back to list - chunk_ids = response.get("chunk_ids", []) - if isinstance(chunk_ids, str): - try: - chunk_ids = json.loads(chunk_ids) - except json.JSONDecodeError: - chunk_ids = [] - response["chunk_ids"] = chunk_ids - create_time = response.get("create_time", 0) - update_time = response.get("update_time", 0) - response["create_time"] = create_time - response["update_time"] = create_time if update_time == 0 else update_time - return response if response else None # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get data by ids""" - if not ids: - return [] - - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace] - params = {"workspace": self.workspace, "ids": ids} + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.workspace} results = await self.db.query(sql, list(params.values()), multirows=True) - def _order_results( - rows: list[dict[str, Any]] | None, - ) -> list[dict[str, Any] | None]: - """Preserve the caller requested ordering for bulk id lookups.""" - if not rows: - return [None for _ in ids] - - id_map: dict[str, dict[str, Any]] = {} - for row in rows: - if row is None: - continue - row_id = row.get("id") - if row_id is not None: - id_map[str(row_id)] = row - - ordered: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered.append(id_map.get(str(requested_id))) - return ordered - if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result for result in results: @@ -1889,7 +1690,7 @@ class PGKVStorage(BaseKVStorage): "update_time": create_time if update_time == 0 else update_time, } processed_results.append(processed_row) - return _order_results(processed_results) + return processed_results # Special handling for FULL_ENTITIES namespace if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): @@ -1923,48 +1724,15 @@ class PGKVStorage(BaseKVStorage): result["create_time"] = create_time result["update_time"] = create_time if update_time == 0 else update_time - # Special handling for ENTITY_CHUNKS namespace - if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): - for result in results: - # Parse chunk_ids JSON string back to list - chunk_ids = result.get("chunk_ids", []) - if isinstance(chunk_ids, str): - try: - chunk_ids = json.loads(chunk_ids) - except json.JSONDecodeError: - chunk_ids = [] - result["chunk_ids"] = chunk_ids - create_time = result.get("create_time", 0) - update_time = result.get("update_time", 0) - result["create_time"] = create_time - result["update_time"] = create_time if update_time == 0 else update_time - - # Special handling for RELATION_CHUNKS namespace - if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): - for result in results: - # Parse chunk_ids JSON string back to list - chunk_ids = result.get("chunk_ids", []) - if isinstance(chunk_ids, str): - try: - chunk_ids = json.loads(chunk_ids) - except json.JSONDecodeError: - chunk_ids = [] - result["chunk_ids"] = chunk_ids - create_time = result.get("create_time", 0) - update_time = result.get("update_time", 0) - result["create_time"] = create_time - result["update_time"] = create_time if update_time == 0 else update_time - - return _order_results(results) + return results if results else [] async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" - if not keys: - return set() - - table_name = namespace_to_table_name(self.namespace) - sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.workspace, "ids": list(keys)} + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.workspace} try: res = await self.db.query(sql, list(params.values()), multirows=True) if res: @@ -2059,61 +1827,11 @@ class PGKVStorage(BaseKVStorage): "update_time": current_time, } await self.db.execute(upsert_sql, _data) - elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): - # Get current UTC time and convert to naive datetime for database storage - current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"] - _data = { - "workspace": self.workspace, - "id": k, - "chunk_ids": json.dumps(v["chunk_ids"]), - "count": v["count"], - "create_time": current_time, - "update_time": current_time, - } - await self.db.execute(upsert_sql, _data) - elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): - # Get current UTC time and convert to naive datetime for database storage - current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"] - _data = { - "workspace": self.workspace, - "id": k, - "chunk_ids": json.dumps(v["chunk_ids"]), - "count": v["count"], - "create_time": current_time, - "update_time": current_time, - } - await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: # PG handles persistence automatically pass - async def is_empty(self) -> bool: - """Check if the storage is empty for the current workspace and namespace - - Returns: - bool: True if storage is empty, False otherwise - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" - ) - return True - - sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" - - try: - result = await self.db.query(sql, [self.workspace]) - return not result.get("has_data", False) if result else True - except Exception as e: - logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") - return True - async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -2464,23 +2182,7 @@ class PGVectorStorage(BaseVectorStorage): try: results = await self.db.query(query, list(params.values()), multirows=True) - if not results: - return [] - - # Preserve caller requested ordering while normalizing asyncpg rows to dicts. - id_map: dict[str, dict[str, Any]] = {} - for record in results: - if record is None: - continue - record_dict = dict(record) - row_id = record_dict.get("id") - if row_id is not None: - id_map[str(row_id)] = record_dict - - ordered_results: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered_results.append(id_map.get(str(requested_id))) - return ordered_results + return [dict(record) for record in results] except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" @@ -2593,12 +2295,11 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" - if not keys: - return set() - - table_name = namespace_to_table_name(self.namespace) - sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.workspace, "ids": list(keys)} + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.workspace} try: res = await self.db.query(sql, list(params.values()), multirows=True) if res: @@ -2669,7 +2370,7 @@ class PGDocStatusStorage(DocStatusStorage): if not results: return [] - processed_map: dict[str, dict[str, Any]] = {} + processed_results = [] for row in results: # Parse chunks_list JSON string back to list chunks_list = row.get("chunks_list", []) @@ -2691,25 +2392,23 @@ class PGDocStatusStorage(DocStatusStorage): created_at = self._format_datetime_with_timezone(row["created_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"]) - processed_map[str(row.get("id"))] = { - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": created_at, - "updated_at": updated_at, - "file_path": row["file_path"], - "chunks_list": chunks_list, - "metadata": metadata, - "error_msg": row.get("error_msg"), - "track_id": row.get("track_id"), - } + processed_results.append( + { + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": created_at, + "updated_at": updated_at, + "file_path": row["file_path"], + "chunks_list": chunks_list, + "metadata": metadata, + "error_msg": row.get("error_msg"), + "track_id": row.get("track_id"), + } + ) - ordered_results: list[dict[str, Any] | None] = [] - for requested_id in ids: - ordered_results.append(processed_map.get(str(requested_id))) - - return ordered_results + return processed_results async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Get document by file path @@ -2911,33 +2610,26 @@ class PGDocStatusStorage(DocStatusStorage): elif page_size > 200: page_size = 200 - # Whitelist validation for sort_field to prevent SQL injection - allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"} - if sort_field not in allowed_sort_fields: + if sort_field not in ["created_at", "updated_at", "id", "file_path"]: sort_field = "updated_at" - # Whitelist validation for sort_direction to prevent SQL injection if sort_direction.lower() not in ["asc", "desc"]: sort_direction = "desc" - else: - sort_direction = sort_direction.lower() # Calculate offset offset = (page - 1) * page_size - # Build parameterized query components + # Build WHERE clause + where_clause = "WHERE workspace=$1" params = {"workspace": self.workspace} param_count = 1 - # Build WHERE clause with parameterized query if status_filter is not None: param_count += 1 - where_clause = "WHERE workspace=$1 AND status=$2" + where_clause += f" AND status=${param_count}" params["status"] = status_filter.value - else: - where_clause = "WHERE workspace=$1" - # Build ORDER BY clause using validated whitelist values + # Build ORDER BY clause order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}" # Query for total count @@ -2945,7 +2637,7 @@ class PGDocStatusStorage(DocStatusStorage): count_result = await self.db.query(count_sql, list(params.values())) total_count = count_result["total"] if count_result else 0 - # Query for paginated data with parameterized LIMIT and OFFSET + # Query for paginated data data_sql = f""" SELECT * FROM LIGHTRAG_DOC_STATUS {where_clause} @@ -3029,28 +2721,6 @@ class PGDocStatusStorage(DocStatusStorage): # PG handles persistence automatically pass - async def is_empty(self) -> bool: - """Check if the storage is empty for the current workspace and namespace - - Returns: - bool: True if storage is empty, False otherwise - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" - ) - return True - - sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" - - try: - result = await self.db.query(sql, [self.workspace]) - return not result.get("has_data", False) if result else True - except Exception as e: - logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") - return True - async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -3498,8 +3168,7 @@ class PGGraphStorage(BaseGraphStorage): { "message": f"Error executing graph query: {query}", "wrapped": query, - "detail": repr(e), - "error_type": e.__class__.__name__, + "detail": str(e), } ) from e @@ -4174,6 +3843,102 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves nodes from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all nodes + where the `source_id` property contains any of the specified chunk IDs. + """ + # The string representation of the list for the cypher query + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH (n:base) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') + RETURN n + $$) AS (n agtype); + """ + results = await self._query(query) + + # Build result list + nodes = [] + for result in results: + if result["n"]: + node_dict = result["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" + ) + + node_dict["id"] = node_dict["entity_id"] + nodes.append(node_dict) + + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """ + Retrieves edges from the graph that are associated with a given list of chunk IDs. + This method uses a Cypher query with UNWIND to efficiently find all edges + where the `source_id` property contains any of the specified chunk IDs. + """ + chunk_ids_str = json.dumps(chunk_ids) + + query = f""" + SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND {chunk_ids_str} AS chunk_id + MATCH ()-[r]-() + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') + RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target + $$) AS (edge agtype, source agtype, target agtype); + """ + results = await self._query(query) + edges = [] + if results: + for item in results: + edge_agtype = item["edge"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(edge_agtype, str): + try: + edge_agtype = json.loads(edge_agtype) + except json.JSONDecodeError: + logger.warning( + f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}" + ) + + source_agtype = item["source"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(source_agtype, str): + try: + source_agtype = json.loads(source_agtype) + except json.JSONDecodeError: + logger.warning( + f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}" + ) + + target_agtype = item["target"]["properties"] + # Process string result, parse it to JSON dictionary + if isinstance(target_agtype, str): + try: + target_agtype = json.loads(target_agtype) + except json.JSONDecodeError: + logger.warning( + f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}" + ) + + if edge_agtype and source_agtype and target_agtype: + edge_properties = edge_agtype + edge_properties["source"] = source_agtype["entity_id"] + edge_properties["target"] = target_agtype["entity_id"] + edges.append(edge_properties) + return edges + async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: @@ -4516,19 +4281,16 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all nodes, where each node is a dictionary of its properties """ - # Use native SQL to avoid Cypher wrapper overhead - # Original: SELECT * FROM cypher(...) with MATCH (n:base) - # Optimized: Direct table access for better performance - query = f""" - SELECT properties - FROM {self.graph_name}.base - """ + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + RETURN n + $$) AS (n agtype)""" results = await self._query(query) nodes = [] for result in results: - if result.get("properties"): - node_dict = result["properties"] + if result["n"]: + node_dict = result["n"]["properties"] # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): @@ -4538,7 +4300,6 @@ class PGGraphStorage(BaseGraphStorage): logger.warning( f"[{self.workspace}] Failed to parse node string: {node_dict}" ) - continue # Add node id (entity_id) to the dictionary for easier access node_dict["id"] = node_dict.get("entity_id") @@ -4550,21 +4311,12 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all edges, where each edge is a dictionary of its properties - (If 2 directional edges exist between the same pair of nodes, deduplication must be handled by the caller) - """ - # Use native SQL to avoid Cartesian product (N×N) in Cypher MATCH - # Original Cypher: MATCH (a:base)-[r]-(b:base) creates ~50 billion row combinations - # Optimized: Start from edges table, join to nodes only to get entity_id - # Performance: O(E) instead of O(N²), ~50,000x faster for large graphs - query = f""" - SELECT DISTINCT - (ag_catalog.agtype_access_operator(VARIADIC ARRAY[a.properties, '"entity_id"'::agtype]))::text AS source, - (ag_catalog.agtype_access_operator(VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]))::text AS target, - r.properties - FROM {self.graph_name}."DIRECTED" r - JOIN {self.graph_name}.base a ON r.start_id = a.id - JOIN {self.graph_name}.base b ON r.end_id = b.id + (The edge is bidirectional; deduplication must be handled by the caller) """ + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (a:base)-[r]-(b:base) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + $$) AS (source text, target text, properties agtype)""" results = await self._query(query) edges = [] @@ -4719,8 +4471,6 @@ NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES", NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS", - NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS", - NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS", NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", @@ -4861,28 +4611,6 @@ TABLES = { CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id) )""" }, - "LIGHTRAG_ENTITY_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS ( - id VARCHAR(512), - workspace VARCHAR(255), - chunk_ids JSONB, - count INTEGER, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_RELATION_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS ( - id VARCHAR(512), - workspace VARCHAR(255), - chunk_ids JSONB, - count INTEGER, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id) - )""" - }, } @@ -4906,19 +4634,19 @@ SQL_TEMPLATES = { """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content, COALESCE(doc_name, '') as file_path - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, chunk_order_index, full_doc_id, file_path, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) """, "get_by_id_full_entities": """SELECT id, entity_names, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, @@ -4933,32 +4661,12 @@ SQL_TEMPLATES = { "get_by_ids_full_entities": """SELECT id, entity_names, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_full_relations": """SELECT id, relation_pairs, count, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2) - """, - "get_by_id_entity_chunks": """SELECT id, chunk_ids, count, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2 - """, - "get_by_id_relation_chunks": """SELECT id, chunk_ids, count, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2 - """, - "get_by_ids_entity_chunks": """SELECT id, chunk_ids, count, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2) - """, - "get_by_ids_relation_chunks": """SELECT id, chunk_ids, count, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2) + FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) @@ -5007,22 +4715,6 @@ SQL_TEMPLATES = { count=EXCLUDED.count, update_time = EXCLUDED.update_time """, - "upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (workspace,id) DO UPDATE - SET chunk_ids=EXCLUDED.chunk_ids, - count=EXCLUDED.count, - update_time = EXCLUDED.update_time - """, - "upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (workspace,id) DO UPDATE - SET chunk_ids=EXCLUDED.chunk_ids, - count=EXCLUDED.count, - update_time = EXCLUDED.update_time - """, # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path,