fix: include all query parameters in LLM cache hash key generation
- Add missing query parameters (top_k, enable_rerank, max_tokens, etc.) to cache key generation in kg_query, naive_query, and extract_keywords_only functions - Add queryparam field to CacheData structure and PostgreSQL storage for debugging - Update PostgreSQL schema with automatic migration for queryparam JSONB column - Prevent incorrect cache hits between queries with different parameters Fixes issue where different query parameters incorrectly shared the same cached results.
This commit is contained in:
parent
cb75e6631e
commit
0463963520
3 changed files with 135 additions and 12 deletions
|
|
@ -225,14 +225,14 @@ class PostgreSQLDB:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _migrate_llm_cache_add_columns(self):
|
async def _migrate_llm_cache_add_columns(self):
|
||||||
"""Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
|
"""Add chunk_id, cache_type, and queryparam columns to LIGHTRAG_LLM_CACHE table if they don't exist"""
|
||||||
try:
|
try:
|
||||||
# Check if both columns exist
|
# Check if all columns exist
|
||||||
check_columns_sql = """
|
check_columns_sql = """
|
||||||
SELECT column_name
|
SELECT column_name
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_name = 'lightrag_llm_cache'
|
WHERE table_name = 'lightrag_llm_cache'
|
||||||
AND column_name IN ('chunk_id', 'cache_type')
|
AND column_name IN ('chunk_id', 'cache_type', 'queryparam')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
existing_columns = await self.query(check_columns_sql, multirows=True)
|
existing_columns = await self.query(check_columns_sql, multirows=True)
|
||||||
|
|
@ -289,6 +289,22 @@ class PostgreSQLDB:
|
||||||
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
|
"cache_type column already exists in LIGHTRAG_LLM_CACHE table"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add missing queryparam column
|
||||||
|
if "queryparam" not in existing_column_names:
|
||||||
|
logger.info("Adding queryparam column to LIGHTRAG_LLM_CACHE table")
|
||||||
|
add_queryparam_sql = """
|
||||||
|
ALTER TABLE LIGHTRAG_LLM_CACHE
|
||||||
|
ADD COLUMN queryparam JSONB NULL
|
||||||
|
"""
|
||||||
|
await self.execute(add_queryparam_sql)
|
||||||
|
logger.info(
|
||||||
|
"Successfully added queryparam column to LIGHTRAG_LLM_CACHE table"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"queryparam column already exists in LIGHTRAG_LLM_CACHE table"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}")
|
logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}")
|
||||||
|
|
||||||
|
|
@ -1379,6 +1395,13 @@ class PGKVStorage(BaseKVStorage):
|
||||||
):
|
):
|
||||||
create_time = response.get("create_time", 0)
|
create_time = response.get("create_time", 0)
|
||||||
update_time = response.get("update_time", 0)
|
update_time = response.get("update_time", 0)
|
||||||
|
# Parse queryparam JSON string back to dict
|
||||||
|
queryparam = response.get("queryparam")
|
||||||
|
if isinstance(queryparam, str):
|
||||||
|
try:
|
||||||
|
queryparam = json.loads(queryparam)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
queryparam = None
|
||||||
# Map field names and add cache_type for compatibility
|
# Map field names and add cache_type for compatibility
|
||||||
response = {
|
response = {
|
||||||
**response,
|
**response,
|
||||||
|
|
@ -1387,6 +1410,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"original_prompt": response.get("original_prompt", ""),
|
"original_prompt": response.get("original_prompt", ""),
|
||||||
"chunk_id": response.get("chunk_id"),
|
"chunk_id": response.get("chunk_id"),
|
||||||
"mode": response.get("mode", "default"),
|
"mode": response.get("mode", "default"),
|
||||||
|
"queryparam": queryparam,
|
||||||
"create_time": create_time,
|
"create_time": create_time,
|
||||||
"update_time": create_time if update_time == 0 else update_time,
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
}
|
}
|
||||||
|
|
@ -1455,6 +1479,13 @@ class PGKVStorage(BaseKVStorage):
|
||||||
for row in results:
|
for row in results:
|
||||||
create_time = row.get("create_time", 0)
|
create_time = row.get("create_time", 0)
|
||||||
update_time = row.get("update_time", 0)
|
update_time = row.get("update_time", 0)
|
||||||
|
# Parse queryparam JSON string back to dict
|
||||||
|
queryparam = row.get("queryparam")
|
||||||
|
if isinstance(queryparam, str):
|
||||||
|
try:
|
||||||
|
queryparam = json.loads(queryparam)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
queryparam = None
|
||||||
# Map field names and add cache_type for compatibility
|
# Map field names and add cache_type for compatibility
|
||||||
processed_row = {
|
processed_row = {
|
||||||
**row,
|
**row,
|
||||||
|
|
@ -1463,6 +1494,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"original_prompt": row.get("original_prompt", ""),
|
"original_prompt": row.get("original_prompt", ""),
|
||||||
"chunk_id": row.get("chunk_id"),
|
"chunk_id": row.get("chunk_id"),
|
||||||
"mode": row.get("mode", "default"),
|
"mode": row.get("mode", "default"),
|
||||||
|
"queryparam": queryparam,
|
||||||
"create_time": create_time,
|
"create_time": create_time,
|
||||||
"update_time": create_time if update_time == 0 else update_time,
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
}
|
}
|
||||||
|
|
@ -1570,6 +1602,9 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"cache_type": v.get(
|
"cache_type": v.get(
|
||||||
"cache_type", "extract"
|
"cache_type", "extract"
|
||||||
), # Get cache_type from data
|
), # Get cache_type from data
|
||||||
|
"queryparam": json.dumps(v.get("queryparam"))
|
||||||
|
if v.get("queryparam")
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
@ -4054,6 +4089,7 @@ TABLES = {
|
||||||
return_value TEXT,
|
return_value TEXT,
|
||||||
chunk_id VARCHAR(255) NULL,
|
chunk_id VARCHAR(255) NULL,
|
||||||
cache_type VARCHAR(32),
|
cache_type VARCHAR(32),
|
||||||
|
queryparam JSONB NULL,
|
||||||
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
|
CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id)
|
||||||
|
|
@ -4114,7 +4150,7 @@ SQL_TEMPLATES = {
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||||
|
|
@ -4132,7 +4168,7 @@ SQL_TEMPLATES = {
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, queryparam,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||||
|
|
@ -4163,14 +4199,15 @@ SQL_TEMPLATES = {
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
SET content = $2, update_time = CURRENT_TIMESTAMP
|
SET content = $2, update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
|
"upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type,queryparam)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
ON CONFLICT (workspace,mode,id) DO UPDATE
|
ON CONFLICT (workspace,mode,id) DO UPDATE
|
||||||
SET original_prompt = EXCLUDED.original_prompt,
|
SET original_prompt = EXCLUDED.original_prompt,
|
||||||
return_value=EXCLUDED.return_value,
|
return_value=EXCLUDED.return_value,
|
||||||
mode=EXCLUDED.mode,
|
mode=EXCLUDED.mode,
|
||||||
chunk_id=EXCLUDED.chunk_id,
|
chunk_id=EXCLUDED.chunk_id,
|
||||||
cache_type=EXCLUDED.cache_type,
|
cache_type=EXCLUDED.cache_type,
|
||||||
|
queryparam=EXCLUDED.queryparam,
|
||||||
update_time = CURRENT_TIMESTAMP
|
update_time = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
"upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
|
||||||
|
|
|
||||||
|
|
@ -1727,7 +1727,20 @@ async def kg_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query)
|
args_hash = compute_args_hash(
|
||||||
|
query_param.mode,
|
||||||
|
query,
|
||||||
|
query_param.response_type,
|
||||||
|
query_param.top_k,
|
||||||
|
query_param.chunk_top_k,
|
||||||
|
query_param.max_entity_tokens,
|
||||||
|
query_param.max_relation_tokens,
|
||||||
|
query_param.max_total_tokens,
|
||||||
|
query_param.hl_keywords or [],
|
||||||
|
query_param.ll_keywords or [],
|
||||||
|
query_param.user_prompt or "",
|
||||||
|
query_param.enable_rerank,
|
||||||
|
)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -1826,7 +1839,20 @@ async def kg_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
# Save to cache
|
# Save to cache with query parameters
|
||||||
|
queryparam_dict = {
|
||||||
|
"mode": query_param.mode,
|
||||||
|
"response_type": query_param.response_type,
|
||||||
|
"top_k": query_param.top_k,
|
||||||
|
"chunk_top_k": query_param.chunk_top_k,
|
||||||
|
"max_entity_tokens": query_param.max_entity_tokens,
|
||||||
|
"max_relation_tokens": query_param.max_relation_tokens,
|
||||||
|
"max_total_tokens": query_param.max_total_tokens,
|
||||||
|
"hl_keywords": query_param.hl_keywords or [],
|
||||||
|
"ll_keywords": query_param.ll_keywords or [],
|
||||||
|
"user_prompt": query_param.user_prompt or "",
|
||||||
|
"enable_rerank": query_param.enable_rerank,
|
||||||
|
}
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
|
|
@ -1835,6 +1861,7 @@ async def kg_query(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
mode=query_param.mode,
|
mode=query_param.mode,
|
||||||
cache_type="query",
|
cache_type="query",
|
||||||
|
queryparam=queryparam_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1886,7 +1913,20 @@ async def extract_keywords_only(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Handle cache if needed - add cache type for keywords
|
# 1. Handle cache if needed - add cache type for keywords
|
||||||
args_hash = compute_args_hash(param.mode, text)
|
args_hash = compute_args_hash(
|
||||||
|
param.mode,
|
||||||
|
text,
|
||||||
|
param.response_type,
|
||||||
|
param.top_k,
|
||||||
|
param.chunk_top_k,
|
||||||
|
param.max_entity_tokens,
|
||||||
|
param.max_relation_tokens,
|
||||||
|
param.max_total_tokens,
|
||||||
|
param.hl_keywords or [],
|
||||||
|
param.ll_keywords or [],
|
||||||
|
param.user_prompt or "",
|
||||||
|
param.enable_rerank,
|
||||||
|
)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||||
)
|
)
|
||||||
|
|
@ -1963,6 +2003,20 @@ async def extract_keywords_only(
|
||||||
"low_level_keywords": ll_keywords,
|
"low_level_keywords": ll_keywords,
|
||||||
}
|
}
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
|
# Save to cache with query parameters
|
||||||
|
queryparam_dict = {
|
||||||
|
"mode": param.mode,
|
||||||
|
"response_type": param.response_type,
|
||||||
|
"top_k": param.top_k,
|
||||||
|
"chunk_top_k": param.chunk_top_k,
|
||||||
|
"max_entity_tokens": param.max_entity_tokens,
|
||||||
|
"max_relation_tokens": param.max_relation_tokens,
|
||||||
|
"max_total_tokens": param.max_total_tokens,
|
||||||
|
"hl_keywords": param.hl_keywords or [],
|
||||||
|
"ll_keywords": param.ll_keywords or [],
|
||||||
|
"user_prompt": param.user_prompt or "",
|
||||||
|
"enable_rerank": param.enable_rerank,
|
||||||
|
}
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
|
|
@ -1971,6 +2025,7 @@ async def extract_keywords_only(
|
||||||
prompt=text,
|
prompt=text,
|
||||||
mode=param.mode,
|
mode=param.mode,
|
||||||
cache_type="keywords",
|
cache_type="keywords",
|
||||||
|
queryparam=queryparam_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2945,7 +3000,20 @@ async def naive_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query)
|
args_hash = compute_args_hash(
|
||||||
|
query_param.mode,
|
||||||
|
query,
|
||||||
|
query_param.response_type,
|
||||||
|
query_param.top_k,
|
||||||
|
query_param.chunk_top_k,
|
||||||
|
query_param.max_entity_tokens,
|
||||||
|
query_param.max_relation_tokens,
|
||||||
|
query_param.max_total_tokens,
|
||||||
|
query_param.hl_keywords or [],
|
||||||
|
query_param.ll_keywords or [],
|
||||||
|
query_param.user_prompt or "",
|
||||||
|
query_param.enable_rerank,
|
||||||
|
)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -3092,7 +3160,20 @@ async def naive_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
# Save to cache
|
# Save to cache with query parameters
|
||||||
|
queryparam_dict = {
|
||||||
|
"mode": query_param.mode,
|
||||||
|
"response_type": query_param.response_type,
|
||||||
|
"top_k": query_param.top_k,
|
||||||
|
"chunk_top_k": query_param.chunk_top_k,
|
||||||
|
"max_entity_tokens": query_param.max_entity_tokens,
|
||||||
|
"max_relation_tokens": query_param.max_relation_tokens,
|
||||||
|
"max_total_tokens": query_param.max_total_tokens,
|
||||||
|
"hl_keywords": query_param.hl_keywords or [],
|
||||||
|
"ll_keywords": query_param.ll_keywords or [],
|
||||||
|
"user_prompt": query_param.user_prompt or "",
|
||||||
|
"enable_rerank": query_param.enable_rerank,
|
||||||
|
}
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
|
|
@ -3101,6 +3182,7 @@ async def naive_query(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
mode=query_param.mode,
|
mode=query_param.mode,
|
||||||
cache_type="query",
|
cache_type="query",
|
||||||
|
queryparam=queryparam_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -793,6 +793,7 @@ class CacheData:
|
||||||
mode: str = "default"
|
mode: str = "default"
|
||||||
cache_type: str = "query"
|
cache_type: str = "query"
|
||||||
chunk_id: str | None = None
|
chunk_id: str | None = None
|
||||||
|
queryparam: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
|
|
@ -830,6 +831,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"cache_type": cache_data.cache_type,
|
"cache_type": cache_data.cache_type,
|
||||||
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
||||||
"original_prompt": cache_data.prompt,
|
"original_prompt": cache_data.prompt,
|
||||||
|
"queryparam": cache_data.queryparam
|
||||||
|
if cache_data.queryparam is not None
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f" == LLM cache == saving: {flattened_key}")
|
logger.info(f" == LLM cache == saving: {flattened_key}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue