feat: Flatten LLM cache structure for improved recall efficiency
Refactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency.
This commit is contained in:
parent
b32c3825cc
commit
271722405f
12 changed files with 836 additions and 375 deletions
|
|
@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
|
||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get all cache data using the new flattened structure
|
||||||
|
all_data = await from_llm_response_cache.get_all()
|
||||||
|
|
||||||
|
# Convert flattened data to hierarchical structure for JsonKVStorage
|
||||||
kv = {}
|
kv = {}
|
||||||
for c_id in await from_llm_response_cache.all_keys():
|
for flattened_key, cache_entry in all_data.items():
|
||||||
print(f"Copying {c_id}")
|
# Parse flattened key: {mode}:{cache_type}:{hash}
|
||||||
workspace = c_id["workspace"]
|
parts = flattened_key.split(":", 2)
|
||||||
mode = c_id["mode"]
|
if len(parts) == 3:
|
||||||
_id = c_id["id"]
|
mode, cache_type, hash_value = parts
|
||||||
postgres_db.workspace = workspace
|
if mode not in kv:
|
||||||
obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
|
kv[mode] = {}
|
||||||
if mode not in kv:
|
kv[mode][hash_value] = cache_entry
|
||||||
kv[mode] = {}
|
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
|
||||||
kv[mode][_id] = obj[_id]
|
else:
|
||||||
print(f"Object {obj}")
|
print(f"Skipping invalid key format: {flattened_key}")
|
||||||
|
|
||||||
await to_llm_response_cache.upsert(kv)
|
await to_llm_response_cache.upsert(kv)
|
||||||
await to_llm_response_cache.index_done_callback()
|
await to_llm_response_cache.index_done_callback()
|
||||||
print("Mission accomplished!")
|
print("Mission accomplished!")
|
||||||
|
|
@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
|
||||||
db=postgres_db,
|
db=postgres_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
for mode in await from_llm_response_cache.all_keys():
|
# Get all cache data from JsonKVStorage (hierarchical structure)
|
||||||
print(f"Copying {mode}")
|
all_data = await from_llm_response_cache.get_all()
|
||||||
caches = await from_llm_response_cache.get_by_id(mode)
|
|
||||||
for k, v in caches.items():
|
# Convert hierarchical data to flattened structure for PGKVStorage
|
||||||
item = {mode: {k: v}}
|
flattened_data = {}
|
||||||
print(f"\tCopying {item}")
|
for mode, mode_data in all_data.items():
|
||||||
await to_llm_response_cache.upsert(item)
|
print(f"Processing mode: {mode}")
|
||||||
|
for hash_value, cache_entry in mode_data.items():
|
||||||
|
# Determine cache_type from cache entry or use default
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
# Create flattened key: {mode}:{cache_type}:{hash}
|
||||||
|
flattened_key = f"{mode}:{cache_type}:{hash_value}"
|
||||||
|
flattened_data[flattened_key] = cache_entry
|
||||||
|
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
|
||||||
|
|
||||||
|
# Upsert the flattened data
|
||||||
|
await to_llm_response_cache.upsert(flattened_data)
|
||||||
|
print("Mission accomplished!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = {
|
||||||
"DOC_STATUS_STORAGE": {
|
"DOC_STATUS_STORAGE": {
|
||||||
"implementations": [
|
"implementations": [
|
||||||
"JsonDocStatusStorage",
|
"JsonDocStatusStorage",
|
||||||
|
"RedisDocStatusStorage",
|
||||||
"PGDocStatusStorage",
|
"PGDocStatusStorage",
|
||||||
"MongoDocStatusStorage",
|
"MongoDocStatusStorage",
|
||||||
],
|
],
|
||||||
|
|
@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
||||||
"MongoVectorDBStorage": [],
|
"MongoVectorDBStorage": [],
|
||||||
# Document Status Storage Implementations
|
# Document Status Storage Implementations
|
||||||
"JsonDocStatusStorage": [],
|
"JsonDocStatusStorage": [],
|
||||||
|
"RedisDocStatusStorage": ["REDIS_URI"],
|
||||||
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
||||||
"MongoDocStatusStorage": [],
|
"MongoDocStatusStorage": [],
|
||||||
}
|
}
|
||||||
|
|
@ -96,6 +98,7 @@ STORAGES = {
|
||||||
"MongoGraphStorage": ".kg.mongo_impl",
|
"MongoGraphStorage": ".kg.mongo_impl",
|
||||||
"MongoVectorDBStorage": ".kg.mongo_impl",
|
"MongoVectorDBStorage": ".kg.mongo_impl",
|
||||||
"RedisKVStorage": ".kg.redis_impl",
|
"RedisKVStorage": ".kg.redis_impl",
|
||||||
|
"RedisDocStatusStorage": ".kg.redis_impl",
|
||||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||||
# "TiDBKVStorage": ".kg.tidb_impl",
|
# "TiDBKVStorage": ".kg.tidb_impl",
|
||||||
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
|
||||||
self._collection.delete(ids=ids)
|
self._collection.delete(ids=ids)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||||
|
|
|
||||||
|
|
@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
# Migrate legacy cache structure if needed
|
||||||
|
if self.namespace.endswith("_cache"):
|
||||||
# Calculate data count based on namespace
|
loaded_data = await self._migrate_legacy_cache_structure(
|
||||||
if self.namespace.endswith("cache"):
|
loaded_data
|
||||||
# For cache namespaces, sum the cache entries across all cache types
|
|
||||||
data_count = sum(
|
|
||||||
len(first_level_dict)
|
|
||||||
for first_level_dict in loaded_data.values()
|
|
||||||
if isinstance(first_level_dict, dict)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# For non-cache namespaces, use the original count method
|
self._data.update(loaded_data)
|
||||||
data_count = len(loaded_data)
|
data_count = len(loaded_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||||
|
|
@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate data count based on namespace
|
# Calculate data count - all data is now flattened
|
||||||
if self.namespace.endswith("cache"):
|
data_count = len(data_dict)
|
||||||
# # For cache namespaces, sum the cache entries across all cache types
|
|
||||||
data_count = sum(
|
|
||||||
len(first_level_dict)
|
|
||||||
for first_level_dict in data_dict.values()
|
|
||||||
if isinstance(first_level_dict, dict)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For non-cache namespaces, use the original count method
|
|
||||||
data_count = len(data_dict)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||||
|
|
@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
await set_all_update_flags(self.namespace)
|
await set_all_update_flags(self.namespace)
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||||
"""Delete specific records from storage by by cache mode
|
"""Delete specific records from storage by cache mode
|
||||||
|
|
||||||
Importance notes for in-memory storage:
|
Importance notes for in-memory storage:
|
||||||
1. Changes will be persisted to disk during the next index_done_callback
|
1. Changes will be persisted to disk during the next index_done_callback
|
||||||
2. update flags to notify other processes that data persistence is needed
|
2. update flags to notify other processes that data persistence is needed
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids (list[str]): List of cache mode to be drop from storage
|
modes (list[str]): List of cache modes to be dropped from storage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True: if the cache drop successfully
|
True: if the cache drop successfully
|
||||||
|
|
@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.delete(modes)
|
async with self._storage_lock:
|
||||||
|
keys_to_delete = []
|
||||||
|
modes_set = set(modes) # Convert to set for efficient lookup
|
||||||
|
|
||||||
|
for key in list(self._data.keys()):
|
||||||
|
# Parse flattened cache key: mode:cache_type:hash
|
||||||
|
parts = key.split(":", 2)
|
||||||
|
if len(parts) == 3 and parts[0] in modes_set:
|
||||||
|
keys_to_delete.append(key)
|
||||||
|
|
||||||
|
# Batch delete
|
||||||
|
for key in keys_to_delete:
|
||||||
|
self._data.pop(key, None)
|
||||||
|
|
||||||
|
if keys_to_delete:
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping cache by modes: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
||||||
|
|
@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
|
||||||
|
"""Migrate legacy nested cache structure to flattened structure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Original data dictionary that may contain legacy structure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Migrated data dictionary with flattened cache keys
|
||||||
|
"""
|
||||||
|
from lightrag.utils import generate_cache_key
|
||||||
|
|
||||||
|
# Early return if data is empty
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check first entry to see if it's already in new format
|
||||||
|
first_key = next(iter(data.keys()))
|
||||||
|
if ":" in first_key and len(first_key.split(":")) == 3:
|
||||||
|
# Already in flattened format, return as-is
|
||||||
|
return data
|
||||||
|
|
||||||
|
migrated_data = {}
|
||||||
|
migration_count = 0
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
# Check if this is a legacy nested cache structure
|
||||||
|
if isinstance(value, dict) and all(
|
||||||
|
isinstance(v, dict) and "return" in v for v in value.values()
|
||||||
|
):
|
||||||
|
# This looks like a legacy cache mode with nested structure
|
||||||
|
mode = key
|
||||||
|
for cache_hash, cache_entry in value.items():
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||||
|
migrated_data[flattened_key] = cache_entry
|
||||||
|
migration_count += 1
|
||||||
|
else:
|
||||||
|
# Keep non-cache data or already flattened cache data as-is
|
||||||
|
migrated_data[key] = value
|
||||||
|
|
||||||
|
if migration_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Migrated {migration_count} legacy cache entries to flattened structure"
|
||||||
|
)
|
||||||
|
# Persist migrated data immediately
|
||||||
|
write_json(migrated_data, self._file_name)
|
||||||
|
|
||||||
|
return migrated_data
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
"""Finalize storage resources
|
"""Finalize storage resources
|
||||||
Persistence cache data to disk before exiting
|
Persistence cache data to disk before exiting
|
||||||
"""
|
"""
|
||||||
if self.namespace.endswith("cache"):
|
if self.namespace.endswith("_cache"):
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from ..base import (
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocStatusStorage,
|
DocStatusStorage,
|
||||||
)
|
)
|
||||||
from ..namespace import NameSpace, is_namespace
|
|
||||||
from ..utils import logger, compute_mdhash_id
|
from ..utils import logger, compute_mdhash_id
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
|
@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
self._data = None
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
if id == "default":
|
# Unified handling for flattened keys
|
||||||
# Find all documents with _id starting with "default_"
|
return await self._data.find_one({"_id": id})
|
||||||
cursor = self._data.find({"_id": {"$regex": "^default_"}})
|
|
||||||
result = {}
|
|
||||||
async for doc in cursor:
|
|
||||||
# Use the complete _id as key
|
|
||||||
result[doc["_id"]] = doc
|
|
||||||
return result if result else None
|
|
||||||
else:
|
|
||||||
# Original behavior for non-"default" ids
|
|
||||||
return await self._data.find_one({"_id": id})
|
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
cursor = self._data.find({"_id": {"$in": ids}})
|
cursor = self._data.find({"_id": {"$in": ids}})
|
||||||
|
|
@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
# Unified handling for all namespaces with flattened keys
|
||||||
update_tasks: list[Any] = []
|
# Use bulk_write for better performance
|
||||||
for mode, items in data.items():
|
from pymongo import UpdateOne
|
||||||
for k, v in items.items():
|
|
||||||
key = f"{mode}_{k}"
|
|
||||||
data[mode][k]["_id"] = f"{mode}_{k}"
|
|
||||||
update_tasks.append(
|
|
||||||
self._data.update_one(
|
|
||||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await asyncio.gather(*update_tasks)
|
|
||||||
else:
|
|
||||||
update_tasks = []
|
|
||||||
for k, v in data.items():
|
|
||||||
data[k]["_id"] = k
|
|
||||||
update_tasks.append(
|
|
||||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
|
||||||
)
|
|
||||||
await asyncio.gather(*update_tasks)
|
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
operations = []
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
for k, v in data.items():
|
||||||
res = {}
|
v["_id"] = k # Use flattened key as _id
|
||||||
v = await self._data.find_one({"_id": mode + "_" + id})
|
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
|
||||||
if v:
|
|
||||||
res[id] = v
|
if operations:
|
||||||
logger.debug(f"llm_response_cache find one by:{id}")
|
await self._data.bulk_write(operations)
|
||||||
return res
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# Mongo handles persistence automatically
|
# Mongo handles persistence automatically
|
||||||
|
|
@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build regex pattern to match documents with the specified modes
|
# Build regex pattern to match flattened key format: mode:cache_type:hash
|
||||||
pattern = f"^({'|'.join(modes)})_"
|
pattern = f"^({'|'.join(modes)}):"
|
||||||
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
||||||
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
return data - existing_ids
|
return data - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
update_tasks: list[Any] = []
|
update_tasks: list[Any] = []
|
||||||
|
|
@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
logger.debug("vector index already exist")
|
logger.debug("vector index already exist")
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
Args:
|
Args:
|
||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -247,6 +247,116 @@ class PostgreSQLDB:
|
||||||
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
||||||
# Do not re-raise, to allow the application to start
|
# Do not re-raise, to allow the application to start
|
||||||
|
|
||||||
|
async def _check_llm_cache_needs_migration(self):
|
||||||
|
"""Check if LLM cache data needs migration by examining the first record"""
|
||||||
|
try:
|
||||||
|
# Only query the first record to determine format
|
||||||
|
check_sql = """
|
||||||
|
SELECT id FROM LIGHTRAG_LLM_CACHE
|
||||||
|
ORDER BY create_time ASC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
result = await self.query(check_sql)
|
||||||
|
|
||||||
|
if result and result.get("id"):
|
||||||
|
# If id doesn't contain colon, it's old format
|
||||||
|
return ":" not in result["id"]
|
||||||
|
|
||||||
|
return False # No data or already new format
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check LLM cache migration status: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _migrate_llm_cache_to_flattened_keys(self):
|
||||||
|
"""Migrate LLM cache to flattened key format, recalculating hash values"""
|
||||||
|
try:
|
||||||
|
# Get all old format data
|
||||||
|
old_data_sql = """
|
||||||
|
SELECT id, mode, original_prompt, return_value, chunk_id,
|
||||||
|
create_time, update_time
|
||||||
|
FROM LIGHTRAG_LLM_CACHE
|
||||||
|
WHERE id NOT LIKE '%:%'
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_records = await self.query(old_data_sql, multirows=True)
|
||||||
|
|
||||||
|
if not old_records:
|
||||||
|
logger.info("No old format LLM cache data found, skipping migration")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(old_records)} old format cache records, starting migration..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import hash calculation function
|
||||||
|
from ..utils import compute_args_hash
|
||||||
|
|
||||||
|
migrated_count = 0
|
||||||
|
|
||||||
|
# Migrate data in batches
|
||||||
|
for record in old_records:
|
||||||
|
try:
|
||||||
|
# Recalculate hash using correct method
|
||||||
|
new_hash = compute_args_hash(
|
||||||
|
record["mode"], record["original_prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new flattened key
|
||||||
|
cache_type = "extract" # Default type
|
||||||
|
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
|
||||||
|
|
||||||
|
# Insert new format data
|
||||||
|
insert_sql = """
|
||||||
|
INSERT INTO LIGHTRAG_LLM_CACHE
|
||||||
|
(workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
ON CONFLICT (workspace, mode, id) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.execute(
|
||||||
|
insert_sql,
|
||||||
|
{
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"id": new_key,
|
||||||
|
"mode": record["mode"],
|
||||||
|
"original_prompt": record["original_prompt"],
|
||||||
|
"return_value": record["return_value"],
|
||||||
|
"chunk_id": record["chunk_id"],
|
||||||
|
"create_time": record["create_time"],
|
||||||
|
"update_time": record["update_time"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete old data
|
||||||
|
delete_sql = """
|
||||||
|
DELETE FROM LIGHTRAG_LLM_CACHE
|
||||||
|
WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||||
|
"""
|
||||||
|
await self.execute(
|
||||||
|
delete_sql,
|
||||||
|
{
|
||||||
|
"workspace": self.workspace,
|
||||||
|
"mode": record["mode"],
|
||||||
|
"id": record["id"], # Old id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to migrate cache record {record['id']}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully migrated {migrated_count} cache records to flattened format"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM cache migration failed: {e}")
|
||||||
|
# Don't raise exception, allow system to continue startup
|
||||||
|
|
||||||
async def check_tables(self):
|
async def check_tables(self):
|
||||||
# First create all tables
|
# First create all tables
|
||||||
for k, v in TABLES.items():
|
for k, v in TABLES.items():
|
||||||
|
|
@ -304,6 +414,13 @@ class PostgreSQLDB:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
||||||
|
|
||||||
|
# Check and migrate LLM cache to flattened keys if needed
|
||||||
|
try:
|
||||||
|
if await self._check_llm_cache_needs_migration():
|
||||||
|
await self._migrate_llm_cache_to_flattened_keys()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
sql: str,
|
sql: str,
|
||||||
|
|
@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.db.query(sql, params, multirows=True)
|
results = await self.db.query(sql, params, multirows=True)
|
||||||
|
|
||||||
|
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
result_dict = {}
|
processed_results = {}
|
||||||
for row in results:
|
for row in results:
|
||||||
mode = row["mode"]
|
# Parse flattened key to extract cache_type
|
||||||
if mode not in result_dict:
|
key_parts = row["id"].split(":")
|
||||||
result_dict[mode] = {}
|
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
||||||
result_dict[mode][row["id"]] = row
|
|
||||||
return result_dict
|
# Map field names and add cache_type for compatibility
|
||||||
else:
|
processed_row = {
|
||||||
return {row["id"]: row for row in results}
|
**row,
|
||||||
|
"return": row.get("return_value", ""), # Map return_value to return
|
||||||
|
"cache_type": cache_type, # Add cache_type from key
|
||||||
|
"original_prompt": row.get("original_prompt", ""),
|
||||||
|
"chunk_id": row.get("chunk_id"),
|
||||||
|
"mode": row.get("mode", "default")
|
||||||
|
}
|
||||||
|
processed_results[row["id"]] = processed_row
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
# For other namespaces, return as-is
|
||||||
|
return {row["id"]: row for row in results}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
"""Get doc_full data by id."""
|
"""Get data by id."""
|
||||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
# For LLM cache, the id parameter actually represents the mode
|
response = await self.db.query(sql, params)
|
||||||
params = {"workspace": self.db.workspace, "mode": id}
|
return response if response else None
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
res = {}
|
|
||||||
for row in array_res:
|
|
||||||
# Dynamically add cache_type field based on mode
|
|
||||||
row_with_cache_type = dict(row)
|
|
||||||
if id == "default":
|
|
||||||
row_with_cache_type["cache_type"] = "extract"
|
|
||||||
else:
|
|
||||||
row_with_cache_type["cache_type"] = "unknown"
|
|
||||||
res[row["id"]] = row_with_cache_type
|
|
||||||
return res if res else None
|
|
||||||
else:
|
|
||||||
params = {"workspace": self.db.workspace, "id": id}
|
|
||||||
response = await self.db.query(sql, params)
|
|
||||||
return response if response else None
|
|
||||||
|
|
||||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
|
||||||
"""Specifically for llm_response_cache."""
|
|
||||||
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
|
||||||
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
res = {}
|
|
||||||
for row in array_res:
|
|
||||||
res[row["id"]] = row
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""Get doc_chunks data by id"""
|
"""Get data by ids"""
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
)
|
)
|
||||||
params = {"workspace": self.db.workspace}
|
params = {"workspace": self.db.workspace}
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
return await self.db.query(sql, params, multirows=True)
|
||||||
array_res = await self.db.query(sql, params, multirows=True)
|
|
||||||
modes = set()
|
|
||||||
dict_res: dict[str, dict] = {}
|
|
||||||
for row in array_res:
|
|
||||||
modes.add(row["mode"])
|
|
||||||
for mode in modes:
|
|
||||||
if mode not in dict_res:
|
|
||||||
dict_res[mode] = {}
|
|
||||||
for row in array_res:
|
|
||||||
dict_res[row["mode"]][row["id"]] = row
|
|
||||||
return [{k: v} for k, v in dict_res.items()]
|
|
||||||
else:
|
|
||||||
return await self.db.query(sql, params, multirows=True)
|
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||||
"""Specifically for llm_response_cache."""
|
"""Specifically for llm_response_cache."""
|
||||||
|
|
@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage):
|
||||||
}
|
}
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||||
for mode, items in data.items():
|
for k, v in data.items():
|
||||||
for k, v in items.items():
|
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
_data = {
|
||||||
_data = {
|
"workspace": self.db.workspace,
|
||||||
"workspace": self.db.workspace,
|
"id": k, # Use flattened key as id
|
||||||
"id": k,
|
"original_prompt": v["original_prompt"],
|
||||||
"original_prompt": v["original_prompt"],
|
"return_value": v["return"],
|
||||||
"return_value": v["return"],
|
"mode": v.get("mode", "default"), # Get mode from data
|
||||||
"mode": mode,
|
"chunk_id": v.get("chunk_id"),
|
||||||
"chunk_id": v.get("chunk_id"),
|
}
|
||||||
}
|
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# PG handles persistence automatically
|
# PG handles persistence automatically
|
||||||
|
|
@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
else:
|
else:
|
||||||
exist_keys = []
|
exist_keys = []
|
||||||
new_keys = set([s for s in keys if s not in exist_keys])
|
new_keys = set([s for s in keys if s not in exist_keys])
|
||||||
print(f"keys: {keys}")
|
# print(f"keys: {keys}")
|
||||||
print(f"new_keys: {new_keys}")
|
# print(f"new_keys: {new_keys}")
|
||||||
return new_keys
|
return new_keys
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -2621,7 +2708,7 @@ SQL_TEMPLATES = {
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||||
""",
|
""",
|
||||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, final
|
from typing import Any, final, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import configparser
|
import configparser
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import threading
|
||||||
|
|
||||||
if not pm.is_installed("redis"):
|
if not pm.is_installed("redis"):
|
||||||
pm.install("redis")
|
pm.install("redis")
|
||||||
|
|
@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
||||||
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
from lightrag.base import BaseKVStorage
|
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0
|
||||||
SOCKET_CONNECT_TIMEOUT = 3.0
|
SOCKET_CONNECT_TIMEOUT = 3.0
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConnectionManager:
|
||||||
|
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
||||||
|
|
||||||
|
_pools = {}
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
||||||
|
"""Get or create a connection pool for the given Redis URL"""
|
||||||
|
if redis_url not in cls._pools:
|
||||||
|
with cls._lock:
|
||||||
|
if redis_url not in cls._pools:
|
||||||
|
cls._pools[redis_url] = ConnectionPool.from_url(
|
||||||
|
redis_url,
|
||||||
|
max_connections=MAX_CONNECTIONS,
|
||||||
|
decode_responses=True,
|
||||||
|
socket_timeout=SOCKET_TIMEOUT,
|
||||||
|
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||||
|
)
|
||||||
|
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
||||||
|
return cls._pools[redis_url]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close_all_pools(cls):
|
||||||
|
"""Close all connection pools (for cleanup)"""
|
||||||
|
with cls._lock:
|
||||||
|
for url, pool in cls._pools.items():
|
||||||
|
try:
|
||||||
|
pool.disconnect()
|
||||||
|
logger.info(f"Closed Redis connection pool for {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing Redis pool for {url}: {e}")
|
||||||
|
cls._pools.clear()
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class RedisKVStorage(BaseKVStorage):
|
class RedisKVStorage(BaseKVStorage):
|
||||||
|
|
@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
redis_url = os.environ.get(
|
redis_url = os.environ.get(
|
||||||
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||||
)
|
)
|
||||||
# Create a connection pool with limits
|
# Use shared connection pool
|
||||||
self._pool = ConnectionPool.from_url(
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||||
redis_url,
|
|
||||||
max_connections=MAX_CONNECTIONS,
|
|
||||||
decode_responses=True,
|
|
||||||
socket_timeout=SOCKET_TIMEOUT,
|
|
||||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
|
||||||
)
|
|
||||||
self._redis = Redis(connection_pool=self._pool)
|
self._redis = Redis(connection_pool=self._pool)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
|
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
||||||
|
# Test connection
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
await redis.ping()
|
||||||
|
logger.info(f"Connected to Redis for namespace {self.namespace}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Migrate legacy cache structure if this is a cache namespace
|
||||||
|
if self.namespace.endswith("_cache"):
|
||||||
|
await self._migrate_legacy_cache_structure()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_redis_connection(self):
|
async def _get_redis_connection(self):
|
||||||
"""Safe context manager for Redis operations."""
|
"""Safe context manager for Redis operations."""
|
||||||
|
|
@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
logger.error(f"JSON decode error in batch get: {e}")
|
logger.error(f"JSON decode error in batch get: {e}")
|
||||||
return [None] * len(ids)
|
return [None] * len(ids)
|
||||||
|
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""Get all data from storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing all stored data
|
||||||
|
"""
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Get all keys for this namespace
|
||||||
|
keys = await redis.keys(f"{self.namespace}:*")
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Build result dictionary
|
||||||
|
result = {}
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value:
|
||||||
|
# Extract the ID part (after namespace:)
|
||||||
|
key_id = key.split(":", 1)[1]
|
||||||
|
try:
|
||||||
|
result[key_id] = json.loads(value)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error for key {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting all data from Redis: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
for key in keys:
|
keys_list = list(keys) # Convert set to list for indexing
|
||||||
|
for key in keys_list:
|
||||||
pipe.exists(f"{self.namespace}:{key}")
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||||
return set(keys) - existing_ids
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Inserting {len(data)} items to {self.namespace}")
|
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
|
|
@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
||||||
"""Delete specific records from storage by by cache mode
|
"""Delete specific records from storage by cache mode
|
||||||
|
|
||||||
Importance notes for Redis storage:
|
Importance notes for Redis storage:
|
||||||
1. This will immediately delete the specified cache modes from Redis
|
1. This will immediately delete the specified cache modes from Redis
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
modes (list[str]): List of cache mode to be drop from storage
|
modes (list[str]): List of cache modes to be dropped from storage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True: if the cache drop successfully
|
True: if the cache drop successfully
|
||||||
|
|
@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.delete(modes)
|
async with self._get_redis_connection() as redis:
|
||||||
|
keys_to_delete = []
|
||||||
|
|
||||||
|
# Find matching keys for each mode using SCAN
|
||||||
|
for mode in modes:
|
||||||
|
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
|
||||||
|
pattern = f"{self.namespace}:{mode}:*"
|
||||||
|
cursor = 0
|
||||||
|
mode_keys = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
|
if keys:
|
||||||
|
mode_keys.extend(keys)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
keys_to_delete.extend(mode_keys)
|
||||||
|
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
|
||||||
|
|
||||||
|
if keys_to_delete:
|
||||||
|
# Batch delete
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys_to_delete:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count = sum(results)
|
||||||
|
logger.info(
|
||||||
|
f"Dropped {deleted_count} cache entries for modes: {modes}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No cache entries found for modes: {modes}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping cache by modes in Redis: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
|
|
@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
"""
|
"""
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
keys = await redis.keys(f"{self.namespace}:*")
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
|
pattern = f"{self.namespace}:*"
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
|
if keys:
|
||||||
|
# Delete keys in batches
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
if keys:
|
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
||||||
pipe = redis.pipeline()
|
return {
|
||||||
for key in keys:
|
"status": "success",
|
||||||
pipe.delete(key)
|
"message": f"{deleted_count} keys dropped",
|
||||||
results = await pipe.execute()
|
}
|
||||||
deleted_count = sum(results)
|
|
||||||
|
|
||||||
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": f"{deleted_count} keys dropped",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.info(f"No keys found to drop in {self.namespace}")
|
|
||||||
return {"status": "success", "message": "no keys to drop"}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
async def _migrate_legacy_cache_structure(self):
|
||||||
|
"""Migrate legacy nested cache structure to flattened structure for Redis
|
||||||
|
|
||||||
|
Redis already stores data in a flattened way, but we need to check for
|
||||||
|
legacy keys that might contain nested JSON structures and migrate them.
|
||||||
|
|
||||||
|
Early exit if any flattened key is found (indicating migration already done).
|
||||||
|
"""
|
||||||
|
from lightrag.utils import generate_cache_key
|
||||||
|
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
# Get all keys for this namespace
|
||||||
|
keys = await redis.keys(f"{self.namespace}:*")
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we have any flattened keys already - if so, skip migration
|
||||||
|
has_flattened_keys = False
|
||||||
|
keys_to_migrate = []
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
# Extract the ID part (after namespace:)
|
||||||
|
key_id = key.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
|
||||||
|
if ":" in key_id and len(key_id.split(":")) == 3:
|
||||||
|
has_flattened_keys = True
|
||||||
|
break # Early exit - migration already done
|
||||||
|
|
||||||
|
# Get the data to check if it's a legacy nested structure
|
||||||
|
data = await redis.get(key)
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(data)
|
||||||
|
# Check if this looks like a legacy cache mode with nested structure
|
||||||
|
if isinstance(parsed_data, dict) and all(
|
||||||
|
isinstance(v, dict) and "return" in v
|
||||||
|
for v in parsed_data.values()
|
||||||
|
):
|
||||||
|
keys_to_migrate.append((key, key_id, parsed_data))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If we found any flattened keys, assume migration is already done
|
||||||
|
if has_flattened_keys:
|
||||||
|
logger.debug(
|
||||||
|
f"Found flattened cache keys in {self.namespace}, skipping migration"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not keys_to_migrate:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Perform migration
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
migration_count = 0
|
||||||
|
|
||||||
|
for old_key, mode, nested_data in keys_to_migrate:
|
||||||
|
# Delete the old key
|
||||||
|
pipe.delete(old_key)
|
||||||
|
|
||||||
|
# Create new flattened keys
|
||||||
|
for cache_hash, cache_entry in nested_data.items():
|
||||||
|
cache_type = cache_entry.get("cache_type", "extract")
|
||||||
|
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
||||||
|
full_key = f"{self.namespace}:{flattened_key}"
|
||||||
|
pipe.set(full_key, json.dumps(cache_entry))
|
||||||
|
migration_count += 1
|
||||||
|
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
if migration_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
@dataclass
|
||||||
|
class RedisDocStatusStorage(DocStatusStorage):
|
||||||
|
"""Redis implementation of document status storage"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
redis_url = os.environ.get(
|
||||||
|
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
||||||
|
)
|
||||||
|
# Use shared connection pool
|
||||||
|
self._pool = RedisConnectionManager.get_pool(redis_url)
|
||||||
|
self._redis = Redis(connection_pool=self._pool)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize Redis connection"""
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
await redis.ping()
|
||||||
|
logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _get_redis_connection(self):
|
||||||
|
"""Safe context manager for Redis operations."""
|
||||||
|
try:
|
||||||
|
yield self._redis
|
||||||
|
except ConnectionError as e:
|
||||||
|
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
|
||||||
|
raise
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the Redis connection."""
|
||||||
|
if hasattr(self, "_redis") and self._redis:
|
||||||
|
await self._redis.close()
|
||||||
|
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Support for async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Ensure Redis resources are cleaned up when exiting context."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
|
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
keys_list = list(keys)
|
||||||
|
for key in keys_list:
|
||||||
|
pipe.exists(f"{self.namespace}:{key}")
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
||||||
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for id in ids:
|
||||||
|
pipe.get(f"{self.namespace}:{id}")
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
for result_data in results:
|
||||||
|
if result_data:
|
||||||
|
try:
|
||||||
|
result.append(json.loads(result_data))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in get_by_ids: {e}")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_by_ids: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
|
"""Get counts of documents in each status"""
|
||||||
|
counts = {status.value: 0 for status in DocStatus}
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
||||||
|
if keys:
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Count statuses
|
||||||
|
for value in values:
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
doc_data = json.loads(value)
|
||||||
|
status = doc_data.get("status")
|
||||||
|
if status in counts:
|
||||||
|
counts[status] += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting status counts: {e}")
|
||||||
|
|
||||||
|
return counts
|
||||||
|
|
||||||
|
async def get_docs_by_status(
|
||||||
|
self, status: DocStatus
|
||||||
|
) -> dict[str, DocProcessingStatus]:
|
||||||
|
"""Get all documents with a specific status"""
|
||||||
|
result = {}
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
# Use SCAN to iterate through all keys in the namespace
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
||||||
|
if keys:
|
||||||
|
# Get all values in batch
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
# Filter by status and create DocProcessingStatus objects
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
doc_data = json.loads(value)
|
||||||
|
if doc_data.get("status") == status.value:
|
||||||
|
# Extract document ID from key
|
||||||
|
doc_id = key.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Make a copy of the data to avoid modifying the original
|
||||||
|
data = doc_data.copy()
|
||||||
|
# If content is missing, use content_summary as content
|
||||||
|
if "content" not in data and "content_summary" in data:
|
||||||
|
data["content"] = data["content_summary"]
|
||||||
|
# If file_path is not in data, use document id as file path
|
||||||
|
if "file_path" not in data:
|
||||||
|
data["file_path"] = "no-file-path"
|
||||||
|
|
||||||
|
result[doc_id] = DocProcessingStatus(**data)
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.error(f"Error processing document {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting docs by status: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
"""Redis handles persistence automatically"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
"""Insert or update document status data"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for k, v in data.items():
|
||||||
|
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||||
|
await pipe.execute()
|
||||||
|
except json.JSONEncodeError as e:
|
||||||
|
logger.error(f"JSON encode error during upsert: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
try:
|
||||||
|
data = await redis.get(f"{self.namespace}:{id}")
|
||||||
|
return json.loads(data) if data else None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error for id {id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete(self, doc_ids: list[str]) -> None:
|
||||||
|
"""Delete specific records from storage by their IDs"""
|
||||||
|
if not doc_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
pipe.delete(f"{self.namespace}:{doc_id}")
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count = sum(results)
|
||||||
|
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
|
||||||
|
|
||||||
|
async def drop(self) -> dict[str, str]:
|
||||||
|
"""Drop all document status data from storage and clean up resources"""
|
||||||
|
try:
|
||||||
|
async with self._get_redis_connection() as redis:
|
||||||
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
|
pattern = f"{self.namespace}:*"
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
|
if keys:
|
||||||
|
# Delete keys in batches
|
||||||
|
pipe = redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
|
||||||
|
return {"status": "success", "message": "data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
################ INSERT full_doc AND chunks ################
|
################ INSERT full_doc AND chunks ################
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||||
|
|
@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
||||||
###### INSERT entities And relationships ######
|
###### INSERT entities And relationships ######
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
|
||||||
|
|
||||||
# Get current time as UNIX timestamp
|
# Get current time as UNIX timestamp
|
||||||
import time
|
import time
|
||||||
|
|
|
||||||
|
|
@ -399,10 +399,10 @@ async def _get_cached_extraction_results(
|
||||||
"""
|
"""
|
||||||
cached_results = {}
|
cached_results = {}
|
||||||
|
|
||||||
# Get all cached data for "default" mode (entity extraction cache)
|
# Get all cached data (flattened cache structure)
|
||||||
default_cache = await llm_response_cache.get_by_id("default") or {}
|
all_cache = await llm_response_cache.get_all()
|
||||||
|
|
||||||
for cache_key, cache_entry in default_cache.items():
|
for cache_key, cache_entry in all_cache.items():
|
||||||
if (
|
if (
|
||||||
isinstance(cache_entry, dict)
|
isinstance(cache_entry, dict)
|
||||||
and cache_entry.get("cache_type") == "extract"
|
and cache_entry.get("cache_type") == "extract"
|
||||||
|
|
@ -1387,7 +1387,7 @@ async def kg_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -1546,7 +1546,7 @@ async def extract_keywords_only(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Handle cache if needed - add cache type for keywords
|
# 1. Handle cache if needed - add cache type for keywords
|
||||||
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
|
args_hash = compute_args_hash(param.mode, text)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||||
)
|
)
|
||||||
|
|
@ -2413,7 +2413,7 @@ async def naive_query(
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# Handle cache
|
# Handle cache
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
@ -2529,7 +2529,7 @@ async def kg_query_with_keywords(
|
||||||
# Apply higher priority (5) to query relation LLM function
|
# Apply higher priority (5) to query relation LLM function
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query)
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightrag.prompt import PROMPTS
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from lightrag.constants import (
|
from lightrag.constants import (
|
||||||
DEFAULT_LOG_MAX_BYTES,
|
DEFAULT_LOG_MAX_BYTES,
|
||||||
|
|
@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
"""Compute a hash for the given arguments.
|
"""Compute a hash for the given arguments.
|
||||||
Args:
|
Args:
|
||||||
*args: Arguments to hash
|
*args: Arguments to hash
|
||||||
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Hash string
|
str: Hash string
|
||||||
"""
|
"""
|
||||||
|
|
@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
||||||
|
|
||||||
# Convert all arguments to strings and join them
|
# Convert all arguments to strings and join them
|
||||||
args_str = "".join([str(arg) for arg in args])
|
args_str = "".join([str(arg) for arg in args])
|
||||||
if cache_type:
|
|
||||||
args_str = f"{cache_type}:{args_str}"
|
|
||||||
|
|
||||||
# Compute MD5 hash
|
# Compute MD5 hash
|
||||||
return hashlib.md5(args_str.encode()).hexdigest()
|
return hashlib.md5(args_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
|
||||||
|
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Cache mode (e.g., 'default', 'local', 'global')
|
||||||
|
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
|
||||||
|
hash_value: Hash value from compute_args_hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Flattened cache key
|
||||||
|
"""
|
||||||
|
return f"{mode}:{cache_type}:{hash_value}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
|
||||||
|
"""Parse a flattened cache key back into its components
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
|
||||||
|
"""
|
||||||
|
parts = cache_key.split(":", 2)
|
||||||
|
if len(parts) == 3:
|
||||||
|
return parts[0], parts[1], parts[2]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Compute a unique ID for a given content string.
|
Compute a unique ID for a given content string.
|
||||||
|
|
@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
|
||||||
return combined_data
|
return combined_data
|
||||||
|
|
||||||
|
|
||||||
async def get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding,
|
|
||||||
similarity_threshold=0.95,
|
|
||||||
mode="default",
|
|
||||||
use_llm_check=False,
|
|
||||||
llm_func=None,
|
|
||||||
original_prompt=None,
|
|
||||||
cache_type=None,
|
|
||||||
) -> str | None:
|
|
||||||
logger.debug(
|
|
||||||
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
|
||||||
)
|
|
||||||
mode_cache = await hashing_kv.get_by_id(mode)
|
|
||||||
if not mode_cache:
|
|
||||||
return None
|
|
||||||
|
|
||||||
best_similarity = -1
|
|
||||||
best_response = None
|
|
||||||
best_prompt = None
|
|
||||||
best_cache_id = None
|
|
||||||
|
|
||||||
# Only iterate through cache entries for this mode
|
|
||||||
for cache_id, cache_data in mode_cache.items():
|
|
||||||
# Skip if cache_type doesn't match
|
|
||||||
if cache_type and cache_data.get("cache_type") != cache_type:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if cache data is valid
|
|
||||||
if cache_data["embedding"] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Safely convert cached embedding
|
|
||||||
cached_quantized = np.frombuffer(
|
|
||||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
|
||||||
).reshape(cache_data["embedding_shape"])
|
|
||||||
|
|
||||||
# Ensure min_val and max_val are valid float values
|
|
||||||
embedding_min = cache_data.get("embedding_min")
|
|
||||||
embedding_max = cache_data.get("embedding_max")
|
|
||||||
|
|
||||||
if (
|
|
||||||
embedding_min is None
|
|
||||||
or embedding_max is None
|
|
||||||
or embedding_min >= embedding_max
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cached_embedding = dequantize_embedding(
|
|
||||||
cached_quantized,
|
|
||||||
embedding_min,
|
|
||||||
embedding_max,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing cached embedding: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
|
||||||
if similarity > best_similarity:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_response = cache_data["return"]
|
|
||||||
best_prompt = cache_data["original_prompt"]
|
|
||||||
best_cache_id = cache_id
|
|
||||||
|
|
||||||
if best_similarity > similarity_threshold:
|
|
||||||
# If LLM check is enabled and all required parameters are provided
|
|
||||||
if (
|
|
||||||
use_llm_check
|
|
||||||
and llm_func
|
|
||||||
and original_prompt
|
|
||||||
and best_prompt
|
|
||||||
and best_response is not None
|
|
||||||
):
|
|
||||||
compare_prompt = PROMPTS["similarity_check"].format(
|
|
||||||
original_prompt=original_prompt, cached_prompt=best_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm_result = await llm_func(compare_prompt)
|
|
||||||
llm_result = llm_result.strip()
|
|
||||||
llm_similarity = float(llm_result)
|
|
||||||
|
|
||||||
# Replace vector similarity with LLM similarity score
|
|
||||||
best_similarity = llm_similarity
|
|
||||||
if best_similarity < similarity_threshold:
|
|
||||||
log_data = {
|
|
||||||
"event": "cache_rejected_by_llm",
|
|
||||||
"type": cache_type,
|
|
||||||
"mode": mode,
|
|
||||||
"original_question": original_prompt[:100] + "..."
|
|
||||||
if len(original_prompt) > 100
|
|
||||||
else original_prompt,
|
|
||||||
"cached_question": best_prompt[:100] + "..."
|
|
||||||
if len(best_prompt) > 100
|
|
||||||
else best_prompt,
|
|
||||||
"similarity_score": round(best_similarity, 4),
|
|
||||||
"threshold": similarity_threshold,
|
|
||||||
}
|
|
||||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
||||||
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
|
||||||
return None
|
|
||||||
except Exception as e: # Catch all possible exceptions
|
|
||||||
logger.warning(f"LLM similarity check failed: {e}")
|
|
||||||
return None # Return None directly when LLM check fails
|
|
||||||
|
|
||||||
prompt_display = (
|
|
||||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
|
||||||
)
|
|
||||||
log_data = {
|
|
||||||
"event": "cache_hit",
|
|
||||||
"type": cache_type,
|
|
||||||
"mode": mode,
|
|
||||||
"similarity": round(best_similarity, 4),
|
|
||||||
"cache_id": best_cache_id,
|
|
||||||
"original_prompt": prompt_display,
|
|
||||||
}
|
|
||||||
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
||||||
return best_response
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
"""Calculate cosine similarity between two vectors"""
|
"""Calculate cosine similarity between two vectors"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
|
|
@ -957,7 +857,7 @@ async def handle_cache(
|
||||||
mode="default",
|
mode="default",
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
):
|
):
|
||||||
"""Generic cache handling function"""
|
"""Generic cache handling function with flattened cache keys"""
|
||||||
if hashing_kv is None:
|
if hashing_kv is None:
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
@ -968,15 +868,14 @@ async def handle_cache(
|
||||||
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
flattened_key = generate_cache_key(mode, cache_type, args_hash)
|
||||||
else:
|
cache_entry = await hashing_kv.get_by_id(flattened_key)
|
||||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
if cache_entry:
|
||||||
if args_hash in mode_cache:
|
logger.debug(f"Flattened cache hit(key:{flattened_key})")
|
||||||
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
return cache_entry["return"], None, None, None
|
||||||
return mode_cache[args_hash]["return"], None, None, None
|
|
||||||
|
|
||||||
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -994,7 +893,7 @@ class CacheData:
|
||||||
|
|
||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
"""Save data to cache using flattened key structure.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hashing_kv: The key-value storage for caching
|
hashing_kv: The key-value storage for caching
|
||||||
|
|
@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
logger.debug("Streaming response detected, skipping cache")
|
logger.debug("Streaming response detected, skipping cache")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get existing cache data
|
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
flattened_key = generate_cache_key(
|
||||||
mode_cache = (
|
cache_data.mode, cache_data.cache_type, cache_data.args_hash
|
||||||
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
)
|
||||||
or {}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
|
||||||
|
|
||||||
# Check if we already have identical content cached
|
# Check if we already have identical content cached
|
||||||
if cache_data.args_hash in mode_cache:
|
existing_cache = await hashing_kv.get_by_id(flattened_key)
|
||||||
existing_content = mode_cache[cache_data.args_hash].get("return")
|
if existing_cache:
|
||||||
|
existing_content = existing_cache.get("return")
|
||||||
if existing_content == cache_data.content:
|
if existing_content == cache_data.content:
|
||||||
logger.info(
|
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
|
||||||
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Update cache with new content
|
# Create cache entry with flattened structure
|
||||||
mode_cache[cache_data.args_hash] = {
|
cache_entry = {
|
||||||
"return": cache_data.content,
|
"return": cache_data.content,
|
||||||
"cache_type": cache_data.cache_type,
|
"cache_type": cache_data.cache_type,
|
||||||
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
||||||
|
|
@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"original_prompt": cache_data.prompt,
|
"original_prompt": cache_data.prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
|
logger.info(f" == LLM cache == saving: {flattened_key}")
|
||||||
|
|
||||||
# Only upsert if there's actual new content
|
# Save using flattened key
|
||||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
await hashing_kv.upsert({flattened_key: cache_entry})
|
||||||
|
|
||||||
|
|
||||||
def safe_unicode_decode(content):
|
def safe_unicode_decode(content):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue