Refac: Enhance KG rebuild stability by incorporating create_time into the LLM cache
This commit is contained in:
parent
a9e10ae810
commit
6c2ae40d7d
5 changed files with 172 additions and 35 deletions
|
|
@ -78,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
|
|||
Dictionary containing all stored data
|
||||
"""
|
||||
async with self._storage_lock:
|
||||
return dict(self._data)
|
||||
result = {}
|
||||
for key, value in self._data.items():
|
||||
if value:
|
||||
# Create a copy to avoid modifying the original data
|
||||
data = dict(value)
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
data.setdefault("create_time", 0)
|
||||
data.setdefault("update_time", 0)
|
||||
result[key] = data
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
result = self._data.get(id)
|
||||
if result:
|
||||
# Create a copy to avoid modifying the original data
|
||||
result = dict(result)
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
result.setdefault("create_time", 0)
|
||||
result.setdefault("update_time", 0)
|
||||
# Ensure _id field contains the clean ID
|
||||
result["_id"] = id
|
||||
return result
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
async with self._storage_lock:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
results = []
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
# Create a copy to avoid modifying the original data
|
||||
result = {k: v for k, v in data.items()}
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
result.setdefault("create_time", 0)
|
||||
result.setdefault("update_time", 0)
|
||||
# Ensure _id field contains the clean ID
|
||||
result["_id"] = id
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(None)
|
||||
return results
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
async with self._storage_lock:
|
||||
|
|
@ -107,13 +134,29 @@ class JsonKVStorage(BaseKVStorage):
|
|||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
import time
|
||||
|
||||
current_time = int(time.time()) # Get current Unix timestamp
|
||||
|
||||
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
||||
async with self._storage_lock:
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if "text_chunks" in self.namespace:
|
||||
for chunk_id, chunk_data in data.items():
|
||||
if "llm_cache_list" not in chunk_data:
|
||||
chunk_data["llm_cache_list"] = []
|
||||
# Add timestamps to data based on whether key exists
|
||||
for k, v in data.items():
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if "text_chunks" in self.namespace:
|
||||
if "llm_cache_list" not in v:
|
||||
v["llm_cache_list"] = []
|
||||
|
||||
# Add timestamps based on whether key exists
|
||||
if k in self._data: # Key exists, only update update_time
|
||||
v["update_time"] = current_time
|
||||
else: # New key, set both create_time and update_time
|
||||
v["create_time"] = current_time
|
||||
v["update_time"] = current_time
|
||||
|
||||
v["_id"] = k
|
||||
|
||||
self._data.update(data)
|
||||
await set_all_update_flags(self.namespace)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,11 +98,21 @@ class MongoKVStorage(BaseKVStorage):
|
|||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
# Unified handling for flattened keys
|
||||
return await self._data.find_one({"_id": id})
|
||||
doc = await self._data.find_one({"_id": id})
|
||||
if doc:
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
doc.setdefault("create_time", 0)
|
||||
doc.setdefault("update_time", 0)
|
||||
return doc
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
cursor = self._data.find({"_id": {"$in": ids}})
|
||||
return await cursor.to_list()
|
||||
docs = await cursor.to_list()
|
||||
# Ensure time fields are present for all documents
|
||||
for doc in docs:
|
||||
doc.setdefault("create_time", 0)
|
||||
doc.setdefault("update_time", 0)
|
||||
return docs
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
|
||||
|
|
@ -119,6 +129,9 @@ class MongoKVStorage(BaseKVStorage):
|
|||
result = {}
|
||||
async for doc in cursor:
|
||||
doc_id = doc.pop("_id")
|
||||
# Ensure time fields are present for all documents
|
||||
doc.setdefault("create_time", 0)
|
||||
doc.setdefault("update_time", 0)
|
||||
result[doc_id] = doc
|
||||
return result
|
||||
|
||||
|
|
@ -132,6 +145,8 @@ class MongoKVStorage(BaseKVStorage):
|
|||
from pymongo import UpdateOne
|
||||
|
||||
operations = []
|
||||
current_time = int(time.time()) # Get current Unix timestamp
|
||||
|
||||
for k, v in data.items():
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if self.namespace.endswith("text_chunks"):
|
||||
|
|
@ -139,7 +154,20 @@ class MongoKVStorage(BaseKVStorage):
|
|||
v["llm_cache_list"] = []
|
||||
|
||||
v["_id"] = k # Use flattened key as _id
|
||||
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
|
||||
v["update_time"] = current_time # Always update update_time
|
||||
|
||||
operations.append(
|
||||
UpdateOne(
|
||||
{"_id": k},
|
||||
{
|
||||
"$set": v, # Update all fields including update_time
|
||||
"$setOnInsert": {
|
||||
"create_time": current_time
|
||||
}, # Set create_time only on insert
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
)
|
||||
|
||||
if operations:
|
||||
await self._data.bulk_write(operations)
|
||||
|
|
|
|||
|
|
@ -752,6 +752,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
"original_prompt": row.get("original_prompt", ""),
|
||||
"chunk_id": row.get("chunk_id"),
|
||||
"mode": row.get("mode", "default"),
|
||||
"create_time": row.get("create_time", 0),
|
||||
"update_time": row.get("update_time", 0),
|
||||
}
|
||||
processed_results[row["id"]] = processed_row
|
||||
return processed_results
|
||||
|
|
@ -767,6 +769,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
except json.JSONDecodeError:
|
||||
llm_cache_list = []
|
||||
row["llm_cache_list"] = llm_cache_list
|
||||
row["create_time"] = row.get("create_time", 0)
|
||||
row["update_time"] = row.get("update_time", 0)
|
||||
processed_results[row["id"]] = row
|
||||
return processed_results
|
||||
|
||||
|
|
@ -791,6 +795,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
except json.JSONDecodeError:
|
||||
llm_cache_list = []
|
||||
response["llm_cache_list"] = llm_cache_list
|
||||
response["create_time"] = response.get("create_time", 0)
|
||||
response["update_time"] = response.get("update_time", 0)
|
||||
|
||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||
if response and is_namespace(
|
||||
|
|
@ -804,6 +810,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
"original_prompt": response.get("original_prompt", ""),
|
||||
"chunk_id": response.get("chunk_id"),
|
||||
"mode": response.get("mode", "default"),
|
||||
"create_time": response.get("create_time", 0),
|
||||
"update_time": response.get("update_time", 0),
|
||||
}
|
||||
|
||||
return response if response else None
|
||||
|
|
@ -827,6 +835,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
except json.JSONDecodeError:
|
||||
llm_cache_list = []
|
||||
result["llm_cache_list"] = llm_cache_list
|
||||
result["create_time"] = result.get("create_time", 0)
|
||||
result["update_time"] = result.get("update_time", 0)
|
||||
|
||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||
if results and is_namespace(
|
||||
|
|
@ -842,6 +852,8 @@ class PGKVStorage(BaseKVStorage):
|
|||
"original_prompt": row.get("original_prompt", ""),
|
||||
"chunk_id": row.get("chunk_id"),
|
||||
"mode": row.get("mode", "default"),
|
||||
"create_time": row.get("create_time", 0),
|
||||
"update_time": row.get("update_time", 0),
|
||||
}
|
||||
processed_results.append(processed_row)
|
||||
return processed_results
|
||||
|
|
@ -2941,10 +2953,12 @@ SQL_TEMPLATES = {
|
|||
""",
|
||||
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||
chunk_order_index, full_doc_id, file_path,
|
||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
|
||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||
create_time, update_time
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||
create_time, update_time
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
||||
""",
|
||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
|
||||
|
|
@ -2955,10 +2969,12 @@ SQL_TEMPLATES = {
|
|||
""",
|
||||
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||
chunk_order_index, full_doc_id, file_path,
|
||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
|
||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||
create_time, update_time
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
||||
""",
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
|
||||
create_time, update_time
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||
""",
|
||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||
|
|
|
|||
|
|
@ -132,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
|
|||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
data = await redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
if data:
|
||||
result = json.loads(data)
|
||||
# Ensure time fields are present, provide default values for old data
|
||||
result.setdefault("create_time", 0)
|
||||
result.setdefault("update_time", 0)
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error for id {id}: {e}")
|
||||
return None
|
||||
|
|
@ -144,7 +150,19 @@ class RedisKVStorage(BaseKVStorage):
|
|||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
results = await pipe.execute()
|
||||
return [json.loads(result) if result else None for result in results]
|
||||
|
||||
processed_results = []
|
||||
for result in results:
|
||||
if result:
|
||||
data = json.loads(result)
|
||||
# Ensure time fields are present for all documents
|
||||
data.setdefault("create_time", 0)
|
||||
data.setdefault("update_time", 0)
|
||||
processed_results.append(data)
|
||||
else:
|
||||
processed_results.append(None)
|
||||
|
||||
return processed_results
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error in batch get: {e}")
|
||||
return [None] * len(ids)
|
||||
|
|
@ -176,7 +194,11 @@ class RedisKVStorage(BaseKVStorage):
|
|||
# Extract the ID part (after namespace:)
|
||||
key_id = key.split(":", 1)[1]
|
||||
try:
|
||||
result[key_id] = json.loads(value)
|
||||
data = json.loads(value)
|
||||
# Ensure time fields are present for all documents
|
||||
data.setdefault("create_time", 0)
|
||||
data.setdefault("update_time", 0)
|
||||
result[key_id] = data
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON decode error for key {key}: {e}")
|
||||
continue
|
||||
|
|
@ -200,21 +222,41 @@ class RedisKVStorage(BaseKVStorage):
|
|||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
import time
|
||||
|
||||
current_time = int(time.time()) # Get current Unix timestamp
|
||||
|
||||
async with self._get_redis_connection() as redis:
|
||||
try:
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if "text_chunks" in self.namespace:
|
||||
for chunk_id, chunk_data in data.items():
|
||||
if "llm_cache_list" not in chunk_data:
|
||||
chunk_data["llm_cache_list"] = []
|
||||
# Check which keys already exist to determine create vs update
|
||||
pipe = redis.pipeline()
|
||||
for k in data.keys():
|
||||
pipe.exists(f"{self.namespace}:{k}")
|
||||
exists_results = await pipe.execute()
|
||||
|
||||
# Add timestamps to data
|
||||
for i, (k, v) in enumerate(data.items()):
|
||||
# For text_chunks namespace, ensure llm_cache_list field exists
|
||||
if "text_chunks" in self.namespace:
|
||||
if "llm_cache_list" not in v:
|
||||
v["llm_cache_list"] = []
|
||||
|
||||
# Add timestamps based on whether key exists
|
||||
if exists_results[i]: # Key exists, only update update_time
|
||||
v["update_time"] = current_time
|
||||
else: # New key, set both create_time and update_time
|
||||
v["create_time"] = current_time
|
||||
v["update_time"] = current_time
|
||||
|
||||
v["_id"] = k
|
||||
|
||||
# Store the data
|
||||
pipe = redis.pipeline()
|
||||
for k, v in data.items():
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
await pipe.execute()
|
||||
|
||||
for k in data:
|
||||
data[k]["_id"] = k
|
||||
except json.JSONEncodeError as e:
|
||||
logger.error(f"JSON encode error during upsert: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -273,8 +273,6 @@ async def _rebuild_knowledge_from_chunks(
|
|||
all_referenced_chunk_ids.update(chunk_ids)
|
||||
for chunk_ids in relationships_to_rebuild.values():
|
||||
all_referenced_chunk_ids.update(chunk_ids)
|
||||
# sort all_referenced_chunk_ids to get a stable order in merge stage
|
||||
all_referenced_chunk_ids = sorted(all_referenced_chunk_ids)
|
||||
|
||||
status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
|
||||
logger.info(status_message)
|
||||
|
|
@ -464,12 +462,22 @@ async def _get_cached_extraction_results(
|
|||
):
|
||||
chunk_id = cache_entry["chunk_id"]
|
||||
extraction_result = cache_entry["return"]
|
||||
create_time = cache_entry.get(
|
||||
"create_time", 0
|
||||
) # Get creation time, default to 0
|
||||
valid_entries += 1
|
||||
|
||||
# Support multiple LLM caches per chunk
|
||||
if chunk_id not in cached_results:
|
||||
cached_results[chunk_id] = []
|
||||
cached_results[chunk_id].append(extraction_result)
|
||||
# Store tuple with extraction result and creation time for sorting
|
||||
cached_results[chunk_id].append((extraction_result, create_time))
|
||||
|
||||
# Sort extraction results by create_time for each chunk
|
||||
for chunk_id in cached_results:
|
||||
# Sort by create_time (x[1]), then extract only extraction_result (x[0])
|
||||
cached_results[chunk_id].sort(key=lambda x: x[1])
|
||||
cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
|
||||
|
||||
logger.info(
|
||||
f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue