diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 66f0dd6c..06079d71 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -225,14 +225,14 @@ class PostgreSQLDB: pass async def _migrate_llm_cache_add_columns(self): - """Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist""" + """Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist""" try: - # Check if both columns exist + # Check if all columns exist check_columns_sql = """ SELECT column_name FROM information_schema.columns WHERE table_name = 'lightrag_llm_cache' - AND column_name IN ('chunk_id', 'cache_type') + AND column_name IN ('chunk_id', 'cache_type', 'queryparam') """ existing_columns = await self.query(check_columns_sql, multirows=True) @@ -289,6 +289,22 @@ class PostgreSQLDB: "cache_type column already exists in LIGHTRAG_LLM_CACHE table" ) + # Add missing queryparam column + if "queryparam" not in existing_column_names: + logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table") + add_queryparam_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD COLUMN queryparam JSONB NULL + """ + await self.execute(add_queryparam_sql) + logger.info( + "Successfully added queryparam column to LIGHTRAG_LLM_CACHE table" + ) + else: + logger.info( + "queryparam column already exists in LIGHTRAG_LLM_CACHE table" + ) + except Exception as e: logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}") @@ -1379,6 +1395,13 @@ class PGKVStorage(BaseKVStorage): ): create_time = response.get("create_time", 0) update_time = response.get("update_time", 0) + # Parse queryparam JSON string back to dict + queryparam = response.get("queryparam") + if isinstance(queryparam, str): + try: + queryparam = json.loads(queryparam) + except json.JSONDecodeError: + queryparam = None # Map field names and add cache_type for compatibility response = { **response, @@ -1387,6 +1410,7 @@ class PGKVStorage(BaseKVStorage): "original_prompt": response.get("original_prompt", ""), "chunk_id": response.get("chunk_id"), "mode": response.get("mode", "default"), + "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, } @@ -1455,6 +1479,13 @@ class PGKVStorage(BaseKVStorage): for row in results: create_time = row.get("create_time", 0) update_time = row.get("update_time", 0) + # Parse queryparam JSON string back to dict + queryparam = row.get("queryparam") + if isinstance(queryparam, str): + try: + queryparam = json.loads(queryparam) + except json.JSONDecodeError: + queryparam = None # Map field names and add cache_type for compatibility processed_row = { **row, @@ -1463,6 +1494,7 @@ class PGKVStorage(BaseKVStorage): "original_prompt": row.get("original_prompt", ""), "chunk_id": row.get("chunk_id"), "mode": row.get("mode", "default"), + "queryparam": queryparam, "create_time": create_time, "update_time": create_time if update_time == 0 else update_time, } @@ -1570,6 +1602,9 @@ class PGKVStorage(BaseKVStorage): "cache_type": v.get( "cache_type", "extract" ), # Get cache_type from data + "queryparam": json.dumps(v.get("queryparam")) + if v.get("queryparam") + else None, } await self.db.execute(upsert_sql, _data) @@ -4054,6 +4089,7 @@ TABLES = { return_value TEXT, chunk_id VARCHAR(255) NULL, cache_type VARCHAR(32), + queryparam JSONB NULL, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) @@ -4114,7 +4150,7 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 @@ -4132,7 +4168,7 @@ SQL_TEMPLATES = { EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) @@ -4163,14 +4199,15 @@ SQL_TEMPLATES = { ON CONFLICT (workspace,id) DO UPDATE SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) - VALUES ($1, $2, $3, $4, $5, $6, $7) + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (workspace,mode,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, return_value=EXCLUDED.return_value, mode=EXCLUDED.mode, chunk_id=EXCLUDED.chunk_id, cache_type=EXCLUDED.cache_type, + queryparam=EXCLUDED.queryparam, update_time = CURRENT_TIMESTAMP """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, diff --git a/lightrag/operate.py b/lightrag/operate.py index 254dfdac..1d398ad3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1727,7 +1727,20 @@ async def kg_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query) + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -1826,7 +1839,20 @@ async def kg_query( ) if hashing_kv.global_config.get("enable_llm_cache"): - # Save to cache + # Save to cache with query parameters + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": query_param.hl_keywords or [], + "ll_keywords": query_param.ll_keywords or [], + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -1835,6 +1861,7 @@ async def kg_query( prompt=query, mode=query_param.mode, cache_type="query", + queryparam=queryparam_dict, ), ) @@ -1886,7 +1913,20 @@ async def extract_keywords_only( """ # 1. Handle cache if needed - add cache type for keywords - args_hash = compute_args_hash(param.mode, text) + args_hash = compute_args_hash( + param.mode, + text, + param.response_type, + param.top_k, + param.chunk_top_k, + param.max_entity_tokens, + param.max_relation_tokens, + param.max_total_tokens, + param.hl_keywords or [], + param.ll_keywords or [], + param.user_prompt or "", + param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) @@ -1963,6 +2003,20 @@ async def extract_keywords_only( "low_level_keywords": ll_keywords, } if hashing_kv.global_config.get("enable_llm_cache"): + # Save to cache with query parameters + queryparam_dict = { + "mode": param.mode, + "response_type": param.response_type, + "top_k": param.top_k, + "chunk_top_k": param.chunk_top_k, + "max_entity_tokens": param.max_entity_tokens, + "max_relation_tokens": param.max_relation_tokens, + "max_total_tokens": param.max_total_tokens, + "hl_keywords": param.hl_keywords or [], + "ll_keywords": param.ll_keywords or [], + "user_prompt": param.user_prompt or "", + "enable_rerank": param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -1971,6 +2025,7 @@ async def extract_keywords_only( prompt=text, mode=param.mode, cache_type="keywords", + queryparam=queryparam_dict, ), ) @@ -2945,7 +3000,20 @@ async def naive_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query) + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, + ) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -3092,7 +3160,20 @@ async def naive_query( ) if hashing_kv.global_config.get("enable_llm_cache"): - # Save to cache + # Save to cache with query parameters + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": query_param.hl_keywords or [], + "ll_keywords": query_param.ll_keywords or [], + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } await save_to_cache( hashing_kv, CacheData( @@ -3101,6 +3182,7 @@ async def naive_query( prompt=query, mode=query_param.mode, cache_type="query", + queryparam=queryparam_dict, ), ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 9e818d6b..1eb8c98d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -793,6 +793,7 @@ class CacheData: mode: str = "default" cache_type: str = "query" chunk_id: str | None = None + queryparam: dict | None = None async def save_to_cache(hashing_kv, cache_data: CacheData): @@ -830,6 +831,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "cache_type": cache_data.cache_type, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, "original_prompt": cache_data.prompt, + "queryparam": cache_data.queryparam + if cache_data.queryparam is not None + else None, } logger.info(f" == LLM cache == saving: {flattened_key}")