From 0ac858d3e2cdd2a4214bee0dc47bdc1795b23450 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 18 Nov 2025 21:58:36 +0800 Subject: [PATCH] fix(postgres): allow vchordrq.epsilon config when probes is empty Previously, configure_vchordrq would fail silently when probes was empty (the default), preventing epsilon from being configured. Now each parameter is handled independently with conditional execution, and configuration errors fail-fast instead of being swallowed. This fixes the documented epsilon setting being impossible to use in the default configuration. (cherry picked from commit 3096f844fb814dd5b812378e759ae767beafde9f) --- lightrag/kg/postgres_impl.py | 709 +++++++++++++++++++---------------- 1 file changed, 381 insertions(+), 328 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a6d3ff04..ba5ec6d7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -33,7 +33,6 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger -from ..constants import GRAPH_FIELD_SEP from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -78,6 +77,9 @@ class PostgreSQLDB: self.hnsw_m = config.get("hnsw_m") self.hnsw_ef = config.get("hnsw_ef") self.ivfflat_lists = config.get("ivfflat_lists") + self.vchordrq_build_options = config.get("vchordrq_build_options") + self.vchordrq_probes = config.get("vchordrq_probes") + self.vchordrq_epsilon = config.get("vchordrq_epsilon") # Server settings self.server_settings = config.get("server_settings") @@ -85,24 +87,11 @@ class PostgreSQLDB: # Statement LRU cache size (keep as-is, allow None for optional configuration) self.statement_cache_size = config.get("statement_cache_size") - # Connection retry configuration - self.connection_retry_attempts = max( - 1, min(10, int(os.environ.get("POSTGRES_CONNECTION_RETRIES", 3))) - ) - self.connection_retry_backoff = max( - 0.1, - min(5.0, float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF", 0.5))), - ) - self.connection_retry_backoff_max = max( - self.connection_retry_backoff, - min( - 60.0, - float(os.environ.get("POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", 5.0)), - ), - ) - self.pool_close_timeout = max( - 1.0, min(30.0, float(os.environ.get("POSTGRES_POOL_CLOSE_TIMEOUT", 5.0))) - ) + if self.user is None or self.password is None or self.database is None: + raise ValueError("Missing database user, password, or database") + + # Guard concurrent pool resets + self._pool_reconnect_lock = asyncio.Lock() self._transient_exceptions = ( asyncio.TimeoutError, @@ -117,12 +106,14 @@ class PostgreSQLDB: asyncpg.exceptions.ConnectionFailureError, ) - # Guard concurrent pool resets - self._pool_reconnect_lock = asyncio.Lock() - - if self.user is None or self.password is None or self.database is None: - raise ValueError("Missing database user, password, or database") - + # Connection retry configuration + self.connection_retry_attempts = config["connection_retry_attempts"] + self.connection_retry_backoff = config["connection_retry_backoff"] + self.connection_retry_backoff_max = max( + self.connection_retry_backoff, + config["connection_retry_backoff_max"], + ) + self.pool_close_timeout = config["pool_close_timeout"] logger.info( "PostgreSQL, Retry config: attempts=%s, backoff=%.1fs, backoff_max=%.1fs, pool_close_timeout=%.1fs", self.connection_retry_attempts, @@ -215,9 +206,7 @@ class PostgreSQLDB: # Only add statement_cache_size if it's configured if self.statement_cache_size is not None: - connection_params["statement_cache_size"] = int( - self.statement_cache_size - ) + connection_params["statement_cache_size"] = int(self.statement_cache_size) logger.info( f"PostgreSQL, statement LRU cache size set as: {self.statement_cache_size}" ) @@ -376,7 +365,8 @@ class PostgreSQLDB: await self.configure_age(connection, graph_name) elif with_age and not graph_name: raise ValueError("Graph name is required when with_age is True") - + if self.vector_index_type == "VCHORDRQ": + await self.configure_vchordrq(connection) return await operation(connection) @staticmethod @@ -422,6 +412,29 @@ class PostgreSQLDB: ): pass + async def configure_vchordrq(self, connection: asyncpg.Connection) -> None: + """Configure VCHORDRQ extension for vector similarity search. + + Raises: + asyncpg.exceptions.UndefinedObjectError: If VCHORDRQ extension is not installed + asyncpg.exceptions.InvalidParameterValueError: If parameter value is invalid + + Note: + This method does not catch exceptions. Configuration errors will fail-fast, + while transient connection errors will be retried by _run_with_retry. + """ + # Handle probes parameter - only set if non-empty value is provided + if self.vchordrq_probes and str(self.vchordrq_probes).strip(): + await connection.execute(f"SET vchordrq.probes TO '{self.vchordrq_probes}'") + logger.debug(f"PostgreSQL, VCHORDRQ probes set to: {self.vchordrq_probes}") + + # Handle epsilon parameter independently - check for None to allow 0.0 as valid value + if self.vchordrq_epsilon is not None: + await connection.execute(f"SET vchordrq.epsilon TO {self.vchordrq_epsilon}") + logger.debug( + f"PostgreSQL, VCHORDRQ epsilon set to: {self.vchordrq_epsilon}" + ) + async def _migrate_llm_cache_schema(self): """Migrate LLM cache schema: add new columns and remove deprecated mode field""" try: @@ -1156,19 +1169,12 @@ class PostgreSQLDB: f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" ) try: - if self.vector_index_type == "HNSW": - await self._create_hnsw_vector_indexes() - elif self.vector_index_type == "IVFFLAT": - await self._create_ivfflat_vector_indexes() - elif self.vector_index_type == "FLAT": - logger.warning( - "FLAT index type is not supported by pgvector. Skipping vector index creation. " - "Please use 'HNSW' or 'IVFFLAT' instead." - ) + if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: + await self._create_vector_indexes() else: logger.warning( "Doesn't support this vector index type: {self.vector_index_type}. " - "Supported types: HNSW, IVFFLAT" + "Supported types: HNSW, IVFFLAT, VCHORDRQ" ) except Exception as e: logger.error( @@ -1375,21 +1381,39 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to create index {index['name']}: {e}") - async def _create_hnsw_vector_indexes(self): + async def _create_vector_indexes(self): vdb_tables = [ "LIGHTRAG_VDB_CHUNKS", "LIGHTRAG_VDB_ENTITY", "LIGHTRAG_VDB_RELATION", ] - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) + create_sql = { + "HNSW": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING hnsw (content_vector vector_cosine_ops) + WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) + """, + "IVFFLAT": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING ivfflat (content_vector vector_cosine_ops) + WITH (lists = {self.ivfflat_lists}) + """, + "VCHORDRQ": f""" + CREATE INDEX {{vector_index_name}} + ON {{k}} USING vchordrq (content_vector vector_cosine_ops) + {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''} + """, + } + embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) for k in vdb_tables: - vector_index_name = f"idx_{k.lower()}_hnsw_cosine" + vector_index_name = ( + f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" + ) check_vector_index_sql = f""" SELECT 1 FROM pg_indexes - WHERE indexname = '{vector_index_name}' - AND tablename = '{k.lower()}' + WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' """ try: vector_index_exists = await self.query(check_vector_index_sql) @@ -1398,64 +1422,24 @@ class PostgreSQLDB: alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" await self.execute(alter_sql) logger.debug(f"Ensured vector dimension for {k}") - - create_vector_index_sql = f""" - CREATE INDEX {vector_index_name} - ON {k} USING hnsw (content_vector vector_cosine_ops) - WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) - """ - logger.info(f"Creating hnsw index {vector_index_name} on table {k}") - await self.execute(create_vector_index_sql) + logger.info( + f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" + ) + await self.execute( + create_sql[self.vector_index_type].format( + vector_index_name=vector_index_name, k=k + ) + ) logger.info( f"Successfully created vector index {vector_index_name} on table {k}" ) else: logger.info( - f"HNSW vector index {vector_index_name} already exists on table {k}" + f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" ) except Exception as e: logger.error(f"Failed to create vector index on table {k}, Got: {e}") - async def _create_ivfflat_vector_indexes(self): - vdb_tables = [ - "LIGHTRAG_VDB_CHUNKS", - "LIGHTRAG_VDB_ENTITY", - "LIGHTRAG_VDB_RELATION", - ] - - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) - - for k in vdb_tables: - index_name = f"idx_{k.lower()}_ivfflat_cosine" - check_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{index_name}' AND tablename = '{k.lower()}' - """ - try: - exists = await self.query(check_index_sql) - if not exists: - # Only set vector dimension when index doesn't exist - alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" - await self.execute(alter_sql) - logger.debug(f"Ensured vector dimension for {k}") - - create_sql = f""" - CREATE INDEX {index_name} - ON {k} USING ivfflat (content_vector vector_cosine_ops) - WITH (lists = {self.ivfflat_lists}) - """ - logger.info(f"Creating ivfflat index {index_name} on table {k}") - await self.execute(create_sql) - logger.info( - f"Successfully created ivfflat index {index_name} on table {k}" - ) - else: - logger.info( - f"Ivfflat vector index {index_name} already exists on table {k}" - ) - except Exception as e: - logger.error(f"Failed to create ivfflat index on {k}: {e}") - async def query( self, sql: str, @@ -1610,6 +1594,20 @@ class ClientManager: config.get("postgres", "ivfflat_lists", fallback="100"), ) ), + "vchordrq_build_options": os.environ.get( + "POSTGRES_VCHORDRQ_BUILD_OPTIONS", + config.get("postgres", "vchordrq_build_options", fallback=""), + ), + "vchordrq_probes": os.environ.get( + "POSTGRES_VCHORDRQ_PROBES", + config.get("postgres", "vchordrq_probes", fallback=""), + ), + "vchordrq_epsilon": float( + os.environ.get( + "POSTGRES_VCHORDRQ_EPSILON", + config.get("postgres", "vchordrq_epsilon", fallback="1.9"), + ) + ), # Server settings for Supabase "server_settings": os.environ.get( "POSTGRES_SERVER_SETTINGS", @@ -1619,6 +1617,49 @@ class ClientManager: "POSTGRES_STATEMENT_CACHE_SIZE", config.get("postgres", "statement_cache_size", fallback=None), ), + # Connection retry configuration + "connection_retry_attempts": min( + 10, + int( + os.environ.get( + "POSTGRES_CONNECTION_RETRIES", + config.get("postgres", "connection_retries", fallback=3), + ) + ), + ), + "connection_retry_backoff": min( + 5.0, + float( + os.environ.get( + "POSTGRES_CONNECTION_RETRY_BACKOFF", + config.get( + "postgres", "connection_retry_backoff", fallback=0.5 + ), + ) + ), + ), + "connection_retry_backoff_max": min( + 60.0, + float( + os.environ.get( + "POSTGRES_CONNECTION_RETRY_BACKOFF_MAX", + config.get( + "postgres", + "connection_retry_backoff_max", + fallback=5.0, + ), + ) + ), + ), + "pool_close_timeout": min( + 30.0, + float( + os.environ.get( + "POSTGRES_POOL_CLOSE_TIMEOUT", + config.get("postgres", "pool_close_timeout", fallback=5.0), + ) + ), + ), } @classmethod @@ -1679,113 +1720,6 @@ class PGKVStorage(BaseKVStorage): self.db = None ################ QUERY METHODS ################ - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for get_all: {self.namespace}" - ) - return {} - - sql = f"SELECT * FROM {table_name} WHERE workspace=$1" - params = {"workspace": self.workspace} - - try: - results = await self.db.query(sql, list(params.values()), multirows=True) - - # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - processed_results = {} - for row in results: - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - # Map field names and add cache_type for compatibility - processed_row = { - **row, - "return": row.get("return_value", ""), - "cache_type": row.get("original_prompt", "unknow"), - "original_prompt": row.get("original_prompt", ""), - "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default"), - "create_time": create_time, - "update_time": create_time if update_time == 0 else update_time, - } - processed_results[row["id"]] = processed_row - return processed_results - - # For text_chunks namespace, parse llm_cache_list JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - processed_results = {} - for row in results: - llm_cache_list = row.get("llm_cache_list", []) - if isinstance(llm_cache_list, str): - try: - llm_cache_list = json.loads(llm_cache_list) - except json.JSONDecodeError: - llm_cache_list = [] - row["llm_cache_list"] = llm_cache_list - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For FULL_ENTITIES namespace, parse entity_names JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): - processed_results = {} - for row in results: - entity_names = row.get("entity_names", []) - if isinstance(entity_names, str): - try: - entity_names = json.loads(entity_names) - except json.JSONDecodeError: - entity_names = [] - row["entity_names"] = entity_names - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS): - processed_results = {} - for row in results: - relation_pairs = row.get("relation_pairs", []) - if isinstance(relation_pairs, str): - try: - relation_pairs = json.loads(relation_pairs) - except json.JSONDecodeError: - relation_pairs = [] - row["relation_pairs"] = relation_pairs - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For other namespaces, return as-is - return {row["id"]: row for row in results} - except Exception as e: - logger.error( - f"[{self.workspace}] Error retrieving all data from {self.namespace}: {e}" - ) - return {} - async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] @@ -1861,6 +1795,38 @@ class PGKVStorage(BaseKVStorage): response["create_time"] = create_time response["update_time"] = create_time if update_time == 0 else update_time + # Special handling for ENTITY_CHUNKS namespace + if response and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + # Parse chunk_ids JSON string back to list + chunk_ids = response.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + response["chunk_ids"] = chunk_ids + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + response["create_time"] = create_time + response["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for RELATION_CHUNKS namespace + if response and is_namespace( + self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS + ): + # Parse chunk_ids JSON string back to list + chunk_ids = response.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + response["chunk_ids"] = chunk_ids + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + response["create_time"] = create_time + response["update_time"] = create_time if update_time == 0 else update_time + return response if response else None # Query by id @@ -1868,7 +1834,7 @@ class PGKVStorage(BaseKVStorage): """Get data by ids""" if not ids: return [] - + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace] params = {"workspace": self.workspace, "ids": ids} results = await self.db.query(sql, list(params.values()), multirows=True) @@ -1969,13 +1935,45 @@ class PGKVStorage(BaseKVStorage): result["create_time"] = create_time result["update_time"] = create_time if update_time == 0 else update_time + # Special handling for ENTITY_CHUNKS namespace + if results and is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + for result in results: + # Parse chunk_ids JSON string back to list + chunk_ids = result.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + result["chunk_ids"] = chunk_ids + create_time = result.get("create_time", 0) + update_time = result.get("update_time", 0) + result["create_time"] = create_time + result["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for RELATION_CHUNKS namespace + if results and is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): + for result in results: + # Parse chunk_ids JSON string back to list + chunk_ids = result.get("chunk_ids", []) + if isinstance(chunk_ids, str): + try: + chunk_ids = json.loads(chunk_ids) + except json.JSONDecodeError: + chunk_ids = [] + result["chunk_ids"] = chunk_ids + create_time = result.get("create_time", 0) + update_time = result.get("update_time", 0) + result["create_time"] = create_time + result["update_time"] = create_time if update_time == 0 else update_time + return _order_results(results) async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" if not keys: return set() - + table_name = namespace_to_table_name(self.namespace) sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" params = {"workspace": self.workspace, "ids": list(keys)} @@ -2073,11 +2071,61 @@ class PGKVStorage(BaseKVStorage): "update_time": current_time, } await self.db.execute(upsert_sql, _data) + elif is_namespace(self.namespace, NameSpace.KV_STORE_ENTITY_CHUNKS): + # Get current UTC time and convert to naive datetime for database storage + current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_entity_chunks"] + _data = { + "workspace": self.workspace, + "id": k, + "chunk_ids": json.dumps(v["chunk_ids"]), + "count": v["count"], + "create_time": current_time, + "update_time": current_time, + } + await self.db.execute(upsert_sql, _data) + elif is_namespace(self.namespace, NameSpace.KV_STORE_RELATION_CHUNKS): + # Get current UTC time and convert to naive datetime for database storage + current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_relation_chunks"] + _data = { + "workspace": self.workspace, + "id": k, + "chunk_ids": json.dumps(v["chunk_ids"]), + "count": v["count"], + "create_time": current_time, + "update_time": current_time, + } + await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: # PG handles persistence automatically pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error( + f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" + ) + return True + + sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" + + try: + result = await self.db.query(sql, [self.workspace]) + return not result.get("has_data", False) if result else True + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -2559,7 +2607,7 @@ class PGDocStatusStorage(DocStatusStorage): """Filter out duplicated content""" if not keys: return set() - + table_name = namespace_to_table_name(self.namespace) sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" params = {"workspace": self.workspace, "ids": list(keys)} @@ -2993,6 +3041,28 @@ class PGDocStatusStorage(DocStatusStorage): # PG handles persistence automatically pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error( + f"[{self.workspace}] Unknown namespace for is_empty check: {self.namespace}" + ) + return True + + sql = f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE workspace=$1 LIMIT 1) as has_data" + + try: + result = await self.db.query(sql, [self.workspace]) + return not result.get("has_data", False) if result else True + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -3510,17 +3580,13 @@ class PGGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" - label = self._normalize_node_id(node_id) - - result = await self.get_nodes_batch(node_ids=[label]) + result = await self.get_nodes_batch(node_ids=[node_id]) if result and node_id in result: return result[node_id] return None async def node_degree(self, node_id: str) -> int: - label = self._normalize_node_id(node_id) - - result = await self.node_degrees_batch(node_ids=[label]) + result = await self.node_degrees_batch(node_ids=[node_id]) if result and node_id in result: return result[node_id] @@ -3533,12 +3599,11 @@ class PGGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes""" - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}]) - if result and (src_label, tgt_label) in result: - return result[(src_label, tgt_label)] + result = await self.get_edges_batch( + [{"src": source_node_id, "tgt": target_node_id}] + ) + if result and (source_node_id, target_node_id) in result: + return result[(source_node_id, target_node_id)] return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -3736,13 +3801,17 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - seen = set() - unique_ids = [] + seen: set[str] = set() + unique_ids: list[str] = [] + lookup: dict[str, str] = {} + requested: set[str] = set() for nid in node_ids: - nid_norm = self._normalize_node_id(nid) - if nid_norm not in seen: - seen.add(nid_norm) - unique_ids.append(nid_norm) + if nid not in seen: + seen.add(nid) + unique_ids.append(nid) + requested.add(nid) + lookup[nid] = nid + lookup[self._normalize_node_id(nid)] = nid # Build result dictionary nodes_dict = {} @@ -3781,10 +3850,18 @@ class PGGraphStorage(BaseGraphStorage): node_dict = json.loads(node_dict) except json.JSONDecodeError: logger.warning( - f"Failed to parse node string in batch: {node_dict}" + f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" ) - nodes_dict[result["node_id"]] = node_dict + node_key = result["node_id"] + original_key = lookup.get(node_key) + if original_key is None: + logger.warning( + f"[{self.workspace}] Node {node_key} not found in lookup map" + ) + original_key = node_key + if original_key in requested: + nodes_dict[original_key] = node_dict return nodes_dict @@ -3807,13 +3884,17 @@ class PGGraphStorage(BaseGraphStorage): if not node_ids: return {} - seen = set() + seen: set[str] = set() unique_ids: list[str] = [] + lookup: dict[str, str] = {} + requested: set[str] = set() for nid in node_ids: - n = self._normalize_node_id(nid) - if n not in seen: - seen.add(n) - unique_ids.append(n) + if nid not in seen: + seen.add(nid) + unique_ids.append(nid) + requested.add(nid) + lookup[nid] = nid + lookup[self._normalize_node_id(nid)] = nid out_degrees = {} in_degrees = {} @@ -3865,8 +3946,16 @@ class PGGraphStorage(BaseGraphStorage): node_id = row["node_id"] if not node_id: continue - out_degrees[node_id] = int(row.get("out_degree", 0) or 0) - in_degrees[node_id] = int(row.get("in_degree", 0) or 0) + node_key = node_id + original_key = lookup.get(node_key) + if original_key is None: + logger.warning( + f"[{self.workspace}] Node {node_key} not found in lookup map" + ) + original_key = node_key + if original_key in requested: + out_degrees[original_key] = int(row.get("out_degree", 0) or 0) + in_degrees[original_key] = int(row.get("in_degree", 0) or 0) degrees_dict = {} for node_id in node_ids: @@ -3995,7 +4084,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge properties string: {edge_props}" + f"[{self.workspace}]Failed to parse edge properties string: {edge_props}" ) continue @@ -4011,7 +4100,7 @@ class PGGraphStorage(BaseGraphStorage): edge_props = json.loads(edge_props) except json.JSONDecodeError: logger.warning( - f"Failed to parse edge properties string: {edge_props}" + f"[{self.workspace}] Failed to parse edge properties string: {edge_props}" ) continue @@ -4116,102 +4205,6 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves nodes from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all nodes - where the `source_id` property contains any of the specified chunk IDs. - """ - # The string representation of the list for the cypher query - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH (n:base) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') - RETURN n - $$) AS (n agtype); - """ - results = await self._query(query) - - # Build result list - nodes = [] - for result in results: - if result["n"]: - node_dict = result["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {node_dict}" - ) - - node_dict["id"] = node_dict["entity_id"] - nodes.append(node_dict) - - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves edges from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all edges - where the `source_id` property contains any of the specified chunk IDs. - """ - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH ()-[r]-() - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') - RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target - $$) AS (edge agtype, source agtype, target agtype); - """ - results = await self._query(query) - edges = [] - if results: - for item in results: - edge_agtype = item["edge"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(edge_agtype, str): - try: - edge_agtype = json.loads(edge_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string in batch: {edge_agtype}" - ) - - source_agtype = item["source"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(source_agtype, str): - try: - source_agtype = json.loads(source_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {source_agtype}" - ) - - target_agtype = item["target"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(target_agtype, str): - try: - target_agtype = json.loads(target_agtype) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string in batch: {target_agtype}" - ) - - if edge_agtype and source_agtype and target_agtype: - edge_properties = edge_agtype - edge_properties["source"] = source_agtype["entity_id"] - edge_properties["target"] = target_agtype["entity_id"] - edges.append(edge_properties) - return edges - async def _bfs_subgraph( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: @@ -4757,6 +4750,8 @@ NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES", NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS", + NameSpace.KV_STORE_ENTITY_CHUNKS: "LIGHTRAG_ENTITY_CHUNKS", + NameSpace.KV_STORE_RELATION_CHUNKS: "LIGHTRAG_RELATION_CHUNKS", NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", @@ -4897,6 +4892,28 @@ TABLES = { CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id) )""" }, + "LIGHTRAG_ENTITY_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_ENTITY_CHUNKS ( + id VARCHAR(512), + workspace VARCHAR(255), + chunk_ids JSONB, + count INTEGER, + create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT LIGHTRAG_ENTITY_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_RELATION_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_RELATION_CHUNKS ( + id VARCHAR(512), + workspace VARCHAR(255), + chunk_ids JSONB, + count INTEGER, + create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT LIGHTRAG_RELATION_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, } @@ -4954,6 +4971,26 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2) """, + "get_by_id_entity_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_id_relation_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_ids_entity_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_ENTITY_CHUNKS WHERE workspace=$1 AND id = ANY($2) + """, + "get_by_ids_relation_chunks": """SELECT id, chunk_ids, count, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2) + """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) VALUES ($1, $2, $3, $4) @@ -5001,6 +5038,22 @@ SQL_TEMPLATES = { count=EXCLUDED.count, update_time = EXCLUDED.update_time """, + "upsert_entity_chunks": """INSERT INTO LIGHTRAG_ENTITY_CHUNKS (workspace, id, chunk_ids, count, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,id) DO UPDATE + SET chunk_ids=EXCLUDED.chunk_ids, + count=EXCLUDED.count, + update_time = EXCLUDED.update_time + """, + "upsert_relation_chunks": """INSERT INTO LIGHTRAG_RELATION_CHUNKS (workspace, id, chunk_ids, count, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,id) DO UPDATE + SET chunk_ids=EXCLUDED.chunk_ids, + count=EXCLUDED.count, + update_time = EXCLUDED.update_time + """, # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path,