From 14cda93988dc110c11cf1fe8286ded7b5db7b8fb Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 29 Jun 2025 15:13:42 +0800 Subject: [PATCH] Fix LLM cache handling for Redis to address document deletion scenarios. - Implements bulk scan for "extract" cache entries - Maintains backward compatibility for normal IDs --- lightrag/kg/redis_impl.py | 53 +++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 65c25bfc..1bc07de8 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -79,13 +79,52 @@ class RedisKVStorage(BaseKVStorage): await self.close() async def get_by_id(self, id: str) -> dict[str, Any] | None: - async with self._get_redis_connection() as redis: - try: - data = await redis.get(f"{self.namespace}:{id}") - return json.loads(data) if data else None - except json.JSONDecodeError as e: - logger.error(f"JSON decode error for id {id}: {e}") - return None + if id == "default": + # Find all cache entries with cache_type == "extract" + async with self._get_redis_connection() as redis: + try: + result = {} + pattern = f"{self.namespace}:*" + cursor = 0 + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=100) + + if keys: + # Batch get values for these keys + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Check each value for cache_type == "extract" + for key, value in zip(keys, values): + if value: + try: + data = json.loads(value) + if isinstance(data, dict) and data.get("cache_type") == "extract": + # Extract cache key (remove namespace prefix) + cache_key = key.replace(f"{self.namespace}:", "") + result[cache_key] = data + except json.JSONDecodeError: + continue + + if cursor == 0: + break + + return result if result else None + except Exception as e: + logger.error(f"Error scanning Redis for extract cache entries: {e}") + return None + else: + # Original behavior for non-"default" ids + async with self._get_redis_connection() as redis: + try: + data = await redis.get(f"{self.namespace}:{id}") + return json.loads(data) if data else None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for id {id}: {e}") + return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async with self._get_redis_connection() as redis: