diff --git a/env.example b/env.example index 19153ecc..98e4790b 100644 --- a/env.example +++ b/env.example @@ -114,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage -### TiDB Configuration (Deprecated) -# TIDB_HOST=localhost -# TIDB_PORT=4000 -# TIDB_USER=your_username -# TIDB_PASSWORD='your_password' -# TIDB_DATABASE=your_database -### separating all data from difference Lightrag instances(deprecating) -# TIDB_WORKSPACE=default - ### PostgreSQL Configuration POSTGRES_HOST=localhost POSTGRES_PORT=5432 @@ -130,7 +121,7 @@ POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database POSTGRES_MAX_CONNECTIONS=12 -### separating all data from difference Lightrag instances(deprecating) +### separating all data from difference Lightrag instances # POSTGRES_WORKSPACE=default ### Neo4j Configuration @@ -146,14 +137,15 @@ NEO4J_PASSWORD='your_password' # AGE_POSTGRES_PORT=8529 # AGE Graph Name(apply to PostgreSQL and independent AGM) -### AGE_GRAPH_NAME is precated +### AGE_GRAPH_NAME is deprecated # AGE_GRAPH_NAME=lightrag ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_DATABASE=LightRAG ### separating all data from difference Lightrag instances(deprecating) -# MONGODB_GRAPH=false +### separating all data from difference Lightrag instances +# MONGODB_WORKSPACE=default ### Milvus Configuration MILVUS_URI=http://localhost:19530 diff --git a/lightrag/base.py b/lightrag/base.py index 36c3ff59..7820b4da 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -634,6 +634,8 @@ class DocProcessingStatus: """ISO format timestamp when document was last updated""" chunks_count: int | None = None """Number of chunks after splitting, used for processing""" + chunks_list: list[str] | None = field(default_factory=list) + """List of chunk IDs associated with this document, used for deletion""" error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index f8387ad8..ab6ab390 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage): return logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index d6e2cb70..98835f8c 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -78,22 +78,49 @@ class JsonKVStorage(BaseKVStorage): Dictionary containing all stored data """ async with self._storage_lock: - return dict(self._data) + result = {} + for key, value in self._data.items(): + if value: + # Create a copy to avoid modifying the original data + data = dict(value) + # Ensure time fields are present, provide default values for old data + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + result[key] = data + else: + result[key] = value + return result async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: - return self._data.get(id) + result = self._data.get(id) + if result: + # Create a copy to avoid modifying the original data + result = dict(result) + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + # Ensure _id field contains the clean ID + result["_id"] = id + return result async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async with self._storage_lock: - return [ - ( - {k: v for k, v in self._data[id].items()} - if self._data.get(id, None) - else None - ) - for id in ids - ] + results = [] + for id in ids: + data = self._data.get(id, None) + if data: + # Create a copy to avoid modifying the original data + result = {k: v for k, v in data.items()} + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + # Ensure _id field contains the clean ID + result["_id"] = id + results.append(result) + else: + results.append(None) + return results async def filter_keys(self, keys: set[str]) -> set[str]: async with self._storage_lock: @@ -107,8 +134,29 @@ class JsonKVStorage(BaseKVStorage): """ if not data: return + + import time + + current_time = int(time.time()) # Get current Unix timestamp + logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # Add timestamps to data based on whether key exists + for k, v in data.items(): + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + + # Add timestamps based on whether key exists + if k in self._data: # Key exists, only update update_time + v["update_time"] = current_time + else: # New key, set both create_time and update_time + v["create_time"] = current_time + v["update_time"] = current_time + + v["_id"] = k + self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 38baff5c..11105e82 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -98,11 +98,21 @@ class MongoKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> dict[str, Any] | None: # Unified handling for flattened keys - return await self._data.find_one({"_id": id}) + doc = await self._data.find_one({"_id": id}) + if doc: + # Ensure time fields are present, provide default values for old data + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) + return doc async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) - return await cursor.to_list() + docs = await cursor.to_list() + # Ensure time fields are present for all documents + for doc in docs: + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) + return docs async def filter_keys(self, keys: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) @@ -119,6 +129,9 @@ class MongoKVStorage(BaseKVStorage): result = {} async for doc in cursor: doc_id = doc.pop("_id") + # Ensure time fields are present for all documents + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) result[doc_id] = doc return result @@ -132,9 +145,29 @@ class MongoKVStorage(BaseKVStorage): from pymongo import UpdateOne operations = [] + current_time = int(time.time()) # Get current Unix timestamp + for k, v in data.items(): + # For text_chunks namespace, ensure llm_cache_list field exists + if self.namespace.endswith("text_chunks"): + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + v["_id"] = k # Use flattened key as _id - operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True)) + v["update_time"] = current_time # Always update update_time + + operations.append( + UpdateOne( + {"_id": k}, + { + "$set": v, # Update all fields including update_time + "$setOnInsert": { + "create_time": current_time + }, # Set create_time only on insert + }, + upsert=True, + ) + ) if operations: await self._data.bulk_write(operations) @@ -247,6 +280,9 @@ class MongoDocStatusStorage(DocStatusStorage): return update_tasks: list[Any] = [] for k, v in data.items(): + # Ensure chunks_list field exists and is an array + if "chunks_list" not in v: + v["chunks_list"] = [] data[k]["_id"] = k update_tasks.append( self._data.update_one({"_id": k}, {"$set": v}, upsert=True) @@ -279,6 +315,7 @@ class MongoDocStatusStorage(DocStatusStorage): updated_at=doc.get("updated_at"), chunks_count=doc.get("chunks_count", -1), file_path=doc.get("file_path", doc["_id"]), + chunks_list=doc.get("chunks_list", []), ) for doc in result } diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 28a86b6e..dc9f293c 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -136,6 +136,52 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}") + async def _migrate_llm_cache_add_cache_type(self): + """Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist""" + try: + # Check if cache_type column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_llm_cache' + AND column_name = 'cache_type' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD COLUMN cache_type VARCHAR(32) NULL + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table" + ) + + # Migrate existing data: extract cache_type from flattened keys + logger.info( + "Migrating existing LLM cache data to populate cache_type field" + ) + update_sql = """ + UPDATE LIGHTRAG_LLM_CACHE + SET cache_type = CASE + WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2) + ELSE 'extract' + END + WHERE cache_type IS NULL + """ + await self.execute(update_sql) + logger.info("Successfully migrated existing LLM cache data") + else: + logger.info( + "cache_type column already exists in LIGHTRAG_LLM_CACHE table" + ) + except Exception as e: + logger.warning( + f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}" + ) + async def _migrate_timestamp_columns(self): """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time""" # Tables and columns that need migration @@ -301,15 +347,17 @@ class PostgreSQLDB: record["mode"], record["original_prompt"] ) + # Determine cache_type based on mode + cache_type = "extract" if record["mode"] == "default" else "unknown" + # Generate new flattened key - cache_type = "extract" # Default type new_key = f"{record['mode']}:{cache_type}:{new_hash}" - # Insert new format data + # Insert new format data with cache_type field insert_sql = """ INSERT INTO LIGHTRAG_LLM_CACHE - (workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + (workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (workspace, mode, id) DO NOTHING """ @@ -322,6 +370,7 @@ class PostgreSQLDB: "original_prompt": record["original_prompt"], "return_value": record["return_value"], "chunk_id": record["chunk_id"], + "cache_type": cache_type, # Add cache_type field "create_time": record["create_time"], "update_time": record["update_time"], }, @@ -357,6 +406,68 @@ class PostgreSQLDB: logger.error(f"LLM cache migration failed: {e}") # Don't raise exception, allow system to continue startup + async def _migrate_doc_status_add_chunks_list(self): + """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist""" + try: + # Check if chunks_list column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_doc_status' + AND column_name = 'chunks_list' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_DOC_STATUS + ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table" + ) + else: + logger.info( + "chunks_list column already exists in LIGHTRAG_DOC_STATUS table" + ) + except Exception as e: + logger.warning( + f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}" + ) + + async def _migrate_text_chunks_add_llm_cache_list(self): + """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist""" + try: + # Check if llm_cache_list column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_doc_chunks' + AND column_name = 'llm_cache_list' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_DOC_CHUNKS + ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table" + ) + else: + logger.info( + "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table" + ) + except Exception as e: + logger.warning( + f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}" + ) + async def check_tables(self): # First create all tables for k, v in TABLES.items(): @@ -408,6 +519,15 @@ class PostgreSQLDB: logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}") # Don't throw an exception, allow the initialization process to continue + # Migrate LLM cache table to add cache_type field if needed + try: + await self._migrate_llm_cache_add_cache_type() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}" + ) + # Don't throw an exception, allow the initialization process to continue + # Finally, attempt to migrate old doc chunks data if needed try: await self._migrate_doc_chunks_to_vdb_chunks() @@ -421,6 +541,22 @@ class PostgreSQLDB: except Exception as e: logger.error(f"PostgreSQL, LLM cache migration failed: {e}") + # Migrate doc status to add chunks_list field if needed + try: + await self._migrate_doc_status_add_chunks_list() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}" + ) + + # Migrate text chunks to add llm_cache_list field if needed + try: + await self._migrate_text_chunks_add_llm_cache_list() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}" + ) + async def query( self, sql: str, @@ -608,24 +744,36 @@ class PGKVStorage(BaseKVStorage): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): processed_results = {} for row in results: - # Parse flattened key to extract cache_type - key_parts = row["id"].split(":") - cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown" - # Map field names and add cache_type for compatibility processed_row = { **row, - "return": row.get( - "return_value", "" - ), # Map return_value to return - "cache_type": cache_type, # Add cache_type from key + "return": row.get("return_value", ""), + "cache_type": row.get("original_prompt", "unknow"), "original_prompt": row.get("original_prompt", ""), "chunk_id": row.get("chunk_id"), "mode": row.get("mode", "default"), + "create_time": row.get("create_time", 0), + "update_time": row.get("update_time", 0), } processed_results[row["id"]] = processed_row return processed_results + # For text_chunks namespace, parse llm_cache_list JSON string back to list + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + processed_results = {} + for row in results: + llm_cache_list = row.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + row["llm_cache_list"] = llm_cache_list + row["create_time"] = row.get("create_time", 0) + row["update_time"] = row.get("update_time", 0) + processed_results[row["id"]] = row + return processed_results + # For other namespaces, return as-is return {row["id"]: row for row in results} except Exception as e: @@ -637,6 +785,35 @@ class PGKVStorage(BaseKVStorage): sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} response = await self.db.query(sql, params) + + if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + # Parse llm_cache_list JSON string back to list + llm_cache_list = response.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + response["llm_cache_list"] = llm_cache_list + response["create_time"] = response.get("create_time", 0) + response["update_time"] = response.get("update_time", 0) + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results + if response and is_namespace( + self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ): + # Map field names and add cache_type for compatibility + response = { + **response, + "return": response.get("return_value", ""), + "cache_type": response.get("cache_type"), + "original_prompt": response.get("original_prompt", ""), + "chunk_id": response.get("chunk_id"), + "mode": response.get("mode", "default"), + "create_time": response.get("create_time", 0), + "update_time": response.get("update_time", 0), + } + return response if response else None # Query by id @@ -646,13 +823,42 @@ class PGKVStorage(BaseKVStorage): ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} - return await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, params, multirows=True) - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) + if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + # Parse llm_cache_list JSON string back to list for each result + for result in results: + llm_cache_list = result.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + result["llm_cache_list"] = llm_cache_list + result["create_time"] = result.get("create_time", 0) + result["update_time"] = result.get("update_time", 0) + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results + if results and is_namespace( + self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ): + processed_results = [] + for row in results: + # Map field names and add cache_type for compatibility + processed_row = { + **row, + "return": row.get("return_value", ""), + "cache_type": row.get("cache_type"), + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "mode": row.get("mode", "default"), + "create_time": row.get("create_time", 0), + "update_time": row.get("update_time", 0), + } + processed_results.append(processed_row) + return processed_results + + return results if results else [] async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -693,6 +899,7 @@ class PGKVStorage(BaseKVStorage): "full_doc_id": v["full_doc_id"], "content": v["content"], "file_path": v["file_path"], + "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), "create_time": current_time, "update_time": current_time, } @@ -716,6 +923,9 @@ class PGKVStorage(BaseKVStorage): "return_value": v["return"], "mode": v.get("mode", "default"), # Get mode from data "chunk_id": v.get("chunk_id"), + "cache_type": v.get( + "cache_type", "extract" + ), # Get cache_type from data } await self.db.execute(upsert_sql, _data) @@ -1140,6 +1350,14 @@ class PGDocStatusStorage(DocStatusStorage): if result is None or result == []: return None else: + # Parse chunks_list JSON string back to list + chunks_list = result[0].get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + return dict( content=result[0]["content"], content_length=result[0]["content_length"], @@ -1149,6 +1367,7 @@ class PGDocStatusStorage(DocStatusStorage): created_at=result[0]["created_at"], updated_at=result[0]["updated_at"], file_path=result[0]["file_path"], + chunks_list=chunks_list, ) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -1163,19 +1382,32 @@ class PGDocStatusStorage(DocStatusStorage): if not results: return [] - return [ - { - "content": row["content"], - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - "file_path": row["file_path"], - } - for row in results - ] + + processed_results = [] + for row in results: + # Parse chunks_list JSON string back to list + chunks_list = row.get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + + processed_results.append( + { + "content": row["content"], + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "file_path": row["file_path"], + "chunks_list": chunks_list, + } + ) + + return processed_results async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -1196,8 +1428,18 @@ class PGDocStatusStorage(DocStatusStorage): sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status.value} result = await self.db.query(sql, params, True) - docs_by_status = { - element["id"]: DocProcessingStatus( + + docs_by_status = {} + for element in result: + # Parse chunks_list JSON string back to list + chunks_list = element.get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + + docs_by_status[element["id"]] = DocProcessingStatus( content=element["content"], content_summary=element["content_summary"], content_length=element["content_length"], @@ -1206,9 +1448,9 @@ class PGDocStatusStorage(DocStatusStorage): updated_at=element["updated_at"], chunks_count=element["chunks_count"], file_path=element["file_path"], + chunks_list=chunks_list, ) - for element in result - } + return docs_by_status async def index_done_callback(self) -> None: @@ -1272,10 +1514,10 @@ class PGDocStatusStorage(DocStatusStorage): logger.warning(f"Unable to parse datetime string: {dt_str}") return None - # Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations - # Both fields are updated from the input data in both INSERT and UPDATE cases - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at) - values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) + # Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations + # All fields are updated from the input data in both INSERT and UPDATE cases + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at) + values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) on conflict(id,workspace) do update set content = EXCLUDED.content, content_summary = EXCLUDED.content_summary, @@ -1283,6 +1525,7 @@ class PGDocStatusStorage(DocStatusStorage): chunks_count = EXCLUDED.chunks_count, status = EXCLUDED.status, file_path = EXCLUDED.file_path, + chunks_list = EXCLUDED.chunks_list, created_at = EXCLUDED.created_at, updated_at = EXCLUDED.updated_at""" for k, v in data.items(): @@ -1290,7 +1533,7 @@ class PGDocStatusStorage(DocStatusStorage): created_at = parse_datetime(v.get("created_at")) updated_at = parse_datetime(v.get("updated_at")) - # chunks_count is optional + # chunks_count and chunks_list are optional await self.db.execute( sql, { @@ -1302,6 +1545,7 @@ class PGDocStatusStorage(DocStatusStorage): "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, "status": v["status"], "file_path": v["file_path"], + "chunks_list": json.dumps(v.get("chunks_list", [])), "created_at": created_at, # Use the converted datetime object "updated_at": updated_at, # Use the converted datetime object }, @@ -2620,6 +2864,7 @@ TABLES = { tokens INTEGER, content TEXT, file_path VARCHAR(256), + llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, create_time TIMESTAMP(0) WITH TIME ZONE, update_time TIMESTAMP(0) WITH TIME ZONE, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -2692,6 +2937,7 @@ TABLES = { chunks_count int4 NULL, status varchar(64) NULL, file_path TEXT NULL, + chunks_list JSONB NULL DEFAULT '[]'::jsonb, created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) @@ -2706,24 +2952,30 @@ SQL_TEMPLATES = { FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 """, "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path + chunk_order_index, full_doc_id, file_path, + COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, + create_time, update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + create_time, update_time FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 """, - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path + chunk_order_index, full_doc_id, file_path, + COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, + create_time, update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + create_time, update_time + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) @@ -2731,25 +2983,27 @@ SQL_TEMPLATES = { ON CONFLICT (workspace,id) DO UPDATE SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id) - VALUES ($1, $2, $3, $4, $5, $6) + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (workspace,mode,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, return_value=EXCLUDED.return_value, mode=EXCLUDED.mode, chunk_id=EXCLUDED.chunk_id, + cache_type=EXCLUDED.cache_type, update_time = CURRENT_TIMESTAMP """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, file_path, + chunk_order_index, full_doc_id, content, file_path, llm_cache_list, create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, full_doc_id=EXCLUDED.full_doc_id, content = EXCLUDED.content, file_path=EXCLUDED.file_path, + llm_cache_list=EXCLUDED.llm_cache_list, update_time = EXCLUDED.update_time """, # SQL for VectorStorage diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 5be9f0e6..dba228ca 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -132,7 +132,13 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: try: data = await redis.get(f"{self.namespace}:{id}") - return json.loads(data) if data else None + if data: + result = json.loads(data) + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + return result + return None except json.JSONDecodeError as e: logger.error(f"JSON decode error for id {id}: {e}") return None @@ -144,7 +150,19 @@ class RedisKVStorage(BaseKVStorage): for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() - return [json.loads(result) if result else None for result in results] + + processed_results = [] + for result in results: + if result: + data = json.loads(result) + # Ensure time fields are present for all documents + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + processed_results.append(data) + else: + processed_results.append(None) + + return processed_results except json.JSONDecodeError as e: logger.error(f"JSON decode error in batch get: {e}") return [None] * len(ids) @@ -176,7 +194,11 @@ class RedisKVStorage(BaseKVStorage): # Extract the ID part (after namespace:) key_id = key.split(":", 1)[1] try: - result[key_id] = json.loads(value) + data = json.loads(value) + # Ensure time fields are present for all documents + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + result[key_id] = data except json.JSONDecodeError as e: logger.error(f"JSON decode error for key {key}: {e}") continue @@ -200,15 +222,41 @@ class RedisKVStorage(BaseKVStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return + + import time + + current_time = int(time.time()) # Get current Unix timestamp + async with self._get_redis_connection() as redis: try: + # Check which keys already exist to determine create vs update + pipe = redis.pipeline() + for k in data.keys(): + pipe.exists(f"{self.namespace}:{k}") + exists_results = await pipe.execute() + + # Add timestamps to data + for i, (k, v) in enumerate(data.items()): + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + + # Add timestamps based on whether key exists + if exists_results[i]: # Key exists, only update update_time + v["update_time"] = current_time + else: # New key, set both create_time and update_time + v["create_time"] = current_time + v["update_time"] = current_time + + v["_id"] = k + + # Store the data pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) await pipe.execute() - for k in data: - data[k]["_id"] = k except json.JSONEncodeError as e: logger.error(f"JSON encode error during upsert: {e}") raise @@ -601,6 +649,11 @@ class RedisDocStatusStorage(DocStatusStorage): logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._get_redis_connection() as redis: try: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] + pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 06ec1cd5..d60bb1f6 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -520,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage): } await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) - async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs from the storage. diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 132075d6..2ab9f89a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -349,6 +349,7 @@ class LightRAG: # Fix global_config now global_config = asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -952,6 +953,7 @@ class LightRAG: **dp, "full_doc_id": doc_id, "file_path": file_path, # Add file path to each chunk + "llm_cache_list": [], # Initialize empty LLM cache list for each chunk } for dp in self.chunking_func( self.tokenizer, @@ -963,14 +965,17 @@ class LightRAG: ) } - # Process document (text chunks and full docs) in parallel - # Create tasks with references for potential cancellation + # Process document in two stages + # Stage 1: Process text chunks and docs (parallel execution) doc_status_task = asyncio.create_task( self.doc_status.upsert( { doc_id: { "status": DocStatus.PROCESSING, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # Save chunks list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -986,11 +991,6 @@ class LightRAG: chunks_vdb_task = asyncio.create_task( self.chunks_vdb.upsert(chunks) ) - entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph( - chunks, pipeline_status, pipeline_status_lock - ) - ) full_docs_task = asyncio.create_task( self.full_docs.upsert( {doc_id: {"content": status_doc.content}} @@ -999,14 +999,26 @@ class LightRAG: text_chunks_task = asyncio.create_task( self.text_chunks.upsert(chunks) ) - tasks = [ + + # First stage tasks (parallel execution) + first_stage_tasks = [ doc_status_task, chunks_vdb_task, - entity_relation_task, full_docs_task, text_chunks_task, ] - await asyncio.gather(*tasks) + entity_relation_task = None + + # Execute first stage tasks + await asyncio.gather(*first_stage_tasks) + + # Stage 2: Process entity relation graph (after text_chunks are saved) + entity_relation_task = asyncio.create_task( + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock + ) + ) + await entity_relation_task file_extraction_stage_ok = True except Exception as e: @@ -1021,14 +1033,14 @@ class LightRAG: ) pipeline_status["history_messages"].append(error_msg) - # Cancel other tasks as they are no longer meaningful - for task in [ - chunks_vdb_task, - entity_relation_task, - full_docs_task, - text_chunks_task, - ]: - if not task.done(): + # Cancel tasks that are not yet completed + all_tasks = first_stage_tasks + ( + [entity_relation_task] + if entity_relation_task + else [] + ) + for task in all_tasks: + if task and not task.done(): task.cancel() # Persistent llm cache @@ -1078,6 +1090,9 @@ class LightRAG: doc_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # 保留 chunks_list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -1196,6 +1211,7 @@ class LightRAG: pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, + text_chunks_storage=self.text_chunks, ) return chunk_results except Exception as e: @@ -1726,28 +1742,10 @@ class LightRAG: file_path="", ) - # 2. Get all chunks related to this document - try: - all_chunks = await self.text_chunks.get_all() - related_chunks = { - chunk_id: chunk_data - for chunk_id, chunk_data in all_chunks.items() - if isinstance(chunk_data, dict) - and chunk_data.get("full_doc_id") == doc_id - } + # 2. Get chunk IDs from document status + chunk_ids = set(doc_status_data.get("chunks_list", [])) - # Update pipeline status after getting chunks count - async with pipeline_status_lock: - log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}") - raise Exception(f"Failed to retrieve document chunks: {e}") from e - - if not related_chunks: + if not chunk_ids: logger.warning(f"No chunks found for document {doc_id}") # Mark that deletion operations have started deletion_operations_started = True @@ -1778,7 +1776,6 @@ class LightRAG: file_path=file_path, ) - chunk_ids = set(related_chunks.keys()) # Mark that deletion operations have started deletion_operations_started = True @@ -1802,26 +1799,12 @@ class LightRAG: ) ) - # Update pipeline status after getting affected_nodes - async with pipeline_status_lock: - log_message = f"Found {len(affected_nodes)} affected entities" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - affected_edges = ( await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( list(chunk_ids) ) ) - # Update pipeline status after getting affected_edges - async with pipeline_status_lock: - log_message = f"Found {len(affected_edges)} affected relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - except Exception as e: logger.error(f"Failed to analyze affected graph elements: {e}") raise Exception(f"Failed to analyze graph dependencies: {e}") from e @@ -1839,6 +1822,14 @@ class LightRAG: elif remaining_sources != sources: entities_to_rebuild[node_label] = remaining_sources + async with pipeline_status_lock: + log_message = ( + f"Found {len(entities_to_rebuild)} affected entities" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + # Process relationships for edge_data in affected_edges: src = edge_data.get("source") @@ -1860,6 +1851,14 @@ class LightRAG: elif remaining_sources != sources: relationships_to_rebuild[edge_tuple] = remaining_sources + async with pipeline_status_lock: + log_message = ( + f"Found {len(relationships_to_rebuild)} affected relations" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + except Exception as e: logger.error(f"Failed to process graph analysis results: {e}") raise Exception(f"Failed to process graph dependencies: {e}") from e @@ -1943,7 +1942,7 @@ class LightRAG: knowledge_graph_inst=self.chunk_entity_relation_graph, entities_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, - text_chunks=self.text_chunks, + text_chunks_storage=self.text_chunks, llm_response_cache=self.llm_response_cache, global_config=asdict(self), pipeline_status=pipeline_status, diff --git a/lightrag/operate.py b/lightrag/operate.py index bd70ceed..60425148 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -25,6 +25,7 @@ from .utils import ( CacheData, get_conversation_turns, use_llm_func_with_cache, + update_chunk_cache_list, ) from .base import ( BaseGraphStorage, @@ -103,8 +104,6 @@ async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, global_config: dict, - pipeline_status: dict = None, - pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Handle entity relation summary @@ -247,7 +246,7 @@ async def _rebuild_knowledge_from_chunks( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks: BaseKVStorage, + text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, global_config: dict[str, str], pipeline_status: dict | None = None, @@ -261,6 +260,7 @@ async def _rebuild_knowledge_from_chunks( Args: entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids + text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data} """ if not entities_to_rebuild and not relationships_to_rebuild: return @@ -281,9 +281,12 @@ async def _rebuild_knowledge_from_chunks( pipeline_status["latest_message"] = status_message pipeline_status["history_messages"].append(status_message) - # Get cached extraction results for these chunks + # Get cached extraction results for these chunks using storage + # cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at] cached_results = await _get_cached_extraction_results( - llm_response_cache, all_referenced_chunk_ids + llm_response_cache, + all_referenced_chunk_ids, + text_chunks_storage=text_chunks_storage, ) if not cached_results: @@ -299,15 +302,37 @@ async def _rebuild_knowledge_from_chunks( chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} - for chunk_id, extraction_result in cached_results.items(): + for chunk_id, extraction_results in cached_results.items(): try: - entities, relationships = await _parse_extraction_result( - text_chunks=text_chunks, - extraction_result=extraction_result, - chunk_id=chunk_id, - ) - chunk_entities[chunk_id] = entities - chunk_relationships[chunk_id] = relationships + # Handle multiple extraction results per chunk + chunk_entities[chunk_id] = defaultdict(list) + chunk_relationships[chunk_id] = defaultdict(list) + + # process multiple LLM extraction results for a single chunk_id + for extraction_result in extraction_results: + entities, relationships = await _parse_extraction_result( + text_chunks_storage=text_chunks_storage, + extraction_result=extraction_result, + chunk_id=chunk_id, + ) + + # Merge entities and relationships from this extraction result + # Only keep the first occurrence of each entity_name in the same chunk_id + for entity_name, entity_list in entities.items(): + if ( + entity_name not in chunk_entities[chunk_id] + or len(chunk_entities[chunk_id][entity_name]) == 0 + ): + chunk_entities[chunk_id][entity_name].extend(entity_list) + + # Only keep the first occurrence of each rel_key in the same chunk_id + for rel_key, rel_list in relationships.items(): + if ( + rel_key not in chunk_relationships[chunk_id] + or len(chunk_relationships[chunk_id][rel_key]) == 0 + ): + chunk_relationships[chunk_id][rel_key].extend(rel_list) + except Exception as e: status_message = ( f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" @@ -387,43 +412,86 @@ async def _rebuild_knowledge_from_chunks( async def _get_cached_extraction_results( - llm_response_cache: BaseKVStorage, chunk_ids: set[str] -) -> dict[str, str]: + llm_response_cache: BaseKVStorage, + chunk_ids: set[str], + text_chunks_storage: BaseKVStorage, +) -> dict[str, list[str]]: """Get cached extraction results for specific chunk IDs Args: + llm_response_cache: LLM response cache storage chunk_ids: Set of chunk IDs to get cached results for + text_chunks_data: Pre-loaded chunk data (optional, for performance) + text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None) Returns: - Dict mapping chunk_id -> extraction_result_text + Dict mapping chunk_id -> list of extraction_result_text """ cached_results = {} - # Get all cached data (flattened cache structure) - all_cache = await llm_response_cache.get_all() + # Collect all LLM cache IDs from chunks + all_cache_ids = set() - for cache_key, cache_entry in all_cache.items(): + # Read from storage + chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) + for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list): + if chunk_data and isinstance(chunk_data, dict): + llm_cache_list = chunk_data.get("llm_cache_list", []) + if llm_cache_list: + all_cache_ids.update(llm_cache_list) + else: + logger.warning( + f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}" + ) + + if not all_cache_ids: + logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") + return cached_results + + # Batch get LLM cache entries + cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) + + # Process cache entries and group by chunk_id + valid_entries = 0 + for cache_id, cache_entry in zip(all_cache_ids, cache_data_list): if ( - isinstance(cache_entry, dict) + cache_entry is not None + and isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" and cache_entry.get("chunk_id") in chunk_ids ): chunk_id = cache_entry["chunk_id"] extraction_result = cache_entry["return"] - cached_results[chunk_id] = extraction_result + create_time = cache_entry.get( + "create_time", 0 + ) # Get creation time, default to 0 + valid_entries += 1 - logger.debug( - f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs" + # Support multiple LLM caches per chunk + if chunk_id not in cached_results: + cached_results[chunk_id] = [] + # Store tuple with extraction result and creation time for sorting + cached_results[chunk_id].append((extraction_result, create_time)) + + # Sort extraction results by create_time for each chunk + for chunk_id in cached_results: + # Sort by create_time (x[1]), then extract only extraction_result (x[0]) + cached_results[chunk_id].sort(key=lambda x: x[1]) + cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]] + + logger.info( + f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results" ) return cached_results async def _parse_extraction_result( - text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str + text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str ) -> tuple[dict, dict]: """Parse cached extraction result using the same logic as extract_entities Args: + text_chunks_storage: Text chunks storage to get chunk data extraction_result: The cached LLM extraction result chunk_id: The chunk ID for source tracking @@ -431,8 +499,8 @@ async def _parse_extraction_result( Tuple of (entities_dict, relationships_dict) """ - # Get chunk data for file_path - chunk_data = await text_chunks.get_by_id(chunk_id) + # Get chunk data for file_path from storage + chunk_data = await text_chunks_storage.get_by_id(chunk_id) file_path = ( chunk_data.get("file_path", "unknown_source") if chunk_data @@ -805,8 +873,6 @@ async def _merge_nodes_then_upsert( entity_name, description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -969,8 +1035,6 @@ async def _merge_edges_then_upsert( f"({src_id}, {tgt_id})", description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -1146,6 +1210,7 @@ async def extract_entities( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + text_chunks_storage: BaseKVStorage | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -1252,6 +1317,9 @@ async def extract_entities( # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") + # Create cache keys collector for batch processing + cache_keys_collector = [] + # Get initial extraction hint_prompt = entity_extract_prompt.format( **{**context_base, "input_text": content} @@ -1263,7 +1331,10 @@ async def extract_entities( llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) + + # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path @@ -1280,6 +1351,7 @@ async def extract_entities( history_messages=history, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) @@ -1310,11 +1382,21 @@ async def extract_entities( llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", + cache_keys_collector=cache_keys_collector, ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break + # Batch update chunk's llm_cache_list with all collected cache keys + if cache_keys_collector and text_chunks_storage: + await update_chunk_cache_list( + chunk_key, + text_chunks_storage, + cache_keys_collector, + "entity_extraction", + ) + processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) diff --git a/lightrag/utils.py b/lightrag/utils.py index 6c40407b..c6e2def9 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1423,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return import_class +async def update_chunk_cache_list( + chunk_id: str, + text_chunks_storage: "BaseKVStorage", + cache_keys: list[str], + cache_scenario: str = "batch_update", +) -> None: + """Update chunk's llm_cache_list with the given cache keys + + Args: + chunk_id: Chunk identifier + text_chunks_storage: Text chunks storage instance + cache_keys: List of cache keys to add to the list + cache_scenario: Description of the cache scenario for logging + """ + if not cache_keys: + return + + try: + chunk_data = await text_chunks_storage.get_by_id(chunk_id) + if chunk_data: + # Ensure llm_cache_list exists + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] + + # Add cache keys to the list if not already present + existing_keys = set(chunk_data["llm_cache_list"]) + new_keys = [key for key in cache_keys if key not in existing_keys] + + if new_keys: + chunk_data["llm_cache_list"].extend(new_keys) + + # Update the chunk in storage + await text_chunks_storage.upsert({chunk_id: chunk_data}) + logger.debug( + f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})" + ) + except Exception as e: + logger.warning( + f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}" + ) + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1431,6 +1473,7 @@ async def use_llm_func_with_cache( history_messages: list[dict[str, str]] = None, cache_type: str = "extract", chunk_id: str | None = None, + cache_keys_collector: list = None, ) -> str: """Call LLM function with cache support @@ -1445,6 +1488,8 @@ async def use_llm_func_with_cache( history_messages: History messages list cache_type: Type of cache chunk_id: Chunk identifier to store in cache + text_chunks_storage: Text chunks storage to update llm_cache_list + cache_keys_collector: Optional list to collect cache keys for batch processing Returns: LLM response text @@ -1457,6 +1502,9 @@ async def use_llm_func_with_cache( _prompt = input_text arg_hash = compute_args_hash(_prompt) + # Generate cache key for this LLM call + cache_key = generate_cache_key("default", cache_type, arg_hash) + cached_return, _1, _2, _3 = await handle_cache( llm_response_cache, arg_hash, @@ -1467,6 +1515,11 @@ async def use_llm_func_with_cache( if cached_return: logger.debug(f"Found cache for {arg_hash}") statistic_data["llm_cache"] += 1 + + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return cached_return statistic_data["llm_call"] += 1 @@ -1491,6 +1544,10 @@ async def use_llm_func_with_cache( ), ) + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return res # When cache is disabled, directly call LLM