fix: include all query parameters in LLM cache hash key generation

- Add missing query parameters (top_k, enable_rerank, max_tokens, etc.) to cache key generation in kg_query, naive_query, and extract_keywords_only functions
- Add queryparam field to CacheData structure and PostgreSQL storage for debugging
- Update PostgreSQL schema with automatic migration for queryparam JSONB column
- Prevent incorrect cache hits between queries with different parameters

Fixes issue where different query parameters incorrectly shared the same cached results.
This commit is contained in:
yangdx 2025-08-05 18:03:10 +08:00
parent cb75e6631e
commit 0463963520
3 changed files with 135 additions and 12 deletions

View file

@ -225,14 +225,14 @@ class PostgreSQLDB:
pass
async def _migrate_llm_cache_add_columns(self):
"""Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
"""Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
try:
# Check if both columns exist
# Check if all columns exist
check_columns_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'lightrag_llm_cache'
AND column_name IN ('chunk_id', 'cache_type')
AND column_name IN ('chunk_id', 'cache_type', 'queryparam')
"""
existing_columns = await self.query(check_columns_sql, multirows=True)
@ -289,6 +289,22 @@ class PostgreSQLDB:
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
)
# Add missing queryparam column
if "queryparam" not in existing_column_names:
logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
add_queryparam_sql = """
ALTER TABLE LIGHTRAG_LLM_CACHE
ADD COLUMN queryparam JSONB NULL
"""
await self.execute(add_queryparam_sql)
logger.info(
"Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
)
else:
logger.info(
"queryparam column already exists in LIGHTRAG_LLM_CACHE table"
)
except Exception as e:
logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}")
@ -1379,6 +1395,13 @@ class PGKVStorage(BaseKVStorage):
):
create_time = response.get("create_time", 0)
update_time = response.get("update_time", 0)
# Parse queryparam JSON string back to dict
queryparam = response.get("queryparam")
if isinstance(queryparam, str):
try:
queryparam = json.loads(queryparam)
except json.JSONDecodeError:
queryparam = None
# Map field names and add cache_type for compatibility
response = {
**response,
@ -1387,6 +1410,7 @@ class PGKVStorage(BaseKVStorage):
"original_prompt": response.get("original_prompt", ""),
"chunk_id": response.get("chunk_id"),
"mode": response.get("mode", "default"),
"queryparam": queryparam,
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
@ -1455,6 +1479,13 @@ class PGKVStorage(BaseKVStorage):
for row in results:
create_time = row.get("create_time", 0)
update_time = row.get("update_time", 0)
# Parse queryparam JSON string back to dict
queryparam = row.get("queryparam")
if isinstance(queryparam, str):
try:
queryparam = json.loads(queryparam)
except json.JSONDecodeError:
queryparam = None
# Map field names and add cache_type for compatibility
processed_row = {
**row,
@ -1463,6 +1494,7 @@ class PGKVStorage(BaseKVStorage):
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"queryparam": queryparam,
"create_time": create_time,
"update_time": create_time if update_time == 0 else update_time,
}
@ -1570,6 +1602,9 @@ class PGKVStorage(BaseKVStorage):
"cache_type": v.get(
"cache_type", "extract"
), # Get cache_type from data
"queryparam": json.dumps(v.get("queryparam"))
if v.get("queryparam")
else None,
}
await self.db.execute(upsert_sql, _data)
@ -4054,6 +4089,7 @@ TABLES = {
return_value TEXT,
chunk_id VARCHAR(255) NULL,
cache_type VARCHAR(32),
queryparam JSONB NULL,
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
@ -4114,7 +4150,7 @@ SQL_TEMPLATES = {
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
@ -4132,7 +4168,7 @@ SQL_TEMPLATES = {
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
@ -4163,14 +4199,15 @@ SQL_TEMPLATES = {
ON CONFLICT (workspace,id) DO UPDATE
SET content = $2, update_time = CURRENT_TIMESTAMP
""",
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
VALUES ($1, $2, $3, $4, $5, $6, $7)
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (workspace,mode,id) DO UPDATE
SET original_prompt = EXCLUDED.original_prompt,
return_value=EXCLUDED.return_value,
mode=EXCLUDED.mode,
chunk_id=EXCLUDED.chunk_id,
cache_type=EXCLUDED.cache_type,
queryparam=EXCLUDED.queryparam,
update_time = CURRENT_TIMESTAMP
""",
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,

View file

@ -1727,7 +1727,20 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query)
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@ -1826,7 +1839,20 @@ async def kg_query(
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
@ -1835,6 +1861,7 @@ async def kg_query(
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
@ -1886,7 +1913,20 @@ async def extract_keywords_only(
"""
# 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text)
args_hash = compute_args_hash(
param.mode,
text,
param.response_type,
param.top_k,
param.chunk_top_k,
param.max_entity_tokens,
param.max_relation_tokens,
param.max_total_tokens,
param.hl_keywords or [],
param.ll_keywords or [],
param.user_prompt or "",
param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
)
@ -1963,6 +2003,20 @@ async def extract_keywords_only(
"low_level_keywords": ll_keywords,
}
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": param.mode,
"response_type": param.response_type,
"top_k": param.top_k,
"chunk_top_k": param.chunk_top_k,
"max_entity_tokens": param.max_entity_tokens,
"max_relation_tokens": param.max_relation_tokens,
"max_total_tokens": param.max_total_tokens,
"hl_keywords": param.hl_keywords or [],
"ll_keywords": param.ll_keywords or [],
"user_prompt": param.user_prompt or "",
"enable_rerank": param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
@ -1971,6 +2025,7 @@ async def extract_keywords_only(
prompt=text,
mode=param.mode,
cache_type="keywords",
queryparam=queryparam_dict,
),
)
@ -2945,7 +3000,20 @@ async def naive_query(
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(query_param.mode, query)
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
@ -3092,7 +3160,20 @@ async def naive_query(
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
@ -3101,6 +3182,7 @@ async def naive_query(
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)

View file

@ -793,6 +793,7 @@ class CacheData:
mode: str = "default"
cache_type: str = "query"
chunk_id: str | None = None
queryparam: dict | None = None
async def save_to_cache(hashing_kv, cache_data: CacheData):
@ -830,6 +831,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"cache_type": cache_data.cache_type,
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
"original_prompt": cache_data.prompt,
"queryparam": cache_data.queryparam
if cache_data.queryparam is not None
else None,
}
logger.info(f" == LLM cache == saving: {flattened_key}")