Refac: Enhance KG rebuild stability by incorporating create_time into the LLM cache

This commit is contained in:
yangdx 2025-07-03 17:08:29 +08:00
parent a9e10ae810
commit 6c2ae40d7d
5 changed files with 172 additions and 35 deletions

View file

@ -78,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
Dictionary containing all stored data
"""
async with self._storage_lock:
return dict(self._data)
result = {}
for key, value in self._data.items():
if value:
# Create a copy to avoid modifying the original data
data = dict(value)
# Ensure time fields are present, provide default values for old data
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key] = data
else:
result[key] = value
return result
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock:
return self._data.get(id)
result = self._data.get(id)
if result:
# Create a copy to avoid modifying the original data
result = dict(result)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
return result
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._storage_lock:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
results = []
for id in ids:
data = self._data.get(id, None)
if data:
# Create a copy to avoid modifying the original data
result = {k: v for k, v in data.items()}
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
# Ensure _id field contains the clean ID
result["_id"] = id
results.append(result)
else:
results.append(None)
return results
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._storage_lock:
@ -107,13 +134,29 @@ class JsonKVStorage(BaseKVStorage):
"""
if not data:
return
import time
current_time = int(time.time()) # Get current Unix timestamp
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
async with self._storage_lock:
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
for chunk_id, chunk_data in data.items():
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
# Add timestamps to data based on whether key exists
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if k in self._data: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
self._data.update(data)
await set_all_update_flags(self.namespace)

View file

@ -98,11 +98,21 @@ class MongoKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> dict[str, Any] | None:
# Unified handling for flattened keys
return await self._data.find_one({"_id": id})
doc = await self._data.find_one({"_id": id})
if doc:
# Ensure time fields are present, provide default values for old data
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return doc
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list()
docs = await cursor.to_list()
# Ensure time fields are present for all documents
for doc in docs:
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
return docs
async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@ -119,6 +129,9 @@ class MongoKVStorage(BaseKVStorage):
result = {}
async for doc in cursor:
doc_id = doc.pop("_id")
# Ensure time fields are present for all documents
doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0)
result[doc_id] = doc
return result
@ -132,6 +145,8 @@ class MongoKVStorage(BaseKVStorage):
from pymongo import UpdateOne
operations = []
current_time = int(time.time()) # Get current Unix timestamp
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if self.namespace.endswith("text_chunks"):
@ -139,7 +154,20 @@ class MongoKVStorage(BaseKVStorage):
v["llm_cache_list"] = []
v["_id"] = k # Use flattened key as _id
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
v["update_time"] = current_time # Always update update_time
operations.append(
UpdateOne(
{"_id": k},
{
"$set": v, # Update all fields including update_time
"$setOnInsert": {
"create_time": current_time
}, # Set create_time only on insert
},
upsert=True,
)
)
if operations:
await self._data.bulk_write(operations)

View file

@ -752,6 +752,8 @@ class PGKVStorage(BaseKVStorage):
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": row.get("create_time", 0),
"update_time": row.get("update_time", 0),
}
processed_results[row["id"]] = processed_row
return processed_results
@ -767,6 +769,8 @@ class PGKVStorage(BaseKVStorage):
except json.JSONDecodeError:
llm_cache_list = []
row["llm_cache_list"] = llm_cache_list
row["create_time"] = row.get("create_time", 0)
row["update_time"] = row.get("update_time", 0)
processed_results[row["id"]] = row
return processed_results
@ -791,6 +795,8 @@ class PGKVStorage(BaseKVStorage):
except json.JSONDecodeError:
llm_cache_list = []
response["llm_cache_list"] = llm_cache_list
response["create_time"] = response.get("create_time", 0)
response["update_time"] = response.get("update_time", 0)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if response and is_namespace(
@ -804,6 +810,8 @@ class PGKVStorage(BaseKVStorage):
"original_prompt": response.get("original_prompt", ""),
"chunk_id": response.get("chunk_id"),
"mode": response.get("mode", "default"),
"create_time": response.get("create_time", 0),
"update_time": response.get("update_time", 0),
}
return response if response else None
@ -827,6 +835,8 @@ class PGKVStorage(BaseKVStorage):
except json.JSONDecodeError:
llm_cache_list = []
result["llm_cache_list"] = llm_cache_list
result["create_time"] = result.get("create_time", 0)
result["update_time"] = result.get("update_time", 0)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if results and is_namespace(
@ -842,6 +852,8 @@ class PGKVStorage(BaseKVStorage):
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"mode": row.get("mode", "default"),
"create_time": row.get("create_time", 0),
"update_time": row.get("update_time", 0),
}
processed_results.append(processed_row)
return processed_results
@ -2941,10 +2953,12 @@ SQL_TEMPLATES = {
""",
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
""",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
@ -2955,10 +2969,12 @@ SQL_TEMPLATES = {
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
create_time, update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
create_time, update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",

View file

@ -132,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
if data:
result = json.loads(data)
# Ensure time fields are present, provide default values for old data
result.setdefault("create_time", 0)
result.setdefault("update_time", 0)
return result
return None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
@ -144,7 +150,19 @@ class RedisKVStorage(BaseKVStorage):
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
processed_results = []
for result in results:
if result:
data = json.loads(result)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
processed_results.append(data)
else:
processed_results.append(None)
return processed_results
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
@ -176,7 +194,11 @@ class RedisKVStorage(BaseKVStorage):
# Extract the ID part (after namespace:)
key_id = key.split(":", 1)[1]
try:
result[key_id] = json.loads(value)
data = json.loads(value)
# Ensure time fields are present for all documents
data.setdefault("create_time", 0)
data.setdefault("update_time", 0)
result[key_id] = data
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}")
continue
@ -200,21 +222,41 @@ class RedisKVStorage(BaseKVStorage):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
return
import time
current_time = int(time.time()) # Get current Unix timestamp
async with self._get_redis_connection() as redis:
try:
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
for chunk_id, chunk_data in data.items():
if "llm_cache_list" not in chunk_data:
chunk_data["llm_cache_list"] = []
# Check which keys already exist to determine create vs update
pipe = redis.pipeline()
for k in data.keys():
pipe.exists(f"{self.namespace}:{k}")
exists_results = await pipe.execute()
# Add timestamps to data
for i, (k, v) in enumerate(data.items()):
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
# Add timestamps based on whether key exists
if exists_results[i]: # Key exists, only update update_time
v["update_time"] = current_time
else: # New key, set both create_time and update_time
v["create_time"] = current_time
v["update_time"] = current_time
v["_id"] = k
# Store the data
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise

View file

@ -273,8 +273,6 @@ async def _rebuild_knowledge_from_chunks(
all_referenced_chunk_ids.update(chunk_ids)
for chunk_ids in relationships_to_rebuild.values():
all_referenced_chunk_ids.update(chunk_ids)
# sort all_referenced_chunk_ids to get a stable order in merge stage
all_referenced_chunk_ids = sorted(all_referenced_chunk_ids)
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
logger.info(status_message)
@ -464,12 +462,22 @@ async def _get_cached_extraction_results(
):
chunk_id = cache_entry["chunk_id"]
extraction_result = cache_entry["return"]
create_time = cache_entry.get(
"create_time", 0
) # Get creation time, default to 0
valid_entries += 1
# Support multiple LLM caches per chunk
if chunk_id not in cached_results:
cached_results[chunk_id] = []
cached_results[chunk_id].append(extraction_result)
# Store tuple with extraction result and creation time for sorting
cached_results[chunk_id].append((extraction_result, create_time))
# Sort extraction results by create_time for each chunk
for chunk_id in cached_results:
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
cached_results[chunk_id].sort(key=lambda x: x[1])
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
logger.info(
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"