From 09d0721cab5e3e3dde34e9bbfed9df7f4041bb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:22:01 +0800 Subject: [PATCH] fix: sync postgres and shared_storage from upstream --- lightrag/kg/postgres_impl.py | 1058 ++++++++++++++++++---------------- 1 file changed, 560 insertions(+), 498 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index f9912e56..49069ce3 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -33,8 +33,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger -from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm @@ -78,6 +77,9 @@ class PostgreSQLDB: self.hnsw_m = config.get("hnsw_m") self.hnsw_ef = config.get("hnsw_ef") self.ivfflat_lists = config.get("ivfflat_lists") + self.vchordrq_build_options = config.get("vchordrq_build_options") + self.vchordrq_probes = config.get("vchordrq_probes") + self.vchordrq_epsilon = config.get("vchordrq_epsilon") # Server settings self.server_settings = config.get("server_settings") @@ -363,7 +365,8 @@ class PostgreSQLDB: await self.configure_age(connection, graph_name) elif with_age and not graph_name: raise ValueError("Graph name is required when with_age is True") - + if self.vector_index_type == "VCHORDRQ": + await self.configure_vchordrq(connection) return await operation(connection) @staticmethod @@ -380,7 +383,7 @@ class PostgreSQLDB: async def configure_age_extension(connection: asyncpg.Connection) -> None: """Create AGE extension if it doesn't exist for graph operations.""" try: - await connection.execute("CREATE EXTENSION IF NOT EXISTS age") # type: ignore + await connection.execute("CREATE EXTENSION IF NOT EXISTS AGE CASCADE") # type: ignore logger.info("PostgreSQL, AGE extension enabled") except Exception as e: logger.warning(f"Could not create AGE extension: {e}") @@ -409,6 +412,29 @@ class PostgreSQLDB: ): pass + async def configure_vchordrq(self, connection: asyncpg.Connection) -> None: + """Configure VCHORDRQ extension for vector similarity search. + + Raises: + asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed + asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid + + Note: + This method does not catch exceptions. Configuration errors will fail-fast, + while transient connection errors will be retried by _run_with_retry. + """ + # Handle probes parameter - only set if non-empty value is provided + if self.vchordrq_probes and str(self.vchordrq_probes).strip(): + await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'") + logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}") + + # Handle epsilon parameter independently - check for None to allow 0.0 as valid value + if self.vchordrq_epsilon is not None: + await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}") + logger.debug( + f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}" + ) + async def _migrate_llm_cache_schema(self): """Migrate LLM cache schema: add new columns and remove deprecated mode field""" try: @@ -537,49 +563,74 @@ class PostgreSQLDB: "LIGHTRAG_DOC_STATUS": ["created_at", "updated_at"], } - for table_name, columns in tables_to_migrate.items(): - for column_name in columns: - try: - # Check if column exists - check_column_sql = f""" - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = '{table_name.lower()}' - AND column_name = '{column_name}' - """ + try: + # Optimization: Batch check all columns in one query instead of 8 separate queries + table_names_lower = [t.lower() for t in tables_to_migrate.keys()] + all_column_names = list( + set(col for cols in tables_to_migrate.values() for col in cols) + ) - column_info = await self.query(check_column_sql) - if not column_info: + check_all_columns_sql = """ + SELECT table_name, column_name, data_type + FROM information_schema.columns + WHERE table_name = ANY($1) + AND column_name = ANY($2) + """ + + all_columns_result = await self.query( + check_all_columns_sql, + [table_names_lower, all_column_names], + multirows=True, + ) + + # Build lookup dict: (table_name, column_name) -> data_type + column_types = {} + if all_columns_result: + column_types = { + (row["table_name"].upper(), row["column_name"]): row["data_type"] + for row in all_columns_result + } + + # Now iterate and migrate only what's needed + for table_name, columns in tables_to_migrate.items(): + for column_name in columns: + try: + data_type = column_types.get((table_name, column_name)) + + if not data_type: + logger.warning( + f"Column {table_name}.{column_name} does not exist, skipping migration" + ) + continue + + # Check column type + if data_type == "timestamp without time zone": + logger.debug( + f"Column {table_name}.{column_name} is already witimezone-free, no migration needed" + ) + continue + + # Execute migration, explicitly specifying UTC timezone for interpreting original data + logger.info( + f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type" + ) + migration_sql = f""" + ALTER TABLE {table_name} + ALTER COLUMN {column_name} TYPE TIMESTAMP(0), + ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP + """ + + await self.execute(migration_sql) + logger.info( + f"Successfully migrated {table_name}.{column_name} to timezone-free type" + ) + except Exception as e: + # Log error but don't interrupt the process logger.warning( - f"Column {table_name}.{column_name} does not exist, skipping migration" + f"Failed to migrate {table_name}.{column_name}: {e}" ) - continue - - # Check column type - data_type = column_info.get("data_type") - if data_type == "timestamp without time zone": - logger.debug( - f"Column {table_name}.{column_name} is already witimezone-free, no migration needed" - ) - continue - - # Execute migration, explicitly specifying UTC timezone for interpreting original data - logger.info( - f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type" - ) - migration_sql = f""" - ALTER TABLE {table_name} - ALTER COLUMN {column_name} TYPE TIMESTAMP(0), - ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP - """ - - await self.execute(migration_sql) - logger.info( - f"Successfully migrated {table_name}.{column_name} to timezone-free type" - ) - except Exception as e: - # Log error but don't interrupt the process - logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") + except Exception as e: + logger.error(f"Failed to batch check timestamp columns: {e}") async def _migrate_doc_chunks_to_vdb_chunks(self): """ @@ -956,73 +1007,89 @@ class PostgreSQLDB: }, ] - for migration in field_migrations: - try: - # Check current column definition - check_column_sql = """ - SELECT column_name, data_type, character_maximum_length, is_nullable - FROM information_schema.columns - WHERE table_name = $1 AND column_name = $2 - """ - params = { - "table_name": migration["table"].lower(), - "column_name": migration["column"], + try: + # Optimization: Batch check all columns in one query instead of 5 separate queries + unique_tables = list(set(m["table"].lower() for m in field_migrations)) + unique_columns = list(set(m["column"] for m in field_migrations)) + + check_all_columns_sql = """ + SELECT table_name, column_name, data_type, character_maximum_length, is_nullable + FROM information_schema.columns + WHERE table_name = ANY($1) + AND column_name = ANY($2) + """ + + all_columns_result = await self.query( + check_all_columns_sql, [unique_tables, unique_columns], multirows=True + ) + + # Build lookup dict: (table_name, column_name) -> column_info + column_info_map = {} + if all_columns_result: + column_info_map = { + (row["table_name"].upper(), row["column_name"]): row + for row in all_columns_result } - column_info = await self.query( - check_column_sql, - list(params.values()), - ) - if not column_info: + # Now iterate and migrate only what's needed + for migration in field_migrations: + try: + column_info = column_info_map.get( + (migration["table"], migration["column"]) + ) + + if not column_info: + logger.warning( + f"Column {migration['table']}.{migration['column']} does not exist, skipping migration" + ) + continue + + current_type = column_info.get("data_type", "").lower() + current_length = column_info.get("character_maximum_length") + + # Check if migration is needed + needs_migration = False + + if migration["column"] == "entity_name" and current_length == 255: + needs_migration = True + elif ( + migration["column"] in ["source_id", "target_id"] + and current_length == 256 + ): + needs_migration = True + elif ( + migration["column"] == "file_path" + and current_type == "character varying" + ): + needs_migration = True + + if needs_migration: + logger.info( + f"Migrating {migration['table']}.{migration['column']}: {migration['description']}" + ) + + # Execute the migration + alter_sql = f""" + ALTER TABLE {migration["table"]} + ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]} + """ + + await self.execute(alter_sql) + logger.info( + f"Successfully migrated {migration['table']}.{migration['column']}" + ) + else: + logger.debug( + f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed" + ) + + except Exception as e: + # Log error but don't interrupt the process logger.warning( - f"Column {migration['table']}.{migration['column']} does not exist, skipping migration" + f"Failed to migrate {migration['table']}.{migration['column']}: {e}" ) - continue - - current_type = column_info.get("data_type", "").lower() - current_length = column_info.get("character_maximum_length") - - # Check if migration is needed - needs_migration = False - - if migration["column"] == "entity_name" and current_length == 255: - needs_migration = True - elif ( - migration["column"] in ["source_id", "target_id"] - and current_length == 256 - ): - needs_migration = True - elif ( - migration["column"] == "file_path" - and current_type == "character varying" - ): - needs_migration = True - - if needs_migration: - logger.info( - f"Migrating {migration['table']}.{migration['column']}: {migration['description']}" - ) - - # Execute the migration - alter_sql = f""" - ALTER TABLE {migration["table"]} - ALTER COLUMN {migration["column"]} TYPE {migration["new_type"]} - """ - - await self.execute(alter_sql) - logger.info( - f"Successfully migrated {migration['table']}.{migration['column']}" - ) - else: - logger.debug( - f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed" - ) - - except Exception as e: - # Log error but don't interrupt the process - logger.warning( - f"Failed to migrate {migration['table']}.{migration['column']}: {e}" - ) + except Exception as e: + logger.error(f"Failed to batch check field lengths: {e}") async def check_tables(self): # First create all tables @@ -1042,47 +1109,59 @@ class PostgreSQLDB: ) raise e - # Create index for id column in each table - try: + # Batch check all indexes at once (optimization: single query instead of N queries) + try: + table_names = list(TABLES.keys()) + table_names_lower = [t.lower() for t in table_names] + + # Get all existing indexes for our tables in one query + check_all_indexes_sql = """ + SELECT indexname, tablename + FROM pg_indexes + WHERE tablename = ANY($1) + """ + existing_indexes_result = await self.query( + check_all_indexes_sql, [table_names_lower], multirows=True + ) + + # Build a set of existing index names for fast lookup + existing_indexes = set() + if existing_indexes_result: + existing_indexes = {row["indexname"] for row in existing_indexes_result} + + # Create missing indexes + for k in table_names: + # Create index for id column if missing index_name = f"idx_{k.lower()}_id" - check_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{index_name}' - AND tablename = '{k.lower()}' - """ - index_exists = await self.query(check_index_sql) + if index_name not in existing_indexes: + try: + create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" + logger.info( + f"PostgreSQL, Creating index {index_name} on table {k}" + ) + await self.execute(create_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index {index_name}, Got: {e}" + ) - if not index_exists: - create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" - logger.info(f"PostgreSQL, Creating index {index_name} on table {k}") - await self.execute(create_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create index on table {k}, Got: {e}" - ) - - # Create composite index for (workspace, id) columns in each table - try: + # Create composite index for (workspace, id) if missing composite_index_name = f"idx_{k.lower()}_workspace_id" - check_composite_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{composite_index_name}' - AND tablename = '{k.lower()}' - """ - composite_index_exists = await self.query(check_composite_index_sql) - - if not composite_index_exists: - create_composite_index_sql = ( - f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)" - ) - logger.info( - f"PostgreSQL, Creating composite index {composite_index_name} on table {k}" - ) - await self.execute(create_composite_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create composite index on table {k}, Got: {e}" - ) + if composite_index_name not in existing_indexes: + try: + create_composite_index_sql = ( + f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)" + ) + logger.info( + f"PostgreSQL, Creating composite index {composite_index_name} on table {k}" + ) + await self.execute(create_composite_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create composite index {composite_index_name}, Got: {e}" + ) + except Exception as e: + logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}") # Create vector indexs if self.vector_index_type: @@ -1090,19 +1169,12 @@ class PostgreSQLDB: f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" ) try: - if self.vector_index_type == "HNSW": - await self._create_hnsw_vector_indexes() - elif self.vector_index_type == "IVFFLAT": - await self._create_ivfflat_vector_indexes() - elif self.vector_index_type == "FLAT": - logger.warning( - "FLAT index type is not supported by pgvector. Skipping vector index creation. " - "Please use 'HNSW' or 'IVFFLAT' instead." - ) + if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: + await self._create_vector_indexes() else: logger.warning( "Doesn't support this vector index type: {self.vector_index_type}. " - "Supported types: HNSW, IVFFLAT" + "Supported types: HNSW, IVFFLAT, VCHORDRQ" ) except Exception as e: logger.error( @@ -1309,21 +1381,39 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to create index {index['name']}: {e}") - async def _create_hnsw_vector_indexes(self): + async def _create_vector_indexes(self): vdb_tables = [ "LIGHTRAG_VDB_CHUNKS", "LIGHTRAG_VDB_ENTITY", "LIGHTRAG_VDB_RELATION", ] - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) + create_sql = { + "HNSW": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING hnsw (content_vector vector_cosine_ops) + WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) + """, + "IVFFLAT": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING ivfflat (content_vector vector_cosine_ops) + WITH (lists = {self.ivfflat_lists}) + """, + "VCHORDRQ": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING vchordrq (content_vector vector_cosine_ops) + {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''} + """, + } + embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) for k in vdb_tables: - vector_index_name = f"idx_{k.lower()}_hnsw_cosine" + vector_index_name = ( + f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" + ) check_vector_index_sql = f""" SELECT 1 FROM pg_indexes - WHERE indexname = '{vector_index_name}' - AND tablename = '{k.lower()}' + WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' """ try: vector_index_exists = await self.query(check_vector_index_sql) @@ -1332,64 +1422,24 @@ class PostgreSQLDB: alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" await self.execute(alter_sql) logger.debug(f"Ensured vector dimension for {k}") - - create_vector_index_sql = f""" - CREATE INDEX {vector_index_name} - ON {k} USING hnsw (content_vector vector_cosine_ops) - WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) - """ - logger.info(f"Creating hnsw index {vector_index_name} on table {k}") - await self.execute(create_vector_index_sql) + logger.info( + f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" + ) + await self.execute( + create_sql[self.vector_index_type].format( + vector_index_name=vector_index_name, k=k + ) + ) logger.info( f"Successfully created vector index {vector_index_name} on table {k}" ) else: logger.info( - f"HNSW vector index {vector_index_name} already exists on table {k}" + f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" ) except Exception as e: logger.error(f"Failed to create vector index on table {k}, Got: {e}") - async def _create_ivfflat_vector_indexes(self): - vdb_tables = [ - "LIGHTRAG_VDB_CHUNKS", - "LIGHTRAG_VDB_ENTITY", - "LIGHTRAG_VDB_RELATION", - ] - - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) - - for k in vdb_tables: - index_name = f"idx_{k.lower()}_ivfflat_cosine" - check_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{index_name}' AND tablename = '{k.lower()}' - """ - try: - exists = await self.query(check_index_sql) - if not exists: - # Only set vector dimension when index doesn't exist - alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" - await self.execute(alter_sql) - logger.debug(f"Ensured vector dimension for {k}") - - create_sql = f""" - CREATE INDEX {index_name} - ON {k} USING ivfflat (content_vector vector_cosine_ops) - WITH (lists = {self.ivfflat_lists}) - """ - logger.info(f"Creating ivfflat index {index_name} on table {k}") - await self.execute(create_sql) - logger.info( - f"Successfully created ivfflat index {index_name} on table {k}" - ) - else: - logger.info( - f"Ivfflat vector index {index_name} already exists on table {k}" - ) - except Exception as e: - logger.error(f"Failed to create ivfflat index on {k}: {e}") - async def query( self, sql: str, @@ -1544,6 +1594,20 @@ class ClientManager: config.get("postgres", "ivfflat_lists", fallback="100"), ) ), + "vchordrq_build_options": os.environ.get( + "POSTGRES_VCHORDRQ_BUILD_OPTIONS", + config.get("postgres", "vchordrq_build_options", fallback=""), + ), + "vchordrq_probes": os.environ.get( + "POSTGRES_VCHORDRQ_PROBES", + config.get("postgres", "vchordrq_probes", fallback=""), + ), + "vchordrq_epsilon": float( + os.environ.get( + "POSTGRES_VCHORDRQ_EPSILON", + config.get("postgres", "vchordrq_epsilon", fallback="1.9"), + ) + ), # Server settings for Supabase "server_settings": os.environ.get( "POSTGRES_SERVER_SETTINGS", @@ -1650,119 +1714,11 @@ class PGKVStorage(BaseKVStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None ################ QUERY METHODS ################ - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}" - ) - return {} - - sql = f"SELECT * FROM {table_name} WHERE workspace=$1" - params = {"workspace": self.workspace} - - try: - results = await self.db.query(sql, list(params.values()), multirows=True) - - # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - processed_results = {} - for row in results: - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - # Map field names and add cache_type for compatibility - processed_row = { - **row, - "return": row.get("return_value", ""), - "cache_type": row.get("original_prompt", "unknow"), - "original_prompt": row.get("original_prompt", ""), - "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default"), - "create_time": create_time, - "update_time": create_time if update_time == 0 else update_time, - } - processed_results[row["id"]] = processed_row - return processed_results - - # For text_chunks namespace, parse llm_cache_list JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - processed_results = {} - for row in results: - llm_cache_list = row.get("llm_cache_list", []) - if isinstance(llm_cache_list, str): - try: - llm_cache_list = json.loads(llm_cache_list) - except json.JSONDecodeError: - llm_cache_list = [] - row["llm_cache_list"] = llm_cache_list - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For FULL_ENTITIES namespace, parse entity_names JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): - processed_results = {} - for row in results: - entity_names = row.get("entity_names", []) - if isinstance(entity_names, str): - try: - entity_names = json.loads(entity_names) - except json.JSONDecodeError: - entity_names = [] - row["entity_names"] = entity_names - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS): - processed_results = {} - for row in results: - relation_pairs = row.get("relation_pairs", []) - if isinstance(relation_pairs, str): - try: - relation_pairs = json.loads(relation_pairs) - except json.JSONDecodeError: - relation_pairs = [] - row["relation_pairs"] = relation_pairs - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For other namespaces, return as-is - return {row["id"]: row for row in results} - except Exception as e: - logger.error( - f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}" - ) - return {} - async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] @@ -1838,6 +1794,38 @@ class PGKVStorage(BaseKVStorage): response["create_time"] = create_time response["update_time"] = create_time if update_time == 0 else update_time + # Special handling for ENTITY_CHUNKS namespace + if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + # Parse chunk_ids JSON string back to list + chunk_ids = response.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + response["chunk_ids"] = chunk_ids + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + response["create_time"] = create_time + response["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for RELATION_CHUNKS namespace + if response and is_namespace( + self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS + ): + # Parse chunk_ids JSON string back to list + chunk_ids = response.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + response["chunk_ids"] = chunk_ids + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + response["create_time"] = create_time + response["update_time"] = create_time if update_time == 0 else update_time + return response if response else None # Query by id @@ -1946,6 +1934,38 @@ class PGKVStorage(BaseKVStorage): result["create_time"] = create_time result["update_time"] = create_time if update_time == 0 else update_time + # Special handling for ENTITY_CHUNKS namespace + if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + for result in results: + # Parse chunk_ids JSON string back to list + chunk_ids = result.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + result["chunk_ids"] = chunk_ids + create_time = result.get("create_time", 0) + update_time = result.get("update_time", 0) + result["create_time"] = create_time + result["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for RELATION_CHUNKS namespace + if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): + for result in results: + # Parse chunk_ids JSON string back to list + chunk_ids = result.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + result["chunk_ids"] = chunk_ids + create_time = result.get("create_time", 0) + update_time = result.get("update_time", 0) + result["create_time"] = create_time + result["update_time"] = create_time if update_time == 0 else update_time + return _order_results(results) async def filter_keys(self, keys: set[str]) -> set[str]: @@ -2050,11 +2070,61 @@ class PGKVStorage(BaseKVStorage): "update_time": current_time, } await self.db.execute(upsert_sql, _data) + elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + # Get current UTC time and convert to naive datetime for database storage + current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"] + _data = { + "workspace": self.workspace, + "id": k, + "chunk_ids": json.dumps(v["chunk_ids"]), + "count": v["count"], + "create_time": current_time, + "update_time": current_time, + } + await self.db.execute(upsert_sql, _data) + elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): + # Get current UTC time and convert to naive datetime for database storage + current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"] + _data = { + "workspace": self.workspace, + "id": k, + "chunk_ids": json.dumps(v["chunk_ids"]), + "count": v["count"], + "create_time": current_time, + "update_time": current_time, + } + await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: # PG handles persistence automatically pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error( + f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" + ) + return True + + sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" + + try: + result = await self.db.query(sql, [self.workspace]) + return not result.get("has_data", False) if result else True + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -2088,22 +2158,21 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2138,10 +2207,9 @@ class PGVectorStorage(BaseVectorStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime @@ -2477,22 +2545,21 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2527,10 +2594,9 @@ class PGDocStatusStorage(DocStatusStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -2970,6 +3036,28 @@ class PGDocStatusStorage(DocStatusStorage): # PG handles persistence automatically pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error( + f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" + ) + return True + + sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" + + try: + result = await self.db.query(sql, [self.workspace]) + return not result.get("has_data", False) if result else True + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -3083,22 +3171,21 @@ class PGDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} class PGGraphQueryException(Exception): @@ -3230,10 +3317,9 @@ class PGGraphStorage(BaseGraphStorage): ) async def finalize(self): - async with get_graph_db_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -4112,102 +4198,6 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves nodes from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all nodes - where the `source_id` property contains any of the specified chunk IDs. - """ - # The string representation of the list for the cypher query - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH (n:base) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') - RETURN n - $$) AS (n agtype); - """ - results = await self._query(query) - - # Build result list - nodes = [] - for result in results: - if result["n"]: - node_dict = result["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" - ) - - node_dict["id"] = node_dict["entity_id"] - nodes.append(node_dict) - - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves edges from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all edges - where the `source_id` property contains any of the specified chunk IDs. - """ - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH ()-[r]-() - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') - RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target - $$) AS (edge agtype, source agtype, target agtype); - """ - results = await self._query(query) - edges = [] - if results: - for item in results: - edge_agtype = item["edge"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(edge_agtype, str): - try: - edge_agtype = json.loads(edge_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}" - ) - - source_agtype = item["source"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(source_agtype, str): - try: - source_agtype = json.loads(source_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}" - ) - - target_agtype = item["target"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(target_agtype, str): - try: - target_agtype = json.loads(target_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}" - ) - - if edge_agtype and source_agtype and target_agtype: - edge_properties = edge_agtype - edge_properties["source"] = source_agtype["entity_id"] - edge_properties["target"] = target_agtype["entity_id"] - edges.append(edge_properties) - return edges - async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: @@ -4550,16 +4540,19 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all nodes, where each node is a dictionary of its properties """ - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base) - RETURN n - $$) AS (n agtype)""" + # Use native SQL to avoid Cypher wrapper overhead + # Original: SELECT * FROM cypher(...) with MATCH (n:base) + # Optimized: Direct table access for better performance + query = f""" + SELECT properties + FROM {self.graph_name}.base + """ results = await self._query(query) nodes = [] for result in results: - if result["n"]: - node_dict = result["n"]["properties"] + if result.get("properties"): + node_dict = result["properties"] # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): @@ -4569,6 +4562,7 @@ class PGGraphStorage(BaseGraphStorage): logger.warning( f"[{self.workspace}] Failed to parse node string: {node_dict}" ) + continue # Add node id (entity_id) to the dictionary for easier access node_dict["id"] = node_dict.get("entity_id") @@ -4580,12 +4574,21 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all edges, where each edge is a dictionary of its properties - (The edge is bidirectional; deduplication must be handled by the caller) + (If 2 directional edges exist between the same pair of nodes, deduplication must be handled by the caller) + """ + # Use native SQL to avoid Cartesian product (N×N) in Cypher MATCH + # Original Cypher: MATCH (a:base)-[r]-(b:base) creates ~50 billion row combinations + # Optimized: Start from edges table, join to nodes only to get entity_id + # Performance: O(E) instead of O(N²), ~50,000x faster for large graphs + query = f""" + SELECT DISTINCT + (ag_catalog.agtype_access_operator(VARIADIC ARRAY[a.properties, '"entity_id"'::agtype]))::text AS source, + (ag_catalog.agtype_access_operator(VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]))::text AS target, + r.properties + FROM {self.graph_name}."DIRECTED" r + JOIN {self.graph_name}.base a ON r.start_id = a.id + JOIN {self.graph_name}.base b ON r.end_id = b.id """ - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (a:base)-[r]-(b:base) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties - $$) AS (source text, target text, properties agtype)""" results = await self._query(query) edges = [] @@ -4716,21 +4719,20 @@ class PGGraphStorage(BaseGraphStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_graph_db_lock(): - try: - drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n) - DETACH DELETE n - $$) AS (result agtype)""" + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" - await self._query(drop_query, readonly=False) - return { - "status": "success", - "message": f"workspace '{self.workspace}' graph data dropped", - } - except Exception as e: - logger.error(f"[{self.workspace}] Error dropping graph: {e}") - return {"status": "error", "message": str(e)} + await self._query(drop_query, readonly=False) + return { + "status": "success", + "message": f"workspace '{self.workspace}' graph data dropped", + } + except Exception as e: + logger.error(f"[{self.workspace}] Error dropping graph: {e}") + return {"status": "error", "message": str(e)} # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before @@ -4740,6 +4742,8 @@ NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES", NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS", + NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS", + NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS", NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", @@ -4880,6 +4884,28 @@ TABLES = { CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id) )""" }, + "LIGHTRAG_ENTITY_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS ( + id VARCHAR(512), + workspace VARCHAR(255), + chunk_ids JSONB, + count INTEGER, + create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_RELATION_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS ( + id VARCHAR(512), + workspace VARCHAR(255), + chunk_ids JSONB, + count INTEGER, + create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, } @@ -4937,6 +4963,26 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2) """, + "get_by_id_entity_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_id_relation_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_ids_entity_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2) + """, + "get_by_ids_relation_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2) + """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) VALUES ($1, $2, $3, $4) @@ -4984,6 +5030,22 @@ SQL_TEMPLATES = { count=EXCLUDED.count, update_time = EXCLUDED.update_time """, + "upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,id) DO UPDATE + SET chunk_ids=EXCLUDED.chunk_ids, + count=EXCLUDED.count, + update_time = EXCLUDED.update_time + """, + "upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,id) DO UPDATE + SET chunk_ids=EXCLUDED.chunk_ids, + count=EXCLUDED.count, + update_time = EXCLUDED.update_time + """, # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path,