fix: include all query parameters in LLM cache hash key generation
- Add missing query parameters (top_k, enable_rerank, max_tokens, etc.) to cache key generation in kg_query, naive_query, and extract_keywords_only functions - Add queryparam field to CacheData structure and PostgreSQL storage for debugging - Update PostgreSQL schema with automatic migration for queryparam JSONB column - Prevent incorrect cache hits between queries with different parameters Fixes issue where different query parameters incorrectly shared the same cached results.
This commit is contained in:
parent
cb75e6631e
commit
0463963520
3 changed files with 135 additions and 12 deletions
|
|
@ -225,14 +225,14 @@ class PostgreSQLDB:
|
|||
pass
|
||||
|
||||
async def _migrate_llm_cache_add_columns(self):
|
||||
"""Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
|
||||
"""Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
|
||||
try:
|
||||
# Check if both columns exist
|
||||
# Check if all columns exist
|
||||
check_columns_sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'lightrag_llm_cache'
|
||||
AND column_name IN ('chunk_id', 'cache_type')
|
||||
AND column_name IN ('chunk_id', 'cache_type', 'queryparam')
|
||||
"""
|
||||
|
||||
existing_columns = await self.query(check_columns_sql, multirows=True)
|
||||
|
|
@ -289,6 +289,22 @@ class PostgreSQLDB:
|
|||
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
|
||||
)
|
||||
|
||||
# Add missing queryparam column
|
||||
if "queryparam" not in existing_column_names:
|
||||
logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
|
||||
add_queryparam_sql = """
|
||||
ALTER TABLE LIGHTRAG_LLM_CACHE
|
||||
ADD COLUMN queryparam JSONB NULL
|
||||
"""
|
||||
await self.execute(add_queryparam_sql)
|
||||
logger.info(
|
||||
"Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"queryparam column already exists in LIGHTRAG_LLM_CACHE table"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}")
|
||||
|
||||
|
|
@ -1379,6 +1395,13 @@ class PGKVStorage(BaseKVStorage):
|
|||
):
|
||||
create_time = response.get("create_time", 0)
|
||||
update_time = response.get("update_time", 0)
|
||||
# Parse queryparam JSON string back to dict
|
||||
queryparam = response.get("queryparam")
|
||||
if isinstance(queryparam, str):
|
||||
try:
|
||||
queryparam = json.loads(queryparam)
|
||||
except json.JSONDecodeError:
|
||||
queryparam = None
|
||||
# Map field names and add cache_type for compatibility
|
||||
response = {
|
||||
**response,
|
||||
|
|
@ -1387,6 +1410,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
"original_prompt": response.get("original_prompt", ""),
|
||||
"chunk_id": response.get("chunk_id"),
|
||||
"mode": response.get("mode", "default"),
|
||||
"queryparam": queryparam,
|
||||
"create_time": create_time,
|
||||
"update_time": create_time if update_time == 0 else update_time,
|
||||
}
|
||||
|
|
@ -1455,6 +1479,13 @@ class PGKVStorage(BaseKVStorage):
|
|||
for row in results:
|
||||
create_time = row.get("create_time", 0)
|
||||
update_time = row.get("update_time", 0)
|
||||
# Parse queryparam JSON string back to dict
|
||||
queryparam = row.get("queryparam")
|
||||
if isinstance(queryparam, str):
|
||||
try:
|
||||
queryparam = json.loads(queryparam)
|
||||
except json.JSONDecodeError:
|
||||
queryparam = None
|
||||
# Map field names and add cache_type for compatibility
|
||||
processed_row = {
|
||||
**row,
|
||||
|
|
@ -1463,6 +1494,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
"original_prompt": row.get("original_prompt", ""),
|
||||
"chunk_id": row.get("chunk_id"),
|
||||
"mode": row.get("mode", "default"),
|
||||
"queryparam": queryparam,
|
||||
"create_time": create_time,
|
||||
"update_time": create_time if update_time == 0 else update_time,
|
||||
}
|
||||
|
|
@ -1570,6 +1602,9 @@ class PGKVStorage(BaseKVStorage):
|
|||
"cache_type": v.get(
|
||||
"cache_type", "extract"
|
||||
), # Get cache_type from data
|
||||
"queryparam": json.dumps(v.get("queryparam"))
|
||||
if v.get("queryparam")
|
||||
else None,
|
||||
}
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
|
@ -4054,6 +4089,7 @@ TABLES = {
|
|||
return_value TEXT,
|
||||
chunk_id VARCHAR(255) NULL,
|
||||
cache_type VARCHAR(32),
|
||||
queryparam JSONB NULL,
|
||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
|
||||
|
|
@ -4114,7 +4150,7 @@ SQL_TEMPLATES = {
|
|||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||
|
|
@ -4132,7 +4168,7 @@ SQL_TEMPLATES = {
|
|||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||
""",
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
|
||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||
|
|
@ -4163,14 +4199,15 @@ SQL_TEMPLATES = {
|
|||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET content = $2, update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (workspace,mode,id) DO UPDATE
|
||||
SET original_prompt = EXCLUDED.original_prompt,
|
||||
return_value=EXCLUDED.return_value,
|
||||
mode=EXCLUDED.mode,
|
||||
chunk_id=EXCLUDED.chunk_id,
|
||||
cache_type=EXCLUDED.cache_type,
|
||||
queryparam=EXCLUDED.queryparam,
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||
|
|
|
|||
|
|
@ -1727,7 +1727,20 @@ async def kg_query(
|
|||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
|
|
@ -1826,7 +1839,20 @@ async def kg_query(
|
|||
)
|
||||
|
||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||
# Save to cache
|
||||
# Save to cache with query parameters
|
||||
queryparam_dict = {
|
||||
"mode": query_param.mode,
|
||||
"response_type": query_param.response_type,
|
||||
"top_k": query_param.top_k,
|
||||
"chunk_top_k": query_param.chunk_top_k,
|
||||
"max_entity_tokens": query_param.max_entity_tokens,
|
||||
"max_relation_tokens": query_param.max_relation_tokens,
|
||||
"max_total_tokens": query_param.max_total_tokens,
|
||||
"hl_keywords": query_param.hl_keywords or [],
|
||||
"ll_keywords": query_param.ll_keywords or [],
|
||||
"user_prompt": query_param.user_prompt or "",
|
||||
"enable_rerank": query_param.enable_rerank,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
|
|
@ -1835,6 +1861,7 @@ async def kg_query(
|
|||
prompt=query,
|
||||
mode=query_param.mode,
|
||||
cache_type="query",
|
||||
queryparam=queryparam_dict,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -1886,7 +1913,20 @@ async def extract_keywords_only(
|
|||
"""
|
||||
|
||||
# 1. Handle cache if needed - add cache type for keywords
|
||||
args_hash = compute_args_hash(param.mode, text)
|
||||
args_hash = compute_args_hash(
|
||||
param.mode,
|
||||
text,
|
||||
param.response_type,
|
||||
param.top_k,
|
||||
param.chunk_top_k,
|
||||
param.max_entity_tokens,
|
||||
param.max_relation_tokens,
|
||||
param.max_total_tokens,
|
||||
param.hl_keywords or [],
|
||||
param.ll_keywords or [],
|
||||
param.user_prompt or "",
|
||||
param.enable_rerank,
|
||||
)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||
)
|
||||
|
|
@ -1963,6 +2003,20 @@ async def extract_keywords_only(
|
|||
"low_level_keywords": ll_keywords,
|
||||
}
|
||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||
# Save to cache with query parameters
|
||||
queryparam_dict = {
|
||||
"mode": param.mode,
|
||||
"response_type": param.response_type,
|
||||
"top_k": param.top_k,
|
||||
"chunk_top_k": param.chunk_top_k,
|
||||
"max_entity_tokens": param.max_entity_tokens,
|
||||
"max_relation_tokens": param.max_relation_tokens,
|
||||
"max_total_tokens": param.max_total_tokens,
|
||||
"hl_keywords": param.hl_keywords or [],
|
||||
"ll_keywords": param.ll_keywords or [],
|
||||
"user_prompt": param.user_prompt or "",
|
||||
"enable_rerank": param.enable_rerank,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
|
|
@ -1971,6 +2025,7 @@ async def extract_keywords_only(
|
|||
prompt=text,
|
||||
mode=param.mode,
|
||||
cache_type="keywords",
|
||||
queryparam=queryparam_dict,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -2945,7 +3000,20 @@ async def naive_query(
|
|||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
|
|
@ -3092,7 +3160,20 @@ async def naive_query(
|
|||
)
|
||||
|
||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||
# Save to cache
|
||||
# Save to cache with query parameters
|
||||
queryparam_dict = {
|
||||
"mode": query_param.mode,
|
||||
"response_type": query_param.response_type,
|
||||
"top_k": query_param.top_k,
|
||||
"chunk_top_k": query_param.chunk_top_k,
|
||||
"max_entity_tokens": query_param.max_entity_tokens,
|
||||
"max_relation_tokens": query_param.max_relation_tokens,
|
||||
"max_total_tokens": query_param.max_total_tokens,
|
||||
"hl_keywords": query_param.hl_keywords or [],
|
||||
"ll_keywords": query_param.ll_keywords or [],
|
||||
"user_prompt": query_param.user_prompt or "",
|
||||
"enable_rerank": query_param.enable_rerank,
|
||||
}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
|
|
@ -3101,6 +3182,7 @@ async def naive_query(
|
|||
prompt=query,
|
||||
mode=query_param.mode,
|
||||
cache_type="query",
|
||||
queryparam=queryparam_dict,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -793,6 +793,7 @@ class CacheData:
|
|||
mode: str = "default"
|
||||
cache_type: str = "query"
|
||||
chunk_id: str | None = None
|
||||
queryparam: dict | None = None
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
|
|
@ -830,6 +831,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||
"cache_type": cache_data.cache_type,
|
||||
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
||||
"original_prompt": cache_data.prompt,
|
||||
"queryparam": cache_data.queryparam
|
||||
if cache_data.queryparam is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
logger.info(f" == LLM cache == saving: {flattened_key}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue