diff --git a/env.example b/env.example index f98648d2..98e4790b 100644 --- a/env.example +++ b/env.example @@ -58,6 +58,8 @@ SUMMARY_LANGUAGE=English # FORCE_LLM_SUMMARY_ON_MERGE=6 ### Max tokens for entity/relations description after merge # MAX_TOKEN_SUMMARY=500 +### Maximum number of entity extraction attempts for ambiguous content +# MAX_GLEANING=1 ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) # MAX_PARALLEL_INSERT=2 @@ -112,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage -### TiDB Configuration (Deprecated) -# TIDB_HOST=localhost -# TIDB_PORT=4000 -# TIDB_USER=your_username -# TIDB_PASSWORD='your_password' -# TIDB_DATABASE=your_database -### separating all data from difference Lightrag instances(deprecating) -# TIDB_WORKSPACE=default - ### PostgreSQL Configuration POSTGRES_HOST=localhost POSTGRES_PORT=5432 @@ -128,7 +121,7 @@ POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database POSTGRES_MAX_CONNECTIONS=12 -### separating all data from difference Lightrag instances(deprecating) +### separating all data from difference Lightrag instances # POSTGRES_WORKSPACE=default ### Neo4j Configuration @@ -144,14 +137,15 @@ NEO4J_PASSWORD='your_password' # AGE_POSTGRES_PORT=8529 # AGE Graph Name(apply to PostgreSQL and independent AGM) -### AGE_GRAPH_NAME is precated +### AGE_GRAPH_NAME is deprecated # AGE_GRAPH_NAME=lightrag ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_DATABASE=LightRAG ### separating all data from difference Lightrag instances(deprecating) -# MONGODB_GRAPH=false +### separating all data from difference Lightrag instances +# MONGODB_WORKSPACE=default ### Milvus Configuration MILVUS_URI=http://localhost:19530 diff --git a/examples/raganything_example.py b/examples/raganything_example.py index 4933b3d7..f61274a8 100644 --- a/examples/raganything_example.py +++ b/examples/raganything_example.py @@ -11,9 +11,74 @@ This example shows how to: import os import argparse import asyncio +import logging +import logging.config +from pathlib import Path + +# Add project root directory to Python path +import sys + +sys.path.append(str(Path(__file__).parent.parent)) + from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc -from raganything.raganything import RAGAnything +from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug +from raganything import RAGAnything, RAGAnythingConfig + + +def configure_logging(): + """Configure logging for the application""" + # Get log directory path from environment variable or use current directory + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log")) + + print(f"\nRAGAnything example log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + + # Set the logger level to INFO + logger.setLevel(logging.INFO) + # Enable verbose debug if needed + set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true") async def process_with_rag( @@ -31,15 +96,21 @@ async def process_with_rag( output_dir: Output directory for RAG results api_key: OpenAI API key base_url: Optional base URL for API + working_dir: Working directory for RAG storage """ try: - # Initialize RAGAnything - rag = RAGAnything( - working_dir=working_dir, - llm_model_func=lambda prompt, - system_prompt=None, - history_messages=[], - **kwargs: openai_complete_if_cache( + # Create RAGAnything configuration + config = RAGAnythingConfig( + working_dir=working_dir or "./rag_storage", + mineru_parse_method="auto", + enable_image_processing=True, + enable_table_processing=True, + enable_equation_processing=True, + ) + + # Define LLM model function + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( "gpt-4o-mini", prompt, system_prompt=system_prompt, @@ -47,81 +118,123 @@ async def process_with_rag( api_key=api_key, base_url=base_url, **kwargs, - ), - vision_model_func=lambda prompt, - system_prompt=None, - history_messages=[], - image_data=None, - **kwargs: openai_complete_if_cache( - "gpt-4o", - "", - system_prompt=None, - history_messages=[], - messages=[ - {"role": "system", "content": system_prompt} - if system_prompt - else None, - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_data}" - }, - }, - ], - } - if image_data - else {"role": "user", "content": prompt}, - ], - api_key=api_key, - base_url=base_url, - **kwargs, ) - if image_data - else openai_complete_if_cache( - "gpt-4o-mini", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=api_key, - base_url=base_url, - **kwargs, - ), - embedding_func=EmbeddingFunc( - embedding_dim=3072, - max_token_size=8192, - func=lambda texts: openai_embed( - texts, - model="text-embedding-3-large", + + # Define vision model function for image processing + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + {"role": "system", "content": system_prompt} + if system_prompt + else None, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + } + if image_data + else {"role": "user", "content": prompt}, + ], api_key=api_key, base_url=base_url, - ), + **kwargs, + ) + else: + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, ), ) + # Initialize RAGAnything with new dataclass structure + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + # Process document await rag.process_document_complete( file_path=file_path, output_dir=output_dir, parse_method="auto" ) - # Example queries - queries = [ + # Example queries - demonstrating different query approaches + logger.info("\nQuerying processed document:") + + # 1. Pure text queries using aquery() + text_queries = [ "What is the main content of the document?", - "Describe the images and figures in the document", - "Tell me about the experimental results and data tables", + "What are the key topics discussed?", ] - print("\nQuerying processed document:") - for query in queries: - print(f"\nQuery: {query}") - result = await rag.query_with_multimodal(query, mode="hybrid") - print(f"Answer: {result}") + for query in text_queries: + logger.info(f"\n[Text Query]: {query}") + result = await rag.aquery(query, mode="hybrid") + logger.info(f"Answer: {result}") + + # 2. Multimodal query with specific multimodal content using aquery_with_multimodal() + logger.info( + "\n[Multimodal Query]: Analyzing performance data in context of document" + ) + multimodal_result = await rag.aquery_with_multimodal( + "Compare this performance data with any similar results mentioned in the document", + multimodal_content=[ + { + "type": "table", + "table_data": """Method,Accuracy,Processing_Time + RAGAnything,95.2%,120ms + Traditional_RAG,87.3%,180ms + Baseline,82.1%,200ms""", + "table_caption": "Performance comparison results", + } + ], + mode="hybrid", + ) + logger.info(f"Answer: {multimodal_result}") + + # 3. Another multimodal query with equation content + logger.info("\n[Multimodal Query]: Mathematical formula analysis") + equation_result = await rag.aquery_with_multimodal( + "Explain this formula and relate it to any mathematical concepts in the document", + multimodal_content=[ + { + "type": "equation", + "latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}", + "equation_caption": "F1-score calculation formula", + } + ], + mode="hybrid", + ) + logger.info(f"Answer: {equation_result}") except Exception as e: - print(f"Error processing with RAG: {str(e)}") + logger.error(f"Error processing with RAG: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) def main(): @@ -135,12 +248,20 @@ def main(): "--output", "-o", default="./output", help="Output directory path" ) parser.add_argument( - "--api-key", required=True, help="OpenAI API key for RAG processing" + "--api-key", + default=os.getenv("OPENAI_API_KEY"), + help="OpenAI API key (defaults to OPENAI_API_KEY env var)", ) parser.add_argument("--base-url", help="Optional base URL for API") args = parser.parse_args() + # Check if API key is provided + if not args.api_key: + logger.error("Error: OpenAI API key is required") + logger.error("Set OPENAI_API_KEY environment variable or use --api-key option") + return + # Create output directory if specified if args.output: os.makedirs(args.output, exist_ok=True) @@ -154,4 +275,12 @@ def main(): if __name__ == "__main__": + # Configure logging first + configure_logging() + + print("RAGAnything Example") + print("=" * 30) + print("Processing document with multimodal RAG pipeline") + print("=" * 30) + main() diff --git a/examples/unofficial-sample/copy_llm_cache_to_another_storage.py b/examples/unofficial-sample/copy_llm_cache_to_another_storage.py index 60fa6192..1671b5d5 100644 --- a/examples/unofficial-sample/copy_llm_cache_to_another_storage.py +++ b/examples/unofficial-sample/copy_llm_cache_to_another_storage.py @@ -52,18 +52,23 @@ async def copy_from_postgres_to_json(): embedding_func=None, ) + # Get all cache data using the new flattened structure + all_data = await from_llm_response_cache.get_all() + + # Convert flattened data to hierarchical structure for JsonKVStorage kv = {} - for c_id in await from_llm_response_cache.all_keys(): - print(f"Copying {c_id}") - workspace = c_id["workspace"] - mode = c_id["mode"] - _id = c_id["id"] - postgres_db.workspace = workspace - obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id) - if mode not in kv: - kv[mode] = {} - kv[mode][_id] = obj[_id] - print(f"Object {obj}") + for flattened_key, cache_entry in all_data.items(): + # Parse flattened key: {mode}:{cache_type}:{hash} + parts = flattened_key.split(":", 2) + if len(parts) == 3: + mode, cache_type, hash_value = parts + if mode not in kv: + kv[mode] = {} + kv[mode][hash_value] = cache_entry + print(f"Copying {flattened_key} -> {mode}[{hash_value}]") + else: + print(f"Skipping invalid key format: {flattened_key}") + await to_llm_response_cache.upsert(kv) await to_llm_response_cache.index_done_callback() print("Mission accomplished!") @@ -85,13 +90,24 @@ async def copy_from_json_to_postgres(): db=postgres_db, ) - for mode in await from_llm_response_cache.all_keys(): - print(f"Copying {mode}") - caches = await from_llm_response_cache.get_by_id(mode) - for k, v in caches.items(): - item = {mode: {k: v}} - print(f"\tCopying {item}") - await to_llm_response_cache.upsert(item) + # Get all cache data from JsonKVStorage (hierarchical structure) + all_data = await from_llm_response_cache.get_all() + + # Convert hierarchical data to flattened structure for PGKVStorage + flattened_data = {} + for mode, mode_data in all_data.items(): + print(f"Processing mode: {mode}") + for hash_value, cache_entry in mode_data.items(): + # Determine cache_type from cache entry or use default + cache_type = cache_entry.get("cache_type", "extract") + # Create flattened key: {mode}:{cache_type}:{hash} + flattened_key = f"{mode}:{cache_type}:{hash_value}" + flattened_data[flattened_key] = cache_entry + print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}") + + # Upsert the flattened data + await to_llm_response_cache.upsert(flattened_data) + print("Mission accomplished!") if __name__ == "__main__": diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 2663a87a..952266f6 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0176" +__api_version__ = "0178" diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index fba5c3e8..c25f7241 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -62,6 +62,51 @@ router = APIRouter( temp_prefix = "__tmp__" +def sanitize_filename(filename: str, input_dir: Path) -> str: + """ + Sanitize uploaded filename to prevent Path Traversal attacks. + + Args: + filename: The original filename from the upload + input_dir: The target input directory + + Returns: + str: Sanitized filename that is safe to use + + Raises: + HTTPException: If the filename is unsafe or invalid + """ + # Basic validation + if not filename or not filename.strip(): + raise HTTPException(status_code=400, detail="Filename cannot be empty") + + # Remove path separators and traversal sequences + clean_name = filename.replace("/", "").replace("\\", "") + clean_name = clean_name.replace("..", "") + + # Remove control characters and null bytes + clean_name = "".join(c for c in clean_name if ord(c) >= 32 and c != "\x7f") + + # Remove leading/trailing whitespace and dots + clean_name = clean_name.strip().strip(".") + + # Check if anything is left after sanitization + if not clean_name: + raise HTTPException( + status_code=400, detail="Invalid filename after sanitization" + ) + + # Verify the final path stays within the input directory + try: + final_path = (input_dir / clean_name).resolve() + if not final_path.is_relative_to(input_dir.resolve()): + raise HTTPException(status_code=400, detail="Unsafe filename detected") + except (OSError, ValueError): + raise HTTPException(status_code=400, detail="Invalid filename") + + return clean_name + + class ScanResponse(BaseModel): """Response model for document scanning operation @@ -783,7 +828,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - logger.info(f"Found {total_files} new files to index.") + logger.info(f"Found {total_files} files to index.") if not new_files: return @@ -816,8 +861,13 @@ async def background_delete_documents( successful_deletions = [] failed_deletions = [] - # Set pipeline status to busy for deletion + # Double-check pipeline status before proceeding async with pipeline_status_lock: + if pipeline_status.get("busy", False): + logger.warning("Error: Unexpected pipeline busy state, aborting deletion.") + return # Abort deletion operation + + # Set pipeline status to busy for deletion pipeline_status.update( { "busy": True, @@ -926,13 +976,26 @@ async def background_delete_documents( async with pipeline_status_lock: pipeline_status["history_messages"].append(error_msg) finally: - # Final summary + # Final summary and check for pending requests async with pipeline_status_lock: pipeline_status["busy"] = False completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed" pipeline_status["latest_message"] = completion_msg pipeline_status["history_messages"].append(completion_msg) + # Check if there are pending document indexing requests + has_pending_request = pipeline_status.get("request_pending", False) + + # If there are pending requests, start document processing pipeline + if has_pending_request: + try: + logger.info( + "Processing pending document indexing requests after deletion" + ) + await rag.apipeline_process_enqueue_documents() + except Exception as e: + logger.error(f"Error processing pending documents after deletion: {e}") + def create_document_routes( rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None @@ -986,18 +1049,21 @@ def create_document_routes( HTTPException: If the file type is not supported (400) or other errors occur (500). """ try: - if not doc_manager.is_supported_file(file.filename): + # Sanitize filename to prevent Path Traversal attacks + safe_filename = sanitize_filename(file.filename, doc_manager.input_dir) + + if not doc_manager.is_supported_file(safe_filename): raise HTTPException( status_code=400, detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", ) - file_path = doc_manager.input_dir / file.filename + file_path = doc_manager.input_dir / safe_filename # Check if file already exists if file_path.exists(): return InsertResponse( status="duplicated", - message=f"File '{file.filename}' already exists in the input directory.", + message=f"File '{safe_filename}' already exists in the input directory.", ) with open(file_path, "wb") as buffer: @@ -1008,7 +1074,7 @@ def create_document_routes( return InsertResponse( status="success", - message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", + message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.", ) except Exception as e: logger.error(f"Error /documents/upload: {file.filename}: {str(e)}") diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index f77184e6..64c36a05 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -234,7 +234,7 @@ class OllamaAPI: @self.router.get("/version", dependencies=[Depends(combined_auth)]) async def get_version(): """Get Ollama version information""" - return OllamaVersionResponse(version="0.5.4") + return OllamaVersionResponse(version="0.9.3") @self.router.get("/tags", dependencies=[Depends(combined_auth)]) async def get_tags(): @@ -244,9 +244,9 @@ class OllamaAPI: { "name": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "size": self.ollama_server_infos.LIGHTRAG_SIZE, "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, - "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "details": { "parent_model": "", "format": "gguf", @@ -337,7 +337,10 @@ class OllamaAPI: data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": "", "done": True, + "done_reason": "stop", + "context": [], "total_duration": total_time, "load_duration": 0, "prompt_eval_count": prompt_tokens, @@ -377,6 +380,7 @@ class OllamaAPI: "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "response": f"\n\nError: {error_msg}", + "error": f"\n\nError: {error_msg}", "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" @@ -385,6 +389,7 @@ class OllamaAPI: final_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": "", "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" @@ -399,7 +404,10 @@ class OllamaAPI: data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": "", "done": True, + "done_reason": "stop", + "context": [], "total_duration": total_time, "load_duration": 0, "prompt_eval_count": prompt_tokens, @@ -444,6 +452,8 @@ class OllamaAPI: "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "response": str(response_text), "done": True, + "done_reason": "stop", + "context": [], "total_duration": total_time, "load_duration": 0, "prompt_eval_count": prompt_tokens, @@ -557,6 +567,12 @@ class OllamaAPI: data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, + "done_reason": "stop", "done": True, "total_duration": total_time, "load_duration": 0, @@ -605,6 +621,7 @@ class OllamaAPI: "content": f"\n\nError: {error_msg}", "images": None, }, + "error": f"\n\nError: {error_msg}", "done": False, } yield f"{json.dumps(error_data, ensure_ascii=False)}\n" @@ -613,6 +630,11 @@ class OllamaAPI: final_data = { "model": self.ollama_server_infos.LIGHTRAG_MODEL, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" @@ -633,6 +655,7 @@ class OllamaAPI: "content": "", "images": None, }, + "done_reason": "stop", "done": True, "total_duration": total_time, "load_duration": 0, @@ -697,6 +720,7 @@ class OllamaAPI: "content": str(response_text), "images": None, }, + "done_reason": "stop", "done": True, "total_duration": total_time, "load_duration": 0, diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index b23dbd32..69aa32d8 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -183,6 +183,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): if isinstance(response, str): # If it's a string, send it all at once yield f"{json.dumps({'response': response})}\n" + elif response is None: + # Handle None response (e.g., when only_need_context=True but no context found) + yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n" else: # If it's an async generator, send chunks one by one try: diff --git a/lightrag/base.py b/lightrag/base.py index 12d142c1..7820b4da 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC): @dataclass class BaseGraphStorage(StorageNameSpace, ABC): + """All operations related to edges in graph should be undirected.""" + embedding_func: EmbeddingFunc @abstractmethod @@ -468,17 +470,6 @@ class BaseGraphStorage(StorageNameSpace, ABC): list[dict]: A list of nodes, where each node is a dictionary of its properties. An empty list if no matching nodes are found. """ - # Default implementation iterates through all nodes, which is inefficient. - # This method should be overridden by subclasses for better performance. - all_nodes = [] - all_labels = await self.get_all_labels() - for label in all_labels: - node = await self.get_node(label) - if node and "source_id" in node: - source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP)) - if not source_ids.isdisjoint(chunk_ids): - all_nodes.append(node) - return all_nodes @abstractmethod async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: @@ -643,6 +634,8 @@ class DocProcessingStatus: """ISO format timestamp when document was last updated""" chunks_count: int | None = None """Number of chunks after splitting, used for processing""" + chunks_list: list[str] | None = field(default_factory=list) + """List of chunk IDs associated with this document, used for deletion""" error: str | None = None """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) diff --git a/lightrag/constants.py b/lightrag/constants.py index f8345994..82451a36 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -7,6 +7,7 @@ consistency and makes maintenance easier. """ # Default values for environment variables +DEFAULT_MAX_GLEANING = 1 DEFAULT_MAX_TOKEN_SUMMARY = 500 DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6 DEFAULT_WOKERS = 2 diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 3398b135..b2a93e82 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -26,11 +26,11 @@ STORAGE_IMPLEMENTATIONS = { "implementations": [ "NanoVectorDBStorage", "MilvusVectorDBStorage", - "ChromaVectorDBStorage", "PGVectorStorage", "FaissVectorDBStorage", "QdrantVectorDBStorage", "MongoVectorDBStorage", + # "ChromaVectorDBStorage", # "TiDBVectorDBStorage", ], "required_methods": ["query", "upsert"], @@ -38,6 +38,7 @@ STORAGE_IMPLEMENTATIONS = { "DOC_STATUS_STORAGE": { "implementations": [ "JsonDocStatusStorage", + "RedisDocStatusStorage", "PGDocStatusStorage", "MongoDocStatusStorage", ], @@ -81,6 +82,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], + "RedisDocStatusStorage": ["REDIS_URI"], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "MongoDocStatusStorage": [], } @@ -98,6 +100,7 @@ STORAGES = { "MongoGraphStorage": ".kg.mongo_impl", "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", + "RedisDocStatusStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", # "TiDBKVStorage": ".kg.tidb_impl", # "TiDBVectorDBStorage": ".kg.tidb_impl", diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py similarity index 98% rename from lightrag/kg/chroma_impl.py rename to lightrag/kg/deprecated/chroma_impl.py index c3927a19..ebdd4593 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): raise async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return @@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") self._collection.delete(ids=ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/deprecated/gremlin_impl.py similarity index 100% rename from lightrag/kg/gremlin_impl.py rename to lightrag/kg/deprecated/gremlin_impl.py diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/deprecated/tidb_impl.py similarity index 99% rename from lightrag/kg/tidb_impl.py rename to lightrag/kg/deprecated/tidb_impl.py index 9b9d17a9..d60bb1f6 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/deprecated/tidb_impl.py @@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage): ################ INSERT full_doc AND chunks ################ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return left_data = {k: v for k, v in data.items() if k not in self._data} @@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): ###### INSERT entities And relationships ###### async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") # Get current time as UNIX timestamp import time @@ -522,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage): } await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) - async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs from the storage. diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index f2afde2e..cb19497a 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -17,14 +17,13 @@ from .shared_storage import ( set_all_update_flags, ) -import faiss # type: ignore - USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1" FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu" - if not pm.is_installed(FAISS_PACKAGE): pm.install(FAISS_PACKAGE) +import faiss # type: ignore + @final @dataclass diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index f8387ad8..ab6ab390 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage): return logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index fa819d4a..98835f8c 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage): if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: - self._data.update(loaded_data) - - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # For cache namespaces, sum the cache entries across all cache types - data_count = sum( - len(first_level_dict) - for first_level_dict in loaded_data.values() - if isinstance(first_level_dict, dict) + # Migrate legacy cache structure if needed + if self.namespace.endswith("_cache"): + loaded_data = await self._migrate_legacy_cache_structure( + loaded_data ) - else: - # For non-cache namespaces, use the original count method - data_count = len(loaded_data) + + self._data.update(loaded_data) + data_count = len(loaded_data) logger.info( f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" @@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage): dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # # For cache namespaces, sum the cache entries across all cache types - data_count = sum( - len(first_level_dict) - for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict) - ) - else: - # For non-cache namespaces, use the original count method - data_count = len(data_dict) + # Calculate data count - all data is now flattened + data_count = len(data_dict) logger.debug( f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" @@ -92,22 +78,49 @@ class JsonKVStorage(BaseKVStorage): Dictionary containing all stored data """ async with self._storage_lock: - return dict(self._data) + result = {} + for key, value in self._data.items(): + if value: + # Create a copy to avoid modifying the original data + data = dict(value) + # Ensure time fields are present, provide default values for old data + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + result[key] = data + else: + result[key] = value + return result async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: - return self._data.get(id) + result = self._data.get(id) + if result: + # Create a copy to avoid modifying the original data + result = dict(result) + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + # Ensure _id field contains the clean ID + result["_id"] = id + return result async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async with self._storage_lock: - return [ - ( - {k: v for k, v in self._data[id].items()} - if self._data.get(id, None) - else None - ) - for id in ids - ] + results = [] + for id in ids: + data = self._data.get(id, None) + if data: + # Create a copy to avoid modifying the original data + result = {k: v for k, v in data.items()} + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + # Ensure _id field contains the clean ID + result["_id"] = id + results.append(result) + else: + results.append(None) + return results async def filter_keys(self, keys: set[str]) -> set[str]: async with self._storage_lock: @@ -121,8 +134,29 @@ class JsonKVStorage(BaseKVStorage): """ if not data: return + + import time + + current_time = int(time.time()) # Get current Unix timestamp + logger.debug(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: + # Add timestamps to data based on whether key exists + for k, v in data.items(): + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + + # Add timestamps based on whether key exists + if k in self._data: # Key exists, only update update_time + v["update_time"] = current_time + else: # New key, set both create_time and update_time + v["create_time"] = current_time + v["update_time"] = current_time + + v["_id"] = k + self._data.update(data) await set_all_update_flags(self.namespace) @@ -150,14 +184,14 @@ class JsonKVStorage(BaseKVStorage): await set_all_update_flags(self.namespace) async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by by cache mode + """Delete specific records from storage by cache mode Importance notes for in-memory storage: 1. Changes will be persisted to disk during the next index_done_callback 2. update flags to notify other processes that data persistence is needed Args: - ids (list[str]): List of cache mode to be drop from storage + modes (list[str]): List of cache modes to be dropped from storage Returns: True: if the cache drop successfully @@ -167,9 +201,29 @@ class JsonKVStorage(BaseKVStorage): return False try: - await self.delete(modes) + async with self._storage_lock: + keys_to_delete = [] + modes_set = set(modes) # Convert to set for efficient lookup + + for key in list(self._data.keys()): + # Parse flattened cache key: mode:cache_type:hash + parts = key.split(":", 2) + if len(parts) == 3 and parts[0] in modes_set: + keys_to_delete.append(key) + + # Batch delete + for key in keys_to_delete: + self._data.pop(key, None) + + if keys_to_delete: + await set_all_update_flags(self.namespace) + logger.info( + f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}" + ) + return True - except Exception: + except Exception as e: + logger.error(f"Error dropping cache by modes: {e}") return False # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: @@ -245,9 +299,58 @@ class JsonKVStorage(BaseKVStorage): logger.error(f"Error dropping {self.namespace}: {e}") return {"status": "error", "message": str(e)} + async def _migrate_legacy_cache_structure(self, data: dict) -> dict: + """Migrate legacy nested cache structure to flattened structure + + Args: + data: Original data dictionary that may contain legacy structure + + Returns: + Migrated data dictionary with flattened cache keys + """ + from lightrag.utils import generate_cache_key + + # Early return if data is empty + if not data: + return data + + # Check first entry to see if it's already in new format + first_key = next(iter(data.keys())) + if ":" in first_key and len(first_key.split(":")) == 3: + # Already in flattened format, return as-is + return data + + migrated_data = {} + migration_count = 0 + + for key, value in data.items(): + # Check if this is a legacy nested cache structure + if isinstance(value, dict) and all( + isinstance(v, dict) and "return" in v for v in value.values() + ): + # This looks like a legacy cache mode with nested structure + mode = key + for cache_hash, cache_entry in value.items(): + cache_type = cache_entry.get("cache_type", "extract") + flattened_key = generate_cache_key(mode, cache_type, cache_hash) + migrated_data[flattened_key] = cache_entry + migration_count += 1 + else: + # Keep non-cache data or already flattened cache data as-is + migrated_data[key] = value + + if migration_count > 0: + logger.info( + f"Migrated {migration_count} legacy cache entries to flattened structure" + ) + # Persist migrated data immediately + write_json(migrated_data, self._file_name) + + return migrated_data + async def finalize(self): """Finalize storage resources Persistence cache data to disk before exiting """ - if self.namespace.endswith("cache"): + if self.namespace.endswith("_cache"): await self.index_done_callback() diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 23e178bc..6cffae88 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"): pm.install("pymilvus") import configparser -from pymilvus import MilvusClient # type: ignore +from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -24,16 +24,605 @@ config.read("config.ini", "utf-8") @final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): - @staticmethod - def create_collection_if_not_exist( - client: MilvusClient, collection_name: str, **kwargs - ): - if client.has_collection(collection_name): - return - client.create_collection( - collection_name, max_length=64, id_type="string", **kwargs + def _create_schema_for_namespace(self) -> CollectionSchema: + """Create schema based on the current instance's namespace""" + + # Get vector dimension from embedding_func + dimension = self.embedding_func.embedding_dim + + # Base fields (common to all collections) + base_fields = [ + FieldSchema( + name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True + ), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name="created_at", dtype=DataType.INT64), + ] + + # Determine specific fields based on namespace + if "entities" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="entity_name", + dtype=DataType.VARCHAR, + max_length=256, + nullable=True, + ), + FieldSchema( + name="entity_type", + dtype=DataType.VARCHAR, + max_length=64, + nullable=True, + ), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG entities vector storage" + + elif "relationships" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True + ), + FieldSchema( + name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True + ), + FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG relationships vector storage" + + elif "chunks" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="full_doc_id", + dtype=DataType.VARCHAR, + max_length=64, + nullable=True, + ), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG chunks vector storage" + + else: + # Default generic schema (backward compatibility) + specific_fields = [ + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG generic vector storage" + + # Merge all fields + all_fields = base_fields + specific_fields + + return CollectionSchema( + fields=all_fields, + description=description, + enable_dynamic_field=True, # Support dynamic fields ) + def _get_index_params(self): + """Get IndexParams in a version-compatible way""" + try: + # Try to use client's prepare_index_params method (most common) + if hasattr(self._client, "prepare_index_params"): + return self._client.prepare_index_params() + except Exception: + pass + + try: + # Try to import IndexParams from different possible locations + from pymilvus.client.prepare import IndexParams + + return IndexParams() + except ImportError: + pass + + try: + from pymilvus.client.types import IndexParams + + return IndexParams() + except ImportError: + pass + + try: + from pymilvus import IndexParams + + return IndexParams() + except ImportError: + pass + + # If all else fails, return None to use fallback method + return None + + def _create_vector_index_fallback(self): + """Fallback method to create vector index using direct API""" + try: + self._client.create_index( + collection_name=self.namespace, + field_name="vector", + index_params={ + "index_type": "HNSW", + "metric_type": "COSINE", + "params": {"M": 16, "efConstruction": 256}, + }, + ) + logger.debug("Created vector index using fallback method") + except Exception as e: + logger.warning(f"Failed to create vector index using fallback method: {e}") + + def _create_scalar_index_fallback(self, field_name: str, index_type: str): + """Fallback method to create scalar index using direct API""" + # Skip unsupported index types + if index_type == "SORTED": + logger.info( + f"Skipping SORTED index for {field_name} (not supported in this Milvus version)" + ) + return + + try: + self._client.create_index( + collection_name=self.namespace, + field_name=field_name, + index_params={"index_type": index_type}, + ) + logger.debug(f"Created {field_name} index using fallback method") + except Exception as e: + logger.info( + f"Could not create {field_name} index using fallback method: {e}" + ) + + def _create_indexes_after_collection(self): + """Create indexes after collection is created""" + try: + # Try to get IndexParams in a version-compatible way + IndexParamsClass = self._get_index_params() + + if IndexParamsClass is not None: + # Use IndexParams approach if available + try: + # Create vector index first (required for most operations) + vector_index = IndexParamsClass + vector_index.add_index( + field_name="vector", + index_type="HNSW", + metric_type="COSINE", + params={"M": 16, "efConstruction": 256}, + ) + self._client.create_index( + collection_name=self.namespace, index_params=vector_index + ) + logger.debug("Created vector index using IndexParams") + except Exception as e: + logger.debug(f"IndexParams method failed for vector index: {e}") + self._create_vector_index_fallback() + + # Create scalar indexes based on namespace + if "entities" in self.namespace.lower(): + # Create indexes for entity fields + try: + entity_name_index = self._get_index_params() + entity_name_index.add_index( + field_name="entity_name", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, + index_params=entity_name_index, + ) + except Exception as e: + logger.debug(f"IndexParams method failed for entity_name: {e}") + self._create_scalar_index_fallback("entity_name", "INVERTED") + + try: + entity_type_index = self._get_index_params() + entity_type_index.add_index( + field_name="entity_type", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, + index_params=entity_type_index, + ) + except Exception as e: + logger.debug(f"IndexParams method failed for entity_type: {e}") + self._create_scalar_index_fallback("entity_type", "INVERTED") + + elif "relationships" in self.namespace.lower(): + # Create indexes for relationship fields + try: + src_id_index = self._get_index_params() + src_id_index.add_index( + field_name="src_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=src_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for src_id: {e}") + self._create_scalar_index_fallback("src_id", "INVERTED") + + try: + tgt_id_index = self._get_index_params() + tgt_id_index.add_index( + field_name="tgt_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=tgt_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for tgt_id: {e}") + self._create_scalar_index_fallback("tgt_id", "INVERTED") + + elif "chunks" in self.namespace.lower(): + # Create indexes for chunk fields + try: + doc_id_index = self._get_index_params() + doc_id_index.add_index( + field_name="full_doc_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=doc_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for full_doc_id: {e}") + self._create_scalar_index_fallback("full_doc_id", "INVERTED") + + # No common indexes needed + + else: + # Fallback to direct API calls if IndexParams is not available + logger.info( + f"IndexParams not available, using fallback methods for {self.namespace}" + ) + + # Create vector index using fallback + self._create_vector_index_fallback() + + # Create scalar indexes using fallback + if "entities" in self.namespace.lower(): + self._create_scalar_index_fallback("entity_name", "INVERTED") + self._create_scalar_index_fallback("entity_type", "INVERTED") + elif "relationships" in self.namespace.lower(): + self._create_scalar_index_fallback("src_id", "INVERTED") + self._create_scalar_index_fallback("tgt_id", "INVERTED") + elif "chunks" in self.namespace.lower(): + self._create_scalar_index_fallback("full_doc_id", "INVERTED") + + logger.info(f"Created indexes for collection: {self.namespace}") + + except Exception as e: + logger.warning(f"Failed to create some indexes for {self.namespace}: {e}") + + def _get_required_fields_for_namespace(self) -> dict: + """Get required core field definitions for current namespace""" + + # Base fields (common to all types) + base_fields = { + "id": {"type": "VarChar", "is_primary": True}, + "vector": {"type": "FloatVector"}, + "created_at": {"type": "Int64"}, + } + + # Add specific fields based on namespace + if "entities" in self.namespace.lower(): + specific_fields = { + "entity_name": {"type": "VarChar"}, + "entity_type": {"type": "VarChar"}, + "file_path": {"type": "VarChar"}, + } + elif "relationships" in self.namespace.lower(): + specific_fields = { + "src_id": {"type": "VarChar"}, + "tgt_id": {"type": "VarChar"}, + "weight": {"type": "Double"}, + "file_path": {"type": "VarChar"}, + } + elif "chunks" in self.namespace.lower(): + specific_fields = { + "full_doc_id": {"type": "VarChar"}, + "file_path": {"type": "VarChar"}, + } + else: + specific_fields = { + "file_path": {"type": "VarChar"}, + } + + return {**base_fields, **specific_fields} + + def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool: + """Check compatibility of a single field""" + field_name = existing_field.get("name", "unknown") + existing_type = existing_field.get("type") + expected_type = expected_config.get("type") + + logger.debug( + f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" + ) + + # Convert DataType enum values to string names if needed + original_existing_type = existing_type + if hasattr(existing_type, "name"): + existing_type = existing_type.name + logger.debug( + f"Converted enum to name: {original_existing_type} -> {existing_type}" + ) + elif isinstance(existing_type, int): + # Map common Milvus internal type codes to type names for backward compatibility + type_mapping = { + 21: "VarChar", + 101: "FloatVector", + 5: "Int64", + 9: "Double", + } + mapped_type = type_mapping.get(existing_type, str(existing_type)) + logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}") + existing_type = mapped_type + + # Normalize type names for comparison + type_aliases = { + "VARCHAR": "VarChar", + "String": "VarChar", + "FLOAT_VECTOR": "FloatVector", + "INT64": "Int64", + "BigInt": "Int64", + "DOUBLE": "Double", + "Float": "Double", + } + + original_existing = existing_type + original_expected = expected_type + existing_type = type_aliases.get(existing_type, existing_type) + expected_type = type_aliases.get(expected_type, expected_type) + + if original_existing != existing_type or original_expected != expected_type: + logger.debug( + f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" + ) + + # Basic type compatibility check + type_compatible = existing_type == expected_type + logger.debug( + f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" + ) + + if not type_compatible: + logger.warning( + f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" + ) + return False + + # Primary key check - be more flexible about primary key detection + if expected_config.get("is_primary"): + # Check multiple possible field names for primary key status + is_primary = ( + existing_field.get("is_primary_key", False) + or existing_field.get("is_primary", False) + or existing_field.get("primary_key", False) + ) + logger.debug( + f"Primary key check for '{field_name}': expected=True, actual={is_primary}" + ) + logger.debug(f"Raw field data for '{field_name}': {existing_field}") + + # For ID field, be more lenient - if it's the ID field, assume it should be primary + if field_name == "id" and not is_primary: + logger.info( + f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" + ) + # Don't fail for ID field primary key mismatch + elif not is_primary: + logger.warning( + f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" + ) + return False + + logger.debug(f"Field '{field_name}' is compatible") + return True + + def _check_vector_dimension(self, collection_info: dict): + """Check vector dimension compatibility""" + current_dimension = self.embedding_func.embedding_dim + + # Find vector field dimension + for field in collection_info.get("fields", []): + if field.get("name") == "vector": + field_type = field.get("type") + if field_type in ["FloatVector", "FLOAT_VECTOR"]: + existing_dimension = field.get("params", {}).get("dim") + + if existing_dimension != current_dimension: + raise ValueError( + f"Vector dimension mismatch for collection '{self.namespace}': " + f"existing={existing_dimension}, current={current_dimension}" + ) + + logger.debug(f"Vector dimension check passed: {current_dimension}") + return + + # If no vector field found, this might be an old collection created with simple schema + logger.warning( + f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." + ) + logger.warning("Consider recreating the collection for optimal performance.") + return + + def _check_schema_compatibility(self, collection_info: dict): + """Check schema field compatibility""" + existing_fields = { + field["name"]: field for field in collection_info.get("fields", []) + } + + # Check if this is an old collection created with simple schema + has_vector_field = any( + field.get("name") == "vector" for field in collection_info.get("fields", []) + ) + + if not has_vector_field: + logger.warning( + f"Collection {self.namespace} appears to be created with old simple schema (no vector field)" + ) + logger.warning( + "This collection will work but may have suboptimal performance" + ) + logger.warning("Consider recreating the collection for optimal performance") + return + + # For collections with vector field, check basic compatibility + # Only check for critical incompatibilities, not missing optional fields + critical_fields = {"id": {"type": "VarChar", "is_primary": True}} + + incompatible_fields = [] + + for field_name, expected_config in critical_fields.items(): + if field_name in existing_fields: + existing_field = existing_fields[field_name] + if not self._is_field_compatible(existing_field, expected_config): + incompatible_fields.append( + f"{field_name}: expected {expected_config['type']}, " + f"got {existing_field.get('type')}" + ) + + if incompatible_fields: + raise ValueError( + f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}" + ) + + # Get all expected fields for informational purposes + expected_fields = self._get_required_fields_for_namespace() + missing_fields = [ + field for field in expected_fields if field not in existing_fields + ] + + if missing_fields: + logger.info( + f"Collection {self.namespace} missing optional fields: {missing_fields}" + ) + logger.info( + "These fields would be available in a newly created collection for better performance" + ) + + logger.debug(f"Schema compatibility check passed for {self.namespace}") + + def _validate_collection_compatibility(self): + """Validate existing collection's dimension and schema compatibility""" + try: + collection_info = self._client.describe_collection(self.namespace) + + # 1. Check vector dimension + self._check_vector_dimension(collection_info) + + # 2. Check schema compatibility + self._check_schema_compatibility(collection_info) + + logger.info(f"Collection {self.namespace} compatibility validation passed") + + except Exception as e: + logger.error( + f"Collection compatibility validation failed for {self.namespace}: {e}" + ) + raise + + def _create_collection_if_not_exist(self): + """Create collection if not exists and check existing collection compatibility""" + + try: + # First, list all collections to see what actually exists + try: + all_collections = self._client.list_collections() + logger.debug(f"All collections in database: {all_collections}") + except Exception as list_error: + logger.warning(f"Could not list collections: {list_error}") + all_collections = [] + + # Check if our specific collection exists + collection_exists = self._client.has_collection(self.namespace) + logger.info( + f"Collection '{self.namespace}' exists check: {collection_exists}" + ) + + if collection_exists: + # Double-check by trying to describe the collection + try: + self._client.describe_collection(self.namespace) + logger.info( + f"Collection '{self.namespace}' confirmed to exist, validating compatibility..." + ) + self._validate_collection_compatibility() + return + except Exception as describe_error: + logger.warning( + f"Collection '{self.namespace}' exists but cannot be described: {describe_error}" + ) + logger.info( + "Treating as if collection doesn't exist and creating new one..." + ) + # Fall through to creation logic + + # Collection doesn't exist, create new collection + logger.info(f"Creating new collection: {self.namespace}") + schema = self._create_schema_for_namespace() + + # Create collection with schema only first + self._client.create_collection( + collection_name=self.namespace, schema=schema + ) + + # Then create indexes + self._create_indexes_after_collection() + + logger.info(f"Successfully created Milvus collection: {self.namespace}") + + except Exception as e: + logger.error( + f"Error in _create_collection_if_not_exist for {self.namespace}: {e}" + ) + + # If there's any error, try to force create the collection + logger.info(f"Attempting to force create collection {self.namespace}...") + try: + # Try to drop the collection first if it exists in a bad state + try: + if self._client.has_collection(self.namespace): + logger.info( + f"Dropping potentially corrupted collection {self.namespace}" + ) + self._client.drop_collection(self.namespace) + except Exception as drop_error: + logger.warning( + f"Could not drop collection {self.namespace}: {drop_error}" + ) + + # Create fresh collection + schema = self._create_schema_for_namespace() + self._client.create_collection( + collection_name=self.namespace, schema=schema + ) + self._create_indexes_after_collection() + logger.info(f"Successfully force-created collection {self.namespace}") + + except Exception as create_error: + logger.error( + f"Failed to force-create collection {self.namespace}: {create_error}" + ) + raise + def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + # Ensure created_at is in meta_fields + if "created_at" not in self.meta_fields: + self.meta_fields.add("created_at") + self._client = MilvusClient( uri=os.environ.get( "MILVUS_URI", @@ -68,14 +661,12 @@ class MilvusVectorDBStorage(BaseVectorStorage): ), ) self._max_batch_size = self.global_config["embedding_batch_num"] - MilvusVectorDBStorage.create_collection_if_not_exist( - self._client, - self.namespace, - dimension=self.embedding_func.embedding_dim, - ) + + # Create collection and check compatibility + self._create_collection_if_not_exist() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return @@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query + + # Include all meta_fields (created_at is now always included) + output_fields = list(self.meta_fields) + results = self._client.search( collection_name=self.namespace, data=embedding, limit=top_k, - output_fields=list(self.meta_fields) + ["created_at"], + output_fields=output_fields, search_params={ "metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}, }, ) - print(results) return [ { **dp["entity"], "id": dp["id"], "distance": dp["distance"], - # created_at is requested in output_fields, so it should be a top-level key in the result dict (dp) "created_at": dp.get("created_at"), } for dp in results[0] @@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage): The vector data if found, or None if not found """ try: + # Include all meta_fields (created_at is now always included) plus id + output_fields = list(self.meta_fields) + ["id"] + # Query Milvus for a specific ID result = self._client.query( collection_name=self.namespace, filter=f'id == "{id}"', - output_fields=list(self.meta_fields) + ["id", "created_at"], + output_fields=output_fields, ) if not result or len(result) == 0: return None - # Ensure the result contains created_at field - if "created_at" not in result[0]: - result[0]["created_at"] = None - return result[0] except Exception as e: logger.error(f"Error retrieving vector data for ID {id}: {e}") @@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return [] try: + # Include all meta_fields (created_at is now always included) plus id + output_fields = list(self.meta_fields) + ["id"] + # Prepare the ID filter expression id_list = '", "'.join(ids) filter_expr = f'id in ["{id_list}"]' @@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): result = self._client.query( collection_name=self.namespace, filter=filter_expr, - output_fields=list(self.meta_fields) + ["id", "created_at"], + output_fields=output_fields, ) - # Ensure each result contains created_at field - for item in result: - if "created_at" not in item: - item["created_at"] = None - return result or [] except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") @@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): self._client.drop_collection(self.namespace) # Recreate the collection - MilvusVectorDBStorage.create_collection_if_not_exist( - self._client, - self.namespace, - dimension=self.embedding_func.embedding_dim, - ) + self._create_collection_if_not_exist() logger.info( f"Process {os.getpid()} drop Milvus collection {self.namespace}" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index fbea463b..2ac3aff2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,4 +1,5 @@ import os +import time from dataclasses import dataclass, field import numpy as np import configparser @@ -14,7 +15,6 @@ from ..base import ( DocStatus, DocStatusStorage, ) -from ..namespace import NameSpace, is_namespace from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP @@ -35,6 +35,7 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") class ClientManager: @@ -96,11 +97,22 @@ class MongoKVStorage(BaseKVStorage): self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: - return await self._data.find_one({"_id": id}) + # Unified handling for flattened keys + doc = await self._data.find_one({"_id": id}) + if doc: + # Ensure time fields are present, provide default values for old data + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) + return doc async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) - return await cursor.to_list() + docs = await cursor.to_list() + # Ensure time fields are present for all documents + for doc in docs: + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) + return docs async def filter_keys(self, keys: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) @@ -117,47 +129,53 @@ class MongoKVStorage(BaseKVStorage): result = {} async for doc in cursor: doc_id = doc.pop("_id") + # Ensure time fields are present for all documents + doc.setdefault("create_time", 0) + doc.setdefault("update_time", 0) result[doc_id] = doc return result async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - update_tasks: list[Any] = [] - for mode, items in data.items(): - for k, v in items.items(): - key = f"{mode}_{k}" - data[mode][k]["_id"] = f"{mode}_{k}" - update_tasks.append( - self._data.update_one( - {"_id": key}, {"$setOnInsert": v}, upsert=True - ) - ) - await asyncio.gather(*update_tasks) - else: - update_tasks = [] - for k, v in data.items(): - data[k]["_id"] = k - update_tasks.append( - self._data.update_one({"_id": k}, {"$set": v}, upsert=True) - ) - await asyncio.gather(*update_tasks) + # Unified handling for all namespaces with flattened keys + # Use bulk_write for better performance + from pymongo import UpdateOne - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - res = {} - v = await self._data.find_one({"_id": mode + "_" + id}) - if v: - res[id] = v - logger.debug(f"llm_response_cache find one by:{id}") - return res - else: - return None - else: - return None + operations = [] + current_time = int(time.time()) # Get current Unix timestamp + + for k, v in data.items(): + # For text_chunks namespace, ensure llm_cache_list field exists + if self.namespace.endswith("text_chunks"): + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + + # Create a copy of v for $set operation, excluding create_time to avoid conflicts + v_for_set = v.copy() + v_for_set["_id"] = k # Use flattened key as _id + v_for_set["update_time"] = current_time # Always update update_time + + # Remove create_time from $set to avoid conflict with $setOnInsert + v_for_set.pop("create_time", None) + + operations.append( + UpdateOne( + {"_id": k}, + { + "$set": v_for_set, # Update all fields except create_time + "$setOnInsert": { + "create_time": current_time + }, # Set create_time only on insert + }, + upsert=True, + ) + ) + + if operations: + await self._data.bulk_write(operations) async def index_done_callback(self) -> None: # Mongo handles persistence automatically @@ -197,8 +215,8 @@ class MongoKVStorage(BaseKVStorage): return False try: - # Build regex pattern to match documents with the specified modes - pattern = f"^({'|'.join(modes)})_" + # Build regex pattern to match flattened key format: mode:cache_type:hash + pattern = f"^({'|'.join(modes)}):" result = await self._data.delete_many({"_id": {"$regex": pattern}}) logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}") return True @@ -262,11 +280,14 @@ class MongoDocStatusStorage(DocStatusStorage): return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return update_tasks: list[Any] = [] for k, v in data.items(): + # Ensure chunks_list field exists and is an array + if "chunks_list" not in v: + v["chunks_list"] = [] data[k]["_id"] = k update_tasks.append( self._data.update_one({"_id": k}, {"$set": v}, upsert=True) @@ -299,6 +320,7 @@ class MongoDocStatusStorage(DocStatusStorage): updated_at=doc.get("updated_at"), chunks_count=doc.get("chunks_count", -1), file_path=doc.get("file_path", doc["_id"]), + chunks_list=doc.get("chunks_list", []), ) for doc in result } @@ -417,11 +439,21 @@ class MongoGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ - Check if there's a direct single-hop edge from source_node_id to target_node_id. + Check if there's a direct single-hop edge between source_node_id and target_node_id. """ - # Direct check if the target_node appears among the edges array. doc = await self.edge_collection.find_one( - {"source_node_id": source_node_id, "target_node_id": target_node_id}, + { + "$or": [ + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, + ] + }, {"_id": 1}, ) return doc is not None @@ -651,7 +683,7 @@ class MongoGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """ - Upsert an edge from source_node_id -> target_node_id with optional 'relation'. + Upsert an edge between source_node_id and target_node_id with optional 'relation'. If an edge with the same target exists, we remove it and re-insert with updated data. """ # Ensure source node exists @@ -663,8 +695,22 @@ class MongoGraphStorage(BaseGraphStorage): GRAPH_FIELD_SEP ) + edge_data["source_node_id"] = source_node_id + edge_data["target_node_id"] = target_node_id + await self.edge_collection.update_one( - {"source_node_id": source_node_id, "target_node_id": target_node_id}, + { + "$or": [ + { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + }, + { + "source_node_id": target_node_id, + "target_node_id": source_node_id, + }, + ] + }, update_doc, upsert=True, ) @@ -678,7 +724,7 @@ class MongoGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: """ 1) Remove node's doc entirely. - 2) Remove inbound edges from any doc that references node_id. + 2) Remove inbound & outbound edges from any doc that references node_id. """ # Remove all edges await self.edge_collection.delete_many( @@ -709,141 +755,369 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels + def _construct_graph_node( + self, node_id, node_data: dict[str, str] + ) -> KnowledgeGraphNode: + return KnowledgeGraphNode( + id=node_id, + labels=[node_id], + properties={ + k: v + for k, v in node_data.items() + if k + not in [ + "_id", + "connected_edges", + "source_ids", + "edge_count", + ] + }, + ) + + def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]): + return KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relationship", ""), + source=edge["source_node_id"], + target=edge["target_node_id"], + properties={ + k: v + for k, v in edge.items() + if k + not in [ + "_id", + "source_node_id", + "target_node_id", + "relationship", + "source_ids", + ] + }, + ) + + async def get_knowledge_graph_all_by_degree( + self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + """ + It's possible that the node with one or multiple relationships is retrieved, + while its neighbor is not. Then this node might seem like disconnected in UI. + """ + + total_node_count = await self.collection.count_documents({}) + result = KnowledgeGraph() + seen_edges = set() + + result.is_truncated = total_node_count > max_nodes + if result.is_truncated: + # Get all node_ids ranked by degree if max_nodes exceeds total node count + pipeline = [ + {"$project": {"source_node_id": 1, "_id": 0}}, + {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}}, + { + "$unionWith": { + "coll": self._edge_collection_name, + "pipeline": [ + {"$project": {"target_node_id": 1, "_id": 0}}, + { + "$group": { + "_id": "$target_node_id", + "degree": {"$sum": 1}, + } + }, + ], + } + }, + {"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}}, + {"$sort": {"degree": -1}}, + {"$limit": max_nodes}, + ] + cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True) + + node_ids = [] + async for doc in cursor: + node_id = str(doc["_id"]) + node_ids.append(node_id) + + cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0}) + async for doc in cursor: + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + # As node count reaches the limit, only need to fetch the edges that directly connect to these nodes + edge_cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": node_ids}}, + {"target_node_id": {"$in": node_ids}}, + ] + } + ) + else: + # All nodes and edges are needed + cursor = self.collection.find({}, {"source_ids": 0}) + + async for doc in cursor: + node_id = str(doc["_id"]) + result.nodes.append(self._construct_graph_node(doc["_id"], doc)) + + edge_cursor = self.edge_collection.find({}) + + async for edge in edge_cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + seen_edges.add(edge_id) + result.edges.append(self._construct_graph_edge(edge_id, edge)) + + return result + + async def _bidirectional_bfs_nodes( + self, + node_labels: list[str], + seen_nodes: set[str], + result: KnowledgeGraph, + depth: int = 0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + if depth > max_depth or len(result.nodes) > max_nodes: + return result + + cursor = self.collection.find({"_id": {"$in": node_labels}}) + + async for node in cursor: + node_id = node["_id"] + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append(self._construct_graph_node(node_id, node)) + if len(result.nodes) > max_nodes: + return result + + # Collect neighbors + # Get both inbound and outbound one hop nodes + cursor = self.edge_collection.find( + { + "$or": [ + {"source_node_id": {"$in": node_labels}}, + {"target_node_id": {"$in": node_labels}}, + ] + } + ) + + neighbor_nodes = [] + async for edge in cursor: + if edge["source_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["source_node_id"]) + if edge["target_node_id"] not in seen_nodes: + neighbor_nodes.append(edge["target_node_id"]) + + if neighbor_nodes: + result = await self._bidirectional_bfs_nodes( + neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes + ) + + return result + + async def get_knowledge_subgraph_bidirectional_bfs( + self, + node_label: str, + depth=0, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + + result = await self._bidirectional_bfs_nodes( + [node_label], seen_nodes, result, depth, max_depth, max_nodes + ) + + # Get all edges from seen_nodes + all_node_ids = list(seen_nodes) + cursor = self.edge_collection.find( + { + "$and": [ + {"source_node_id": {"$in": all_node_ids}}, + {"target_node_id": {"$in": all_node_ids}}, + ] + } + ) + + async for edge in cursor: + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + + async def get_knowledge_subgraph_in_out_bound_bfs( + self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES + ) -> KnowledgeGraph: + seen_nodes = set() + seen_edges = set() + result = KnowledgeGraph() + project_doc = { + "source_ids": 0, + "created_at": 0, + "entity_type": 0, + "file_path": 0, + } + + # Verify if starting node exists + start_node = await self.collection.find_one({"_id": node_label}) + if not start_node: + logger.warning(f"Starting node with label {node_label} does not exist!") + return result + + seen_nodes.add(node_label) + result.nodes.append(self._construct_graph_node(node_label, start_node)) + + if max_depth == 0: + return result + + # In MongoDB, depth = 0 means one-hop + max_depth = max_depth - 1 + + pipeline = [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "target_node_id", + "connectToField": "source_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + }, + }, + { + "$unionWith": { + "coll": self._collection_name, + "pipeline": [ + {"$match": {"_id": node_label}}, + {"$project": project_doc}, + { + "$graphLookup": { + "from": self._edge_collection_name, + "startWith": "$_id", + "connectFromField": "source_node_id", + "connectToField": "target_node_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_edges", + } + }, + ], + } + }, + ] + + cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) + node_edges = [] + + # Two records for node_label are returned capturing outbound and inbound connected_edges + async for doc in cursor: + if doc.get("connected_edges", []): + node_edges.extend(doc.get("connected_edges")) + + # Sort the connected edges by depth ascending and weight descending + # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes + node_edges = sorted( + node_edges, + key=lambda x: (x["depth"], -x["weight"]), + ) + + # As order matters, we need to use another list to store the node_id + # And only take the first max_nodes ones + node_ids = [] + for edge in node_edges: + if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes: + node_ids.append(edge["source_node_id"]) + seen_nodes.add(edge["source_node_id"]) + + if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes: + node_ids.append(edge["target_node_id"]) + seen_nodes.add(edge["target_node_id"]) + + # Filter out all the node whose id is same as node_label so that we do not check existence next step + cursor = self.collection.find({"_id": {"$in": node_ids}}) + + async for doc in cursor: + result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc)) + + for edge in node_edges: + if ( + edge["source_node_id"] not in seen_nodes + or edge["target_node_id"] not in seen_nodes + ): + continue + + edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" + if edge_id not in seen_edges: + result.edges.append(self._construct_graph_edge(edge_id, edge)) + seen_edges.add(edge_id) + + return result + async def get_knowledge_graph( self, node_label: str, - max_depth: int = 5, + max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Args: - node_label: Label of the nodes to start from - max_depth: Maximum depth of traversal (default: 5) + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000 Returns: - KnowledgeGraph object containing nodes and edges of the subgraph + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + If a graph is like this and starting from B: + A → B ← C ← F, B -> E, C → D + + Outbound BFS: + B → E + + Inbound BFS: + A → B + C → B + F → C + + Bidirectional BFS: + A → B + B → E + F → C + C → B + C → D """ - label = node_label result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - node_edges = [] + start = time.perf_counter() try: # Optimize pipeline to avoid memory issues with large datasets - if label == "*": - # For getting all nodes, use a simpler pipeline to avoid memory issues - pipeline = [ - {"$limit": max_nodes}, # Limit early to reduce memory usage - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - ] - - # Check if we need to set truncation flag - all_node_count = await self.collection.count_documents({}) - result.is_truncated = all_node_count > max_nodes - else: - # Verify if starting node exists - start_node = await self.collection.find_one({"_id": label}) - if not start_node: - logger.warning(f"Starting node with label {label} does not exist!") - return result - - # For specific node queries, use the original pipeline but optimized - pipeline = [ - {"$match": {"_id": label}}, - { - "$graphLookup": { - "from": self._edge_collection_name, - "startWith": "$_id", - "connectFromField": "target_node_id", - "connectToField": "source_node_id", - "maxDepth": max_depth, - "depthField": "depth", - "as": "connected_edges", - }, - }, - {"$addFields": {"edge_count": {"$size": "$connected_edges"}}}, - {"$sort": {"edge_count": -1}}, - {"$limit": max_nodes}, - ] - - cursor = await self.collection.aggregate(pipeline, allowDiskUse=True) - nodes_processed = 0 - - async for doc in cursor: - # Add the start node - node_id = str(doc["_id"]) - result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={ - k: v - for k, v in doc.items() - if k - not in [ - "_id", - "connected_edges", - "edge_count", - ] - }, - ) + if node_label == "*": + result = await self.get_knowledge_graph_all_by_degree( + max_depth, max_nodes + ) + elif GRAPH_BFS_MODE == "in_out_bound": + result = await self.get_knowledge_subgraph_in_out_bound_bfs( + node_label, max_depth, max_nodes + ) + else: + result = await self.get_knowledge_subgraph_bidirectional_bfs( + node_label, 0, max_depth, max_nodes ) - seen_nodes.add(node_id) - if doc.get("connected_edges", []): - node_edges.extend(doc.get("connected_edges")) - nodes_processed += 1 - - # Additional safety check to prevent memory issues - if nodes_processed >= max_nodes: - result.is_truncated = True - break - - for edge in node_edges: - if ( - edge["source_node_id"] not in seen_nodes - or edge["target_node_id"] not in seen_nodes - ): - continue - - edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}" - if edge_id not in seen_edges: - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type=edge.get("relationship", ""), - source=edge["source_node_id"], - target=edge["target_node_id"], - properties={ - k: v - for k, v in edge.items() - if k - not in [ - "_id", - "source_node_id", - "target_node_id", - "relationship", - ] - }, - ) - ) - seen_edges.add(edge_id) + duration = time.perf_counter() - start logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" + f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" ) except PyMongoError as e: @@ -856,13 +1130,8 @@ class MongoGraphStorage(BaseGraphStorage): try: simple_cursor = self.collection.find({}).limit(max_nodes) async for doc in simple_cursor: - node_id = str(doc["_id"]) result.nodes.append( - KnowledgeGraphNode( - id=node_id, - labels=[node_id], - properties={k: v for k, v in doc.items() if k != "_id"}, - ) + self._construct_graph_node(str(doc["_id"]), doc) ) result.is_truncated = True logger.info( @@ -1023,13 +1292,11 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.debug("vector index already exist") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return # Add current time as Unix timestamp - import time - current_time = int(time.time()) list_data = [ @@ -1114,7 +1381,7 @@ class MongoVectorDBStorage(BaseVectorStorage): Args: ids: List of vector IDs to be deleted """ - logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}") if not ids: return diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index dc161359..bb7233b4 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -106,7 +106,9 @@ class NetworkXStorage(BaseGraphStorage): async def edge_degree(self, src_id: str, tgt_id: str) -> int: graph = await self._get_graph() - return graph.degree(src_id) + graph.degree(tgt_id) + src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0 + tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0 + return src_degree + tgt_degree async def get_edge( self, source_node_id: str, target_node_id: str diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 0ddc7948..d8447664 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -136,6 +136,52 @@ class PostgreSQLDB: except Exception as e: logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}") + async def _migrate_llm_cache_add_cache_type(self): + """Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist""" + try: + # Check if cache_type column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_llm_cache' + AND column_name = 'cache_type' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD COLUMN cache_type VARCHAR(32) NULL + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table" + ) + + # Migrate existing data: extract cache_type from flattened keys + logger.info( + "Migrating existing LLM cache data to populate cache_type field" + ) + update_sql = """ + UPDATE LIGHTRAG_LLM_CACHE + SET cache_type = CASE + WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2) + ELSE 'extract' + END + WHERE cache_type IS NULL + """ + await self.execute(update_sql) + logger.info("Successfully migrated existing LLM cache data") + else: + logger.info( + "cache_type column already exists in LIGHTRAG_LLM_CACHE table" + ) + except Exception as e: + logger.warning( + f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}" + ) + async def _migrate_timestamp_columns(self): """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time""" # Tables and columns that need migration @@ -189,6 +235,239 @@ class PostgreSQLDB: # Log error but don't interrupt the process logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") + async def _migrate_doc_chunks_to_vdb_chunks(self): + """ + Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met. + This migration is intended for users who are upgrading and have an older table structure + where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column. + + """ + try: + # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty + vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS" + vdb_chunks_count_result = await self.query(vdb_chunks_count_sql) + if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0: + logger.info( + "Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data." + ) + return + + # 2. Check if `content_vector` column exists in the old table + check_column_sql = """ + SELECT 1 FROM information_schema.columns + WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector' + """ + column_exists = await self.query(check_column_sql) + if not column_exists: + logger.info( + "Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS" + ) + return + + # 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data + doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS" + doc_chunks_count_result = await self.query(doc_chunks_count_sql) + if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0: + logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.") + return + + # 4. Perform the migration + logger.info( + "Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..." + ) + migration_sql = """ + INSERT INTO LIGHTRAG_VDB_CHUNKS ( + id, workspace, full_doc_id, chunk_order_index, tokens, content, + content_vector, file_path, create_time, update_time + ) + SELECT + id, workspace, full_doc_id, chunk_order_index, tokens, content, + content_vector, file_path, create_time, update_time + FROM LIGHTRAG_DOC_CHUNKS + ON CONFLICT (workspace, id) DO NOTHING; + """ + await self.execute(migration_sql) + logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.") + + except Exception as e: + logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}") + # Do not re-raise, to allow the application to start + + async def _check_llm_cache_needs_migration(self): + """Check if LLM cache data needs migration by examining the first record""" + try: + # Only query the first record to determine format + check_sql = """ + SELECT id FROM LIGHTRAG_LLM_CACHE + ORDER BY create_time ASC + LIMIT 1 + """ + result = await self.query(check_sql) + + if result and result.get("id"): + # If id doesn't contain colon, it's old format + return ":" not in result["id"] + + return False # No data or already new format + except Exception as e: + logger.warning(f"Failed to check LLM cache migration status: {e}") + return False + + async def _migrate_llm_cache_to_flattened_keys(self): + """Migrate LLM cache to flattened key format, recalculating hash values""" + try: + # Get all old format data + old_data_sql = """ + SELECT id, mode, original_prompt, return_value, chunk_id, + create_time, update_time + FROM LIGHTRAG_LLM_CACHE + WHERE id NOT LIKE '%:%' + """ + + old_records = await self.query(old_data_sql, multirows=True) + + if not old_records: + logger.info("No old format LLM cache data found, skipping migration") + return + + logger.info( + f"Found {len(old_records)} old format cache records, starting migration..." + ) + + # Import hash calculation function + from ..utils import compute_args_hash + + migrated_count = 0 + + # Migrate data in batches + for record in old_records: + try: + # Recalculate hash using correct method + new_hash = compute_args_hash( + record["mode"], record["original_prompt"] + ) + + # Determine cache_type based on mode + cache_type = "extract" if record["mode"] == "default" else "unknown" + + # Generate new flattened key + new_key = f"{record['mode']}:{cache_type}:{new_hash}" + + # Insert new format data with cache_type field + insert_sql = """ + INSERT INTO LIGHTRAG_LLM_CACHE + (workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (workspace, mode, id) DO NOTHING + """ + + await self.execute( + insert_sql, + { + "workspace": self.workspace, + "id": new_key, + "mode": record["mode"], + "original_prompt": record["original_prompt"], + "return_value": record["return_value"], + "chunk_id": record["chunk_id"], + "cache_type": cache_type, # Add cache_type field + "create_time": record["create_time"], + "update_time": record["update_time"], + }, + ) + + # Delete old data + delete_sql = """ + DELETE FROM LIGHTRAG_LLM_CACHE + WHERE workspace=$1 AND mode=$2 AND id=$3 + """ + await self.execute( + delete_sql, + { + "workspace": self.workspace, + "mode": record["mode"], + "id": record["id"], # Old id + }, + ) + + migrated_count += 1 + + except Exception as e: + logger.warning( + f"Failed to migrate cache record {record['id']}: {e}" + ) + continue + + logger.info( + f"Successfully migrated {migrated_count} cache records to flattened format" + ) + + except Exception as e: + logger.error(f"LLM cache migration failed: {e}") + # Don't raise exception, allow system to continue startup + + async def _migrate_doc_status_add_chunks_list(self): + """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist""" + try: + # Check if chunks_list column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_doc_status' + AND column_name = 'chunks_list' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_DOC_STATUS + ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table" + ) + else: + logger.info( + "chunks_list column already exists in LIGHTRAG_DOC_STATUS table" + ) + except Exception as e: + logger.warning( + f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}" + ) + + async def _migrate_text_chunks_add_llm_cache_list(self): + """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist""" + try: + # Check if llm_cache_list column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_doc_chunks' + AND column_name = 'llm_cache_list' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_DOC_CHUNKS + ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table" + ) + else: + logger.info( + "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table" + ) + except Exception as e: + logger.warning( + f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}" + ) + async def check_tables(self): # First create all tables for k, v in TABLES.items(): @@ -240,6 +519,44 @@ class PostgreSQLDB: logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}") # Don't throw an exception, allow the initialization process to continue + # Migrate LLM cache table to add cache_type field if needed + try: + await self._migrate_llm_cache_add_cache_type() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}" + ) + # Don't throw an exception, allow the initialization process to continue + + # Finally, attempt to migrate old doc chunks data if needed + try: + await self._migrate_doc_chunks_to_vdb_chunks() + except Exception as e: + logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}") + + # Check and migrate LLM cache to flattened keys if needed + try: + if await self._check_llm_cache_needs_migration(): + await self._migrate_llm_cache_to_flattened_keys() + except Exception as e: + logger.error(f"PostgreSQL, LLM cache migration failed: {e}") + + # Migrate doc status to add chunks_list field if needed + try: + await self._migrate_doc_status_add_chunks_list() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}" + ) + + # Migrate text chunks to add llm_cache_list field if needed + try: + await self._migrate_text_chunks_add_llm_cache_list() + except Exception as e: + logger.error( + f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}" + ) + async def query( self, sql: str, @@ -423,74 +740,139 @@ class PGKVStorage(BaseKVStorage): try: results = await self.db.query(sql, params, multirows=True) + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - result_dict = {} + processed_results = {} for row in results: - mode = row["mode"] - if mode not in result_dict: - result_dict[mode] = {} - result_dict[mode][row["id"]] = row - return result_dict - else: - return {row["id"]: row for row in results} + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + # Map field names and add cache_type for compatibility + processed_row = { + **row, + "return": row.get("return_value", ""), + "cache_type": row.get("original_prompt", "unknow"), + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "mode": row.get("mode", "default"), + "create_time": create_time, + "update_time": create_time if update_time == 0 else update_time, + } + processed_results[row["id"]] = processed_row + return processed_results + + # For text_chunks namespace, parse llm_cache_list JSON string back to list + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + processed_results = {} + for row in results: + llm_cache_list = row.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + row["llm_cache_list"] = llm_cache_list + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + row["create_time"] = create_time + row["update_time"] = ( + create_time if update_time == 0 else update_time + ) + processed_results[row["id"]] = row + return processed_results + + # For other namespaces, return as-is + return {row["id"]: row for row in results} except Exception as e: logger.error(f"Error retrieving all data from {self.namespace}: {e}") return {} async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get doc_full data by id.""" + """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - return res if res else None - else: - response = await self.db.query(sql, params) - return response if response else None + response = await self.db.query(sql, params) - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] - params = {"workspace": self.db.workspace, "mode": mode, "id": id} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - return res - else: - return None + if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + # Parse llm_cache_list JSON string back to list + llm_cache_list = response.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + response["llm_cache_list"] = llm_cache_list + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + response["create_time"] = create_time + response["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results + if response and is_namespace( + self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ): + create_time = response.get("create_time", 0) + update_time = response.get("update_time", 0) + # Map field names and add cache_type for compatibility + response = { + **response, + "return": response.get("return_value", ""), + "cache_type": response.get("cache_type"), + "original_prompt": response.get("original_prompt", ""), + "chunk_id": response.get("chunk_id"), + "mode": response.get("mode", "default"), + "create_time": create_time, + "update_time": create_time if update_time == 0 else update_time, + } + + return response if response else None # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get doc_chunks data by id""" + """Get data by ids""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - modes = set() - dict_res: dict[str, dict] = {} - for row in array_res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in array_res: - dict_res[row["mode"]][row["id"]] = row - return [{k: v} for k, v in dict_res.items()] - else: - return await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, params, multirows=True) - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) + if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + # Parse llm_cache_list JSON string back to list for each result + for result in results: + llm_cache_list = result.get("llm_cache_list", []) + if isinstance(llm_cache_list, str): + try: + llm_cache_list = json.loads(llm_cache_list) + except json.JSONDecodeError: + llm_cache_list = [] + result["llm_cache_list"] = llm_cache_list + create_time = result.get("create_time", 0) + update_time = result.get("update_time", 0) + result["create_time"] = create_time + result["update_time"] = create_time if update_time == 0 else update_time + + # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results + if results and is_namespace( + self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ): + processed_results = [] + for row in results: + create_time = row.get("create_time", 0) + update_time = row.get("update_time", 0) + # Map field names and add cache_type for compatibility + processed_row = { + **row, + "return": row.get("return_value", ""), + "cache_type": row.get("cache_type"), + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "mode": row.get("mode", "default"), + "create_time": create_time, + "update_time": create_time if update_time == 0 else update_time, + } + processed_results.append(processed_row) + return processed_results + + return results if results else [] async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -520,7 +902,22 @@ class PGKVStorage(BaseKVStorage): return if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - pass + current_time = datetime.datetime.now(timezone.utc) + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_text_chunk"] + _data = { + "workspace": self.db.workspace, + "id": k, + "tokens": v["tokens"], + "chunk_order_index": v["chunk_order_index"], + "full_doc_id": v["full_doc_id"], + "content": v["content"], + "file_path": v["file_path"], + "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), + "create_time": current_time, + "update_time": current_time, + } + await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_doc_full"] @@ -531,19 +928,21 @@ class PGKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - for mode, items in data.items(): - for k, v in items.items(): - upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - _data = { - "workspace": self.db.workspace, - "id": k, - "original_prompt": v["original_prompt"], - "return_value": v["return"], - "mode": mode, - "chunk_id": v.get("chunk_id"), - } + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, # Use flattened key as id + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "mode": v.get("mode", "default"), # Get mode from data + "chunk_id": v.get("chunk_id"), + "cache_type": v.get( + "cache_type", "extract" + ), # Get cache_type from data + } - await self.db.execute(upsert_sql, _data) + await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -949,8 +1348,8 @@ class PGDocStatusStorage(DocStatusStorage): else: exist_keys = [] new_keys = set([s for s in keys if s not in exist_keys]) - print(f"keys: {keys}") - print(f"new_keys: {new_keys}") + # print(f"keys: {keys}") + # print(f"new_keys: {new_keys}") return new_keys except Exception as e: logger.error( @@ -965,6 +1364,14 @@ class PGDocStatusStorage(DocStatusStorage): if result is None or result == []: return None else: + # Parse chunks_list JSON string back to list + chunks_list = result[0].get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + return dict( content=result[0]["content"], content_length=result[0]["content_length"], @@ -974,6 +1381,7 @@ class PGDocStatusStorage(DocStatusStorage): created_at=result[0]["created_at"], updated_at=result[0]["updated_at"], file_path=result[0]["file_path"], + chunks_list=chunks_list, ) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -988,19 +1396,32 @@ class PGDocStatusStorage(DocStatusStorage): if not results: return [] - return [ - { - "content": row["content"], - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - "file_path": row["file_path"], - } - for row in results - ] + + processed_results = [] + for row in results: + # Parse chunks_list JSON string back to list + chunks_list = row.get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + + processed_results.append( + { + "content": row["content"], + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "file_path": row["file_path"], + "chunks_list": chunks_list, + } + ) + + return processed_results async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -1021,8 +1442,18 @@ class PGDocStatusStorage(DocStatusStorage): sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status.value} result = await self.db.query(sql, params, True) - docs_by_status = { - element["id"]: DocProcessingStatus( + + docs_by_status = {} + for element in result: + # Parse chunks_list JSON string back to list + chunks_list = element.get("chunks_list", []) + if isinstance(chunks_list, str): + try: + chunks_list = json.loads(chunks_list) + except json.JSONDecodeError: + chunks_list = [] + + docs_by_status[element["id"]] = DocProcessingStatus( content=element["content"], content_summary=element["content_summary"], content_length=element["content_length"], @@ -1031,9 +1462,9 @@ class PGDocStatusStorage(DocStatusStorage): updated_at=element["updated_at"], chunks_count=element["chunks_count"], file_path=element["file_path"], + chunks_list=chunks_list, ) - for element in result - } + return docs_by_status async def index_done_callback(self) -> None: @@ -1097,10 +1528,10 @@ class PGDocStatusStorage(DocStatusStorage): logger.warning(f"Unable to parse datetime string: {dt_str}") return None - # Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations - # Both fields are updated from the input data in both INSERT and UPDATE cases - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at) - values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) + # Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations + # All fields are updated from the input data in both INSERT and UPDATE cases + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at) + values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) on conflict(id,workspace) do update set content = EXCLUDED.content, content_summary = EXCLUDED.content_summary, @@ -1108,6 +1539,7 @@ class PGDocStatusStorage(DocStatusStorage): chunks_count = EXCLUDED.chunks_count, status = EXCLUDED.status, file_path = EXCLUDED.file_path, + chunks_list = EXCLUDED.chunks_list, created_at = EXCLUDED.created_at, updated_at = EXCLUDED.updated_at""" for k, v in data.items(): @@ -1115,7 +1547,7 @@ class PGDocStatusStorage(DocStatusStorage): created_at = parse_datetime(v.get("created_at")) updated_at = parse_datetime(v.get("updated_at")) - # chunks_count is optional + # chunks_count and chunks_list are optional await self.db.execute( sql, { @@ -1127,6 +1559,7 @@ class PGDocStatusStorage(DocStatusStorage): "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, "status": v["status"], "file_path": v["file_path"], + "chunks_list": json.dumps(v.get("chunks_list", [])), "created_at": created_at, # Use the converted datetime object "updated_at": updated_at, # Use the converted datetime object }, @@ -2409,7 +2842,7 @@ class PGGraphStorage(BaseGraphStorage): NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION", NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS", @@ -2438,6 +2871,21 @@ TABLES = { }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + id VARCHAR(255), + workspace VARCHAR(255), + full_doc_id VARCHAR(256), + chunk_order_index INTEGER, + tokens INTEGER, + content TEXT, + file_path VARCHAR(256), + llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, + create_time TIMESTAMP(0) WITH TIME ZONE, + update_time TIMESTAMP(0) WITH TIME ZONE, + CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_VDB_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS ( id VARCHAR(255), workspace VARCHAR(255), full_doc_id VARCHAR(256), @@ -2448,7 +2896,7 @@ TABLES = { file_path VARCHAR(256), create_time TIMESTAMP(0) WITH TIME ZONE, update_time TIMESTAMP(0) WITH TIME ZONE, - CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) + CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) )""" }, "LIGHTRAG_VDB_ENTITY": { @@ -2503,6 +2951,7 @@ TABLES = { chunks_count int4 NULL, status varchar(64) NULL, file_path TEXT NULL, + chunks_list JSONB NULL DEFAULT '[]'::jsonb, created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) @@ -2517,24 +2966,30 @@ SQL_TEMPLATES = { FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 """, "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path + chunk_order_index, full_doc_id, file_path, + COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, + create_time, update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 + "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + create_time, update_time + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 """, - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path + chunk_order_index, full_doc_id, file_path, + COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, + create_time, update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, + create_time, update_time + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) @@ -2542,16 +2997,31 @@ SQL_TEMPLATES = { ON CONFLICT (workspace,id) DO UPDATE SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id) - VALUES ($1, $2, $3, $4, $5, $6) + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (workspace,mode,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, return_value=EXCLUDED.return_value, mode=EXCLUDED.mode, chunk_id=EXCLUDED.chunk_id, + cache_type=EXCLUDED.cache_type, update_time = CURRENT_TIMESTAMP """, - "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, + "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, + chunk_order_index, full_doc_id, content, file_path, llm_cache_list, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (workspace,id) DO UPDATE + SET tokens=EXCLUDED.tokens, + chunk_order_index=EXCLUDED.chunk_order_index, + full_doc_id=EXCLUDED.full_doc_id, + content = EXCLUDED.content, + file_path=EXCLUDED.file_path, + llm_cache_list=EXCLUDED.llm_cache_list, + update_time = EXCLUDED.update_time + """, + # SQL for VectorStorage + "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -2564,7 +3034,6 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time = EXCLUDED.update_time """, - # SQL for VectorStorage "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) @@ -2591,7 +3060,7 @@ SQL_TEMPLATES = { "relationships": """ WITH relevant_chunks AS ( SELECT id as chunk_id - FROM LIGHTRAG_DOC_CHUNKS + FROM LIGHTRAG_VDB_CHUNKS WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) ) SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at @@ -2608,7 +3077,7 @@ SQL_TEMPLATES = { "entities": """ WITH relevant_chunks AS ( SELECT id as chunk_id - FROM LIGHTRAG_DOC_CHUNKS + FROM LIGHTRAG_VDB_CHUNKS WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) ) SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM @@ -2625,13 +3094,13 @@ SQL_TEMPLATES = { "chunks": """ WITH relevant_chunks AS ( SELECT id as chunk_id - FROM LIGHTRAG_DOC_CHUNKS + FROM LIGHTRAG_VDB_CHUNKS WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) ) SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM ( SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_DOC_CHUNKS + FROM LIGHTRAG_VDB_CHUNKS WHERE workspace=$1 AND id IN (SELECT chunk_id FROM relevant_chunks) ) as chunk_distances diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 885a23ca..dada278a 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"Inserting {len(data)} to {self.namespace}") if not data: return diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 65c25bfc..dba228ca 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,9 +1,10 @@ import os -from typing import Any, final +from typing import Any, final, Union from dataclasses import dataclass import pipmaster as pm import configparser from contextlib import asynccontextmanager +import threading if not pm.is_installed("redis"): pm.install("redis") @@ -13,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore from redis.exceptions import RedisError, ConnectionError # type: ignore from lightrag.utils import logger -from lightrag.base import BaseKVStorage +from lightrag.base import ( + BaseKVStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, +) import json @@ -26,6 +32,41 @@ SOCKET_TIMEOUT = 5.0 SOCKET_CONNECT_TIMEOUT = 3.0 +class RedisConnectionManager: + """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI""" + + _pools = {} + _lock = threading.Lock() + + @classmethod + def get_pool(cls, redis_url: str) -> ConnectionPool: + """Get or create a connection pool for the given Redis URL""" + if redis_url not in cls._pools: + with cls._lock: + if redis_url not in cls._pools: + cls._pools[redis_url] = ConnectionPool.from_url( + redis_url, + max_connections=MAX_CONNECTIONS, + decode_responses=True, + socket_timeout=SOCKET_TIMEOUT, + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, + ) + logger.info(f"Created shared Redis connection pool for {redis_url}") + return cls._pools[redis_url] + + @classmethod + def close_all_pools(cls): + """Close all connection pools (for cleanup)""" + with cls._lock: + for url, pool in cls._pools.items(): + try: + pool.disconnect() + logger.info(f"Closed Redis connection pool for {url}") + except Exception as e: + logger.error(f"Error closing Redis pool for {url}: {e}") + cls._pools.clear() + + @final @dataclass class RedisKVStorage(BaseKVStorage): @@ -33,19 +74,28 @@ class RedisKVStorage(BaseKVStorage): redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) - # Create a connection pool with limits - self._pool = ConnectionPool.from_url( - redis_url, - max_connections=MAX_CONNECTIONS, - decode_responses=True, - socket_timeout=SOCKET_TIMEOUT, - socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, - ) + # Use shared connection pool + self._pool = RedisConnectionManager.get_pool(redis_url) self._redis = Redis(connection_pool=self._pool) logger.info( - f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections" + f"Initialized Redis KV storage for {self.namespace} using shared connection pool" ) + async def initialize(self): + """Initialize Redis connection and migrate legacy cache structure if needed""" + # Test connection + try: + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info(f"Connected to Redis for namespace {self.namespace}") + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + raise + + # Migrate legacy cache structure if this is a cache namespace + if self.namespace.endswith("_cache"): + await self._migrate_legacy_cache_structure() + @asynccontextmanager async def _get_redis_connection(self): """Safe context manager for Redis operations.""" @@ -82,7 +132,13 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: try: data = await redis.get(f"{self.namespace}:{id}") - return json.loads(data) if data else None + if data: + result = json.loads(data) + # Ensure time fields are present, provide default values for old data + result.setdefault("create_time", 0) + result.setdefault("update_time", 0) + return result + return None except json.JSONDecodeError as e: logger.error(f"JSON decode error for id {id}: {e}") return None @@ -94,35 +150,113 @@ class RedisKVStorage(BaseKVStorage): for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() - return [json.loads(result) if result else None for result in results] + + processed_results = [] + for result in results: + if result: + data = json.loads(result) + # Ensure time fields are present for all documents + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + processed_results.append(data) + else: + processed_results.append(None) + + return processed_results except json.JSONDecodeError as e: logger.error(f"JSON decode error in batch get: {e}") return [None] * len(ids) + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._get_redis_connection() as redis: + try: + # Get all keys for this namespace + keys = await redis.keys(f"{self.namespace}:*") + + if not keys: + return {} + + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Build result dictionary + result = {} + for key, value in zip(keys, values): + if value: + # Extract the ID part (after namespace:) + key_id = key.split(":", 1)[1] + try: + data = json.loads(value) + # Ensure time fields are present for all documents + data.setdefault("create_time", 0) + data.setdefault("update_time", 0) + result[key_id] = data + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for key {key}: {e}") + continue + + return result + except Exception as e: + logger.error(f"Error getting all data from Redis: {e}") + return {} + async def filter_keys(self, keys: set[str]) -> set[str]: async with self._get_redis_connection() as redis: pipe = redis.pipeline() - for key in keys: + keys_list = list(keys) # Convert set to list for indexing + for key in keys_list: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() - existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return - logger.info(f"Inserting {len(data)} items to {self.namespace}") + import time + + current_time = int(time.time()) # Get current Unix timestamp + async with self._get_redis_connection() as redis: try: + # Check which keys already exist to determine create vs update + pipe = redis.pipeline() + for k in data.keys(): + pipe.exists(f"{self.namespace}:{k}") + exists_results = await pipe.execute() + + # Add timestamps to data + for i, (k, v) in enumerate(data.items()): + # For text_chunks namespace, ensure llm_cache_list field exists + if "text_chunks" in self.namespace: + if "llm_cache_list" not in v: + v["llm_cache_list"] = [] + + # Add timestamps based on whether key exists + if exists_results[i]: # Key exists, only update update_time + v["update_time"] = current_time + else: # New key, set both create_time and update_time + v["create_time"] = current_time + v["update_time"] = current_time + + v["_id"] = k + + # Store the data pipe = redis.pipeline() for k, v in data.items(): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) await pipe.execute() - for k in data: - data[k]["_id"] = k except json.JSONEncodeError as e: logger.error(f"JSON encode error during upsert: {e}") raise @@ -148,13 +282,13 @@ class RedisKVStorage(BaseKVStorage): ) async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by by cache mode + """Delete specific records from storage by cache mode Importance notes for Redis storage: 1. This will immediately delete the specified cache modes from Redis Args: - modes (list[str]): List of cache mode to be drop from storage + modes (list[str]): List of cache modes to be dropped from storage Returns: True: if the cache drop successfully @@ -164,9 +298,47 @@ class RedisKVStorage(BaseKVStorage): return False try: - await self.delete(modes) + async with self._get_redis_connection() as redis: + keys_to_delete = [] + + # Find matching keys for each mode using SCAN + for mode in modes: + # Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash} + pattern = f"{self.namespace}:{mode}:*" + cursor = 0 + mode_keys = [] + + while True: + cursor, keys = await redis.scan( + cursor, match=pattern, count=1000 + ) + if keys: + mode_keys.extend(keys) + + if cursor == 0: + break + + keys_to_delete.extend(mode_keys) + logger.info( + f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'" + ) + + if keys_to_delete: + # Batch delete + pipe = redis.pipeline() + for key in keys_to_delete: + pipe.delete(key) + results = await pipe.execute() + deleted_count = sum(results) + logger.info( + f"Dropped {deleted_count} cache entries for modes: {modes}" + ) + else: + logger.warning(f"No cache entries found for modes: {modes}") + return True - except Exception: + except Exception as e: + logger.error(f"Error dropping cache by modes in Redis: {e}") return False async def drop(self) -> dict[str, str]: @@ -177,24 +349,370 @@ class RedisKVStorage(BaseKVStorage): """ async with self._get_redis_connection() as redis: try: - keys = await redis.keys(f"{self.namespace}:*") + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.namespace}:*" + cursor = 0 + deleted_count = 0 - if keys: - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count = sum(results) + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - logger.info(f"Dropped {deleted_count} keys from {self.namespace}") - return { - "status": "success", - "message": f"{deleted_count} keys dropped", - } - else: - logger.info(f"No keys found to drop in {self.namespace}") - return {"status": "success", "message": "no keys to drop"} + if cursor == 0: + break + + logger.info(f"Dropped {deleted_count} keys from {self.namespace}") + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } except Exception as e: logger.error(f"Error dropping keys from {self.namespace}: {e}") return {"status": "error", "message": str(e)} + + async def _migrate_legacy_cache_structure(self): + """Migrate legacy nested cache structure to flattened structure for Redis + + Redis already stores data in a flattened way, but we need to check for + legacy keys that might contain nested JSON structures and migrate them. + + Early exit if any flattened key is found (indicating migration already done). + """ + from lightrag.utils import generate_cache_key + + async with self._get_redis_connection() as redis: + # Get all keys for this namespace + keys = await redis.keys(f"{self.namespace}:*") + + if not keys: + return + + # Check if we have any flattened keys already - if so, skip migration + has_flattened_keys = False + keys_to_migrate = [] + + for key in keys: + # Extract the ID part (after namespace:) + key_id = key.split(":", 1)[1] + + # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash) + if ":" in key_id and len(key_id.split(":")) == 3: + has_flattened_keys = True + break # Early exit - migration already done + + # Get the data to check if it's a legacy nested structure + data = await redis.get(key) + if data: + try: + parsed_data = json.loads(data) + # Check if this looks like a legacy cache mode with nested structure + if isinstance(parsed_data, dict) and all( + isinstance(v, dict) and "return" in v + for v in parsed_data.values() + ): + keys_to_migrate.append((key, key_id, parsed_data)) + except json.JSONDecodeError: + continue + + # If we found any flattened keys, assume migration is already done + if has_flattened_keys: + logger.debug( + f"Found flattened cache keys in {self.namespace}, skipping migration" + ) + return + + if not keys_to_migrate: + return + + # Perform migration + pipe = redis.pipeline() + migration_count = 0 + + for old_key, mode, nested_data in keys_to_migrate: + # Delete the old key + pipe.delete(old_key) + + # Create new flattened keys + for cache_hash, cache_entry in nested_data.items(): + cache_type = cache_entry.get("cache_type", "extract") + flattened_key = generate_cache_key(mode, cache_type, cache_hash) + full_key = f"{self.namespace}:{flattened_key}" + pipe.set(full_key, json.dumps(cache_entry)) + migration_count += 1 + + await pipe.execute() + + if migration_count > 0: + logger.info( + f"Migrated {migration_count} legacy cache entries to flattened structure in Redis" + ) + + +@final +@dataclass +class RedisDocStatusStorage(DocStatusStorage): + """Redis implementation of document status storage""" + + def __post_init__(self): + redis_url = os.environ.get( + "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") + ) + # Use shared connection pool + self._pool = RedisConnectionManager.get_pool(redis_url) + self._redis = Redis(connection_pool=self._pool) + logger.info( + f"Initialized Redis doc status storage for {self.namespace} using shared connection pool" + ) + + async def initialize(self): + """Initialize Redis connection""" + try: + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info( + f"Connected to Redis for doc status namespace {self.namespace}" + ) + except Exception as e: + logger.error(f"Failed to connect to Redis for doc status: {e}") + raise + + @asynccontextmanager + async def _get_redis_connection(self): + """Safe context manager for Redis operations.""" + try: + yield self._redis + except ConnectionError as e: + logger.error(f"Redis connection error in doc status {self.namespace}: {e}") + raise + except RedisError as e: + logger.error(f"Redis operation error in doc status {self.namespace}: {e}") + raise + except Exception as e: + logger.error( + f"Unexpected error in Redis doc status operation for {self.namespace}: {e}" + ) + raise + + async def close(self): + """Close the Redis connection.""" + if hasattr(self, "_redis") and self._redis: + await self._redis.close() + logger.debug(f"Closed Redis connection for doc status {self.namespace}") + + async def __aenter__(self): + """Support for async context manager.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Ensure Redis resources are cleaned up when exiting context.""" + await self.close() + + async def filter_keys(self, keys: set[str]) -> set[str]: + """Return keys that should be processed (not in storage or not successfully processed)""" + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + keys_list = list(keys) + for key in keys_list: + pipe.exists(f"{self.namespace}:{key}") + results = await pipe.execute() + + existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + async with self._get_redis_connection() as redis: + try: + pipe = redis.pipeline() + for id in ids: + pipe.get(f"{self.namespace}:{id}") + results = await pipe.execute() + + for result_data in results: + if result_data: + try: + result.append(json.loads(result_data)) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in get_by_ids: {e}") + continue + except Exception as e: + logger.error(f"Error in get_by_ids: {e}") + return result + + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + counts = {status.value: 0 for status in DocStatus} + async with self._get_redis_connection() as redis: + try: + # Use SCAN to iterate through all keys in the namespace + cursor = 0 + while True: + cursor, keys = await redis.scan( + cursor, match=f"{self.namespace}:*", count=1000 + ) + if keys: + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Count statuses + for value in values: + if value: + try: + doc_data = json.loads(value) + status = doc_data.get("status") + if status in counts: + counts[status] += 1 + except json.JSONDecodeError: + continue + + if cursor == 0: + break + except Exception as e: + logger.error(f"Error getting status counts: {e}") + + return counts + + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" + result = {} + async with self._get_redis_connection() as redis: + try: + # Use SCAN to iterate through all keys in the namespace + cursor = 0 + while True: + cursor, keys = await redis.scan( + cursor, match=f"{self.namespace}:*", count=1000 + ) + if keys: + # Get all values in batch + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + # Filter by status and create DocProcessingStatus objects + for key, value in zip(keys, values): + if value: + try: + doc_data = json.loads(value) + if doc_data.get("status") == status.value: + # Extract document ID from key + doc_id = key.split(":", 1)[1] + + # Make a copy of the data to avoid modifying the original + data = doc_data.copy() + # If content is missing, use content_summary as content + if ( + "content" not in data + and "content_summary" in data + ): + data["content"] = data["content_summary"] + # If file_path is not in data, use document id as file path + if "file_path" not in data: + data["file_path"] = "no-file-path" + + result[doc_id] = DocProcessingStatus(**data) + except (json.JSONDecodeError, KeyError) as e: + logger.error( + f"Error processing document {key}: {e}" + ) + continue + + if cursor == 0: + break + except Exception as e: + logger.error(f"Error getting docs by status: {e}") + + return result + + async def index_done_callback(self) -> None: + """Redis handles persistence automatically""" + pass + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Insert or update document status data""" + if not data: + return + + logger.debug(f"Inserting {len(data)} records to {self.namespace}") + async with self._get_redis_connection() as redis: + try: + # Ensure chunks_list field exists for new documents + for doc_id, doc_data in data.items(): + if "chunks_list" not in doc_data: + doc_data["chunks_list"] = [] + + pipe = redis.pipeline() + for k, v in data.items(): + pipe.set(f"{self.namespace}:{k}", json.dumps(v)) + await pipe.execute() + except json.JSONEncodeError as e: + logger.error(f"JSON encode error during upsert: {e}") + raise + + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async with self._get_redis_connection() as redis: + try: + data = await redis.get(f"{self.namespace}:{id}") + return json.loads(data) if data else None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for id {id}: {e}") + return None + + async def delete(self, doc_ids: list[str]) -> None: + """Delete specific records from storage by their IDs""" + if not doc_ids: + return + + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + for doc_id in doc_ids: + pipe.delete(f"{self.namespace}:{doc_id}") + + results = await pipe.execute() + deleted_count = sum(results) + logger.info( + f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}" + ) + + async def drop(self) -> dict[str, str]: + """Drop all document status data from storage and clean up resources""" + try: + async with self._get_redis_connection() as redis: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.namespace}:*" + cursor = 0 + deleted_count = 0 + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) + + if cursor == 0: + break + + logger.info( + f"Dropped {deleted_count} doc status keys from {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping doc status {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 748b9ef8..2ab9f89a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -22,6 +22,7 @@ from typing import ( Dict, ) from lightrag.constants import ( + DEFAULT_MAX_GLEANING, DEFAULT_MAX_TOKEN_SUMMARY, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, ) @@ -124,7 +125,9 @@ class LightRAG: # Entity extraction # --- - entity_extract_max_gleaning: int = field(default=1) + entity_extract_max_gleaning: int = field( + default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int) + ) """Maximum number of entity extraction attempts for ambiguous content.""" summary_to_max_tokens: int = field( @@ -346,6 +349,7 @@ class LightRAG: # Fix global_config now global_config = asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -394,13 +398,13 @@ class LightRAG: embedding_func=self.embedding_func, ) - # TODO: deprecating, text_chunks is redundant with chunks_vdb self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION @@ -949,6 +953,7 @@ class LightRAG: **dp, "full_doc_id": doc_id, "file_path": file_path, # Add file path to each chunk + "llm_cache_list": [], # Initialize empty LLM cache list for each chunk } for dp in self.chunking_func( self.tokenizer, @@ -960,14 +965,17 @@ class LightRAG: ) } - # Process document (text chunks and full docs) in parallel - # Create tasks with references for potential cancellation + # Process document in two stages + # Stage 1: Process text chunks and docs (parallel execution) doc_status_task = asyncio.create_task( self.doc_status.upsert( { doc_id: { "status": DocStatus.PROCESSING, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # Save chunks list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -983,11 +991,6 @@ class LightRAG: chunks_vdb_task = asyncio.create_task( self.chunks_vdb.upsert(chunks) ) - entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph( - chunks, pipeline_status, pipeline_status_lock - ) - ) full_docs_task = asyncio.create_task( self.full_docs.upsert( {doc_id: {"content": status_doc.content}} @@ -996,14 +999,26 @@ class LightRAG: text_chunks_task = asyncio.create_task( self.text_chunks.upsert(chunks) ) - tasks = [ + + # First stage tasks (parallel execution) + first_stage_tasks = [ doc_status_task, chunks_vdb_task, - entity_relation_task, full_docs_task, text_chunks_task, ] - await asyncio.gather(*tasks) + entity_relation_task = None + + # Execute first stage tasks + await asyncio.gather(*first_stage_tasks) + + # Stage 2: Process entity relation graph (after text_chunks are saved) + entity_relation_task = asyncio.create_task( + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock + ) + ) + await entity_relation_task file_extraction_stage_ok = True except Exception as e: @@ -1018,14 +1033,14 @@ class LightRAG: ) pipeline_status["history_messages"].append(error_msg) - # Cancel other tasks as they are no longer meaningful - for task in [ - chunks_vdb_task, - entity_relation_task, - full_docs_task, - text_chunks_task, - ]: - if not task.done(): + # Cancel tasks that are not yet completed + all_tasks = first_stage_tasks + ( + [entity_relation_task] + if entity_relation_task + else [] + ) + for task in all_tasks: + if task and not task.done(): task.cancel() # Persistent llm cache @@ -1075,6 +1090,9 @@ class LightRAG: doc_id: { "status": DocStatus.PROCESSED, "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # 保留 chunks_list "content": status_doc.content, "content_summary": status_doc.content_summary, "content_length": status_doc.content_length, @@ -1193,6 +1211,7 @@ class LightRAG: pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, + text_chunks_storage=self.text_chunks, ) return chunk_results except Exception as e: @@ -1723,28 +1742,10 @@ class LightRAG: file_path="", ) - # 2. Get all chunks related to this document - try: - all_chunks = await self.text_chunks.get_all() - related_chunks = { - chunk_id: chunk_data - for chunk_id, chunk_data in all_chunks.items() - if isinstance(chunk_data, dict) - and chunk_data.get("full_doc_id") == doc_id - } + # 2. Get chunk IDs from document status + chunk_ids = set(doc_status_data.get("chunks_list", [])) - # Update pipeline status after getting chunks count - async with pipeline_status_lock: - log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}") - raise Exception(f"Failed to retrieve document chunks: {e}") from e - - if not related_chunks: + if not chunk_ids: logger.warning(f"No chunks found for document {doc_id}") # Mark that deletion operations have started deletion_operations_started = True @@ -1775,7 +1776,6 @@ class LightRAG: file_path=file_path, ) - chunk_ids = set(related_chunks.keys()) # Mark that deletion operations have started deletion_operations_started = True @@ -1799,26 +1799,12 @@ class LightRAG: ) ) - # Update pipeline status after getting affected_nodes - async with pipeline_status_lock: - log_message = f"Found {len(affected_nodes)} affected entities" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - affected_edges = ( await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( list(chunk_ids) ) ) - # Update pipeline status after getting affected_edges - async with pipeline_status_lock: - log_message = f"Found {len(affected_edges)} affected relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - except Exception as e: logger.error(f"Failed to analyze affected graph elements: {e}") raise Exception(f"Failed to analyze graph dependencies: {e}") from e @@ -1836,6 +1822,14 @@ class LightRAG: elif remaining_sources != sources: entities_to_rebuild[node_label] = remaining_sources + async with pipeline_status_lock: + log_message = ( + f"Found {len(entities_to_rebuild)} affected entities" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + # Process relationships for edge_data in affected_edges: src = edge_data.get("source") @@ -1857,6 +1851,14 @@ class LightRAG: elif remaining_sources != sources: relationships_to_rebuild[edge_tuple] = remaining_sources + async with pipeline_status_lock: + log_message = ( + f"Found {len(relationships_to_rebuild)} affected relations" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + except Exception as e: logger.error(f"Failed to process graph analysis results: {e}") raise Exception(f"Failed to process graph dependencies: {e}") from e @@ -1940,17 +1942,13 @@ class LightRAG: knowledge_graph_inst=self.chunk_entity_relation_graph, entities_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, - text_chunks=self.text_chunks, + text_chunks_storage=self.text_chunks, llm_response_cache=self.llm_response_cache, global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, ) - async with pipeline_status_lock: - log_message = f"Successfully rebuilt {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - except Exception as e: logger.error(f"Failed to rebuild knowledge from chunks: {e}") raise Exception( diff --git a/lightrag/operate.py b/lightrag/operate.py index 77568161..60425148 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -25,6 +25,7 @@ from .utils import ( CacheData, get_conversation_turns, use_llm_func_with_cache, + update_chunk_cache_list, ) from .base import ( BaseGraphStorage, @@ -103,8 +104,6 @@ async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, global_config: dict, - pipeline_status: dict = None, - pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> str: """Handle entity relation summary @@ -247,9 +246,11 @@ async def _rebuild_knowledge_from_chunks( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks: BaseKVStorage, + text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, global_config: dict[str, str], + pipeline_status: dict | None = None, + pipeline_status_lock=None, ) -> None: """Rebuild entity and relationship descriptions from cached extraction results @@ -259,9 +260,12 @@ async def _rebuild_knowledge_from_chunks( Args: entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids + text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data} """ if not entities_to_rebuild and not relationships_to_rebuild: return + rebuilt_entities_count = 0 + rebuilt_relationships_count = 0 # Get all referenced chunk IDs all_referenced_chunk_ids = set() @@ -270,36 +274,74 @@ async def _rebuild_knowledge_from_chunks( for chunk_ids in relationships_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) - logger.debug( - f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" - ) + status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions" + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) - # Get cached extraction results for these chunks + # Get cached extraction results for these chunks using storage + # cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at] cached_results = await _get_cached_extraction_results( - llm_response_cache, all_referenced_chunk_ids + llm_response_cache, + all_referenced_chunk_ids, + text_chunks_storage=text_chunks_storage, ) if not cached_results: - logger.warning("No cached extraction results found, cannot rebuild") + status_message = "No cached extraction results found, cannot rebuild" + logger.warning(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) return # Process cached results to get entities and relationships for each chunk chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} - for chunk_id, extraction_result in cached_results.items(): + for chunk_id, extraction_results in cached_results.items(): try: - entities, relationships = await _parse_extraction_result( - text_chunks=text_chunks, - extraction_result=extraction_result, - chunk_id=chunk_id, - ) - chunk_entities[chunk_id] = entities - chunk_relationships[chunk_id] = relationships + # Handle multiple extraction results per chunk + chunk_entities[chunk_id] = defaultdict(list) + chunk_relationships[chunk_id] = defaultdict(list) + + # process multiple LLM extraction results for a single chunk_id + for extraction_result in extraction_results: + entities, relationships = await _parse_extraction_result( + text_chunks_storage=text_chunks_storage, + extraction_result=extraction_result, + chunk_id=chunk_id, + ) + + # Merge entities and relationships from this extraction result + # Only keep the first occurrence of each entity_name in the same chunk_id + for entity_name, entity_list in entities.items(): + if ( + entity_name not in chunk_entities[chunk_id] + or len(chunk_entities[chunk_id][entity_name]) == 0 + ): + chunk_entities[chunk_id][entity_name].extend(entity_list) + + # Only keep the first occurrence of each rel_key in the same chunk_id + for rel_key, rel_list in relationships.items(): + if ( + rel_key not in chunk_relationships[chunk_id] + or len(chunk_relationships[chunk_id][rel_key]) == 0 + ): + chunk_relationships[chunk_id][rel_key].extend(rel_list) + except Exception as e: - logger.error( + status_message = ( f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" ) + logger.info(status_message) # Per requirement, change to info + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) continue # Rebuild entities @@ -314,11 +356,22 @@ async def _rebuild_knowledge_from_chunks( llm_response_cache=llm_response_cache, global_config=global_config, ) - logger.debug( - f"Rebuilt entity {entity_name} from {len(chunk_ids)} cached extractions" + rebuilt_entities_count += 1 + status_message = ( + f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks" ) + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) except Exception as e: - logger.error(f"Failed to rebuild entity {entity_name}: {e}") + status_message = f"Failed to rebuild entity {entity_name}: {e}" + logger.info(status_message) # Per requirement, change to info + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) # Rebuild relationships for (src, tgt), chunk_ids in relationships_to_rebuild.items(): @@ -333,53 +386,112 @@ async def _rebuild_knowledge_from_chunks( llm_response_cache=llm_response_cache, global_config=global_config, ) - logger.debug( - f"Rebuilt relationship {src}-{tgt} from {len(chunk_ids)} cached extractions" + rebuilt_relationships_count += 1 + status_message = ( + f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks" ) + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) except Exception as e: - logger.error(f"Failed to rebuild relationship {src}-{tgt}: {e}") + status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}" + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) - logger.debug("Completed rebuilding knowledge from cached extractions") + status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships." + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) async def _get_cached_extraction_results( - llm_response_cache: BaseKVStorage, chunk_ids: set[str] -) -> dict[str, str]: + llm_response_cache: BaseKVStorage, + chunk_ids: set[str], + text_chunks_storage: BaseKVStorage, +) -> dict[str, list[str]]: """Get cached extraction results for specific chunk IDs Args: + llm_response_cache: LLM response cache storage chunk_ids: Set of chunk IDs to get cached results for + text_chunks_data: Pre-loaded chunk data (optional, for performance) + text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None) Returns: - Dict mapping chunk_id -> extraction_result_text + Dict mapping chunk_id -> list of extraction_result_text """ cached_results = {} - # Get all cached data for "default" mode (entity extraction cache) - default_cache = await llm_response_cache.get_by_id("default") or {} + # Collect all LLM cache IDs from chunks + all_cache_ids = set() - for cache_key, cache_entry in default_cache.items(): + # Read from storage + chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) + for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list): + if chunk_data and isinstance(chunk_data, dict): + llm_cache_list = chunk_data.get("llm_cache_list", []) + if llm_cache_list: + all_cache_ids.update(llm_cache_list) + else: + logger.warning( + f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}" + ) + + if not all_cache_ids: + logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") + return cached_results + + # Batch get LLM cache entries + cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) + + # Process cache entries and group by chunk_id + valid_entries = 0 + for cache_id, cache_entry in zip(all_cache_ids, cache_data_list): if ( - isinstance(cache_entry, dict) + cache_entry is not None + and isinstance(cache_entry, dict) and cache_entry.get("cache_type") == "extract" and cache_entry.get("chunk_id") in chunk_ids ): chunk_id = cache_entry["chunk_id"] extraction_result = cache_entry["return"] - cached_results[chunk_id] = extraction_result + create_time = cache_entry.get( + "create_time", 0 + ) # Get creation time, default to 0 + valid_entries += 1 - logger.debug( - f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs" + # Support multiple LLM caches per chunk + if chunk_id not in cached_results: + cached_results[chunk_id] = [] + # Store tuple with extraction result and creation time for sorting + cached_results[chunk_id].append((extraction_result, create_time)) + + # Sort extraction results by create_time for each chunk + for chunk_id in cached_results: + # Sort by create_time (x[1]), then extract only extraction_result (x[0]) + cached_results[chunk_id].sort(key=lambda x: x[1]) + cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]] + + logger.info( + f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results" ) return cached_results async def _parse_extraction_result( - text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str + text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str ) -> tuple[dict, dict]: """Parse cached extraction result using the same logic as extract_entities Args: + text_chunks_storage: Text chunks storage to get chunk data extraction_result: The cached LLM extraction result chunk_id: The chunk ID for source tracking @@ -387,8 +499,8 @@ async def _parse_extraction_result( Tuple of (entities_dict, relationships_dict) """ - # Get chunk data for file_path - chunk_data = await text_chunks.get_by_id(chunk_id) + # Get chunk data for file_path from storage + chunk_data = await text_chunks_storage.get_by_id(chunk_id) file_path = ( chunk_data.get("file_path", "unknown_source") if chunk_data @@ -761,8 +873,6 @@ async def _merge_nodes_then_upsert( entity_name, description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -925,8 +1035,6 @@ async def _merge_edges_then_upsert( f"({src_id}, {tgt_id})", description, global_config, - pipeline_status, - pipeline_status_lock, llm_response_cache, ) else: @@ -1102,6 +1210,7 @@ async def extract_entities( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + text_chunks_storage: BaseKVStorage | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -1208,6 +1317,9 @@ async def extract_entities( # Get file path from chunk data or use default file_path = chunk_dp.get("file_path", "unknown_source") + # Create cache keys collector for batch processing + cache_keys_collector = [] + # Get initial extraction hint_prompt = entity_extract_prompt.format( **{**context_base, "input_text": content} @@ -1219,7 +1331,10 @@ async def extract_entities( llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) + + # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) # Process initial extraction with file path @@ -1236,6 +1351,7 @@ async def extract_entities( history_messages=history, cache_type="extract", chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) @@ -1266,11 +1382,21 @@ async def extract_entities( llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", + cache_keys_collector=cache_keys_collector, ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break + # Batch update chunk's llm_cache_list with all collected cache keys + if cache_keys_collector and text_chunks_storage: + await update_chunk_cache_list( + chunk_key, + text_chunks_storage, + cache_keys_collector, + "entity_extraction", + ) + processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) @@ -1343,7 +1469,7 @@ async def kg_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -1390,7 +1516,7 @@ async def kg_query( ) if query_param.only_need_context: - return context + return context if context is not None else PROMPTS["fail_response"] if context is None: return PROMPTS["fail_response"] @@ -1502,7 +1628,7 @@ async def extract_keywords_only( """ # 1. Handle cache if needed - add cache type for keywords - args_hash = compute_args_hash(param.mode, text, cache_type="keywords") + args_hash = compute_args_hash(param.mode, text) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) @@ -1647,7 +1773,7 @@ async def _get_vector_context( f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) logger.info( - f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" + f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" ) if not maybe_trun_chunks: @@ -1871,7 +1997,7 @@ async def _get_node_data( ) logger.info( - f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" + f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" ) # build prompt @@ -2180,7 +2306,7 @@ async def _get_edge_data( ), ) logger.info( - f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" + f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" ) relations_context = [] @@ -2369,7 +2495,7 @@ async def naive_query( use_model_func = partial(use_model_func, _priority=5) # Handle cache - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) @@ -2485,7 +2611,7 @@ async def kg_query_with_keywords( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - args_hash = compute_args_hash(query_param.mode, query, cache_type="query") + args_hash = compute_args_hash(query_param.mode, query) cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 06b7a468..c6e2def9 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -14,7 +14,6 @@ from functools import wraps from hashlib import md5 from typing import Any, Protocol, Callable, TYPE_CHECKING, List import numpy as np -from lightrag.prompt import PROMPTS from dotenv import load_dotenv from lightrag.constants import ( DEFAULT_LOG_MAX_BYTES, @@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]: raise e from None -def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: +def compute_args_hash(*args: Any) -> str: """Compute a hash for the given arguments. Args: *args: Arguments to hash - cache_type: Type of cache (e.g., 'keywords', 'query', 'extract') Returns: str: Hash string """ @@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str: # Convert all arguments to strings and join them args_str = "".join([str(arg) for arg in args]) - if cache_type: - args_str = f"{cache_type}:{args_str}" # Compute MD5 hash return hashlib.md5(args_str.encode()).hexdigest() +def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str: + """Generate a flattened cache key in the format {mode}:{cache_type}:{hash} + + Args: + mode: Cache mode (e.g., 'default', 'local', 'global') + cache_type: Type of cache (e.g., 'extract', 'query', 'keywords') + hash_value: Hash value from compute_args_hash + + Returns: + str: Flattened cache key + """ + return f"{mode}:{cache_type}:{hash_value}" + + +def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: + """Parse a flattened cache key back into its components + + Args: + cache_key: Flattened cache key in format {mode}:{cache_type}:{hash} + + Returns: + tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format + """ + parts = cache_key.split(":", 2) + if len(parts) == 3: + return parts[0], parts[1], parts[2] + return None + + def compute_mdhash_id(content: str, prefix: str = "") -> str: """ Compute a unique ID for a given content string. @@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists): return combined_data -async def get_best_cached_response( - hashing_kv, - current_embedding, - similarity_threshold=0.95, - mode="default", - use_llm_check=False, - llm_func=None, - original_prompt=None, - cache_type=None, -) -> str | None: - logger.debug( - f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}" - ) - mode_cache = await hashing_kv.get_by_id(mode) - if not mode_cache: - return None - - best_similarity = -1 - best_response = None - best_prompt = None - best_cache_id = None - - # Only iterate through cache entries for this mode - for cache_id, cache_data in mode_cache.items(): - # Skip if cache_type doesn't match - if cache_type and cache_data.get("cache_type") != cache_type: - continue - - # Check if cache data is valid - if cache_data["embedding"] is None: - continue - - try: - # Safely convert cached embedding - cached_quantized = np.frombuffer( - bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 - ).reshape(cache_data["embedding_shape"]) - - # Ensure min_val and max_val are valid float values - embedding_min = cache_data.get("embedding_min") - embedding_max = cache_data.get("embedding_max") - - if ( - embedding_min is None - or embedding_max is None - or embedding_min >= embedding_max - ): - logger.warning( - f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}" - ) - continue - - cached_embedding = dequantize_embedding( - cached_quantized, - embedding_min, - embedding_max, - ) - except Exception as e: - logger.warning(f"Error processing cached embedding: {str(e)}") - continue - - similarity = cosine_similarity(current_embedding, cached_embedding) - if similarity > best_similarity: - best_similarity = similarity - best_response = cache_data["return"] - best_prompt = cache_data["original_prompt"] - best_cache_id = cache_id - - if best_similarity > similarity_threshold: - # If LLM check is enabled and all required parameters are provided - if ( - use_llm_check - and llm_func - and original_prompt - and best_prompt - and best_response is not None - ): - compare_prompt = PROMPTS["similarity_check"].format( - original_prompt=original_prompt, cached_prompt=best_prompt - ) - - try: - llm_result = await llm_func(compare_prompt) - llm_result = llm_result.strip() - llm_similarity = float(llm_result) - - # Replace vector similarity with LLM similarity score - best_similarity = llm_similarity - if best_similarity < similarity_threshold: - log_data = { - "event": "cache_rejected_by_llm", - "type": cache_type, - "mode": mode, - "original_question": original_prompt[:100] + "..." - if len(original_prompt) > 100 - else original_prompt, - "cached_question": best_prompt[:100] + "..." - if len(best_prompt) > 100 - else best_prompt, - "similarity_score": round(best_similarity, 4), - "threshold": similarity_threshold, - } - logger.debug(json.dumps(log_data, ensure_ascii=False)) - logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})") - return None - except Exception as e: # Catch all possible exceptions - logger.warning(f"LLM similarity check failed: {e}") - return None # Return None directly when LLM check fails - - prompt_display = ( - best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt - ) - log_data = { - "event": "cache_hit", - "type": cache_type, - "mode": mode, - "similarity": round(best_similarity, 4), - "cache_id": best_cache_id, - "original_prompt": prompt_display, - } - logger.debug(json.dumps(log_data, ensure_ascii=False)) - return best_response - return None - - def cosine_similarity(v1, v2): """Calculate cosine similarity between two vectors""" dot_product = np.dot(v1, v2) @@ -957,7 +857,7 @@ async def handle_cache( mode="default", cache_type=None, ): - """Generic cache handling function""" + """Generic cache handling function with flattened cache keys""" if hashing_kv is None: return None, None, None, None @@ -968,15 +868,14 @@ async def handle_cache( if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): return None, None, None, None - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} - else: - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") - return mode_cache[args_hash]["return"], None, None, None + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key(mode, cache_type, args_hash) + cache_entry = await hashing_kv.get_by_id(flattened_key) + if cache_entry: + logger.debug(f"Flattened cache hit(key:{flattened_key})") + return cache_entry["return"], None, None, None - logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Cache missed(mode:{mode} type:{cache_type})") return None, None, None, None @@ -994,7 +893,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - """Save data to cache, with improved handling for streaming responses and duplicate content. + """Save data to cache using flattened key structure. Args: hashing_kv: The key-value storage for caching @@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): logger.debug("Streaming response detected, skipping cache") return - # Get existing cache data - if exists_func(hashing_kv, "get_by_mode_and_id"): - mode_cache = ( - await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) - or {} - ) - else: - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key( + cache_data.mode, cache_data.cache_type, cache_data.args_hash + ) # Check if we already have identical content cached - if cache_data.args_hash in mode_cache: - existing_content = mode_cache[cache_data.args_hash].get("return") + existing_cache = await hashing_kv.get_by_id(flattened_key) + if existing_cache: + existing_content = existing_cache.get("return") if existing_content == cache_data.content: - logger.info( - f"Cache content unchanged for {cache_data.args_hash}, skipping update" - ) + logger.info(f"Cache content unchanged for {flattened_key}, skipping update") return - # Update cache with new content - mode_cache[cache_data.args_hash] = { + # Create cache entry with flattened structure + cache_entry = { "return": cache_data.content, "cache_type": cache_data.cache_type, "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, @@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "original_prompt": cache_data.prompt, } - logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}") + logger.info(f" == LLM cache == saving: {flattened_key}") - # Only upsert if there's actual new content - await hashing_kv.upsert({cache_data.mode: mode_cache}) + # Save using flattened key + await hashing_kv.upsert({flattened_key: cache_entry}) def safe_unicode_decode(content): @@ -1529,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return import_class +async def update_chunk_cache_list( + chunk_id: str, + text_chunks_storage: "BaseKVStorage", + cache_keys: list[str], + cache_scenario: str = "batch_update", +) -> None: + """Update chunk's llm_cache_list with the given cache keys + + Args: + chunk_id: Chunk identifier + text_chunks_storage: Text chunks storage instance + cache_keys: List of cache keys to add to the list + cache_scenario: Description of the cache scenario for logging + """ + if not cache_keys: + return + + try: + chunk_data = await text_chunks_storage.get_by_id(chunk_id) + if chunk_data: + # Ensure llm_cache_list exists + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] + + # Add cache keys to the list if not already present + existing_keys = set(chunk_data["llm_cache_list"]) + new_keys = [key for key in cache_keys if key not in existing_keys] + + if new_keys: + chunk_data["llm_cache_list"].extend(new_keys) + + # Update the chunk in storage + await text_chunks_storage.upsert({chunk_id: chunk_data}) + logger.debug( + f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})" + ) + except Exception as e: + logger.warning( + f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}" + ) + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1537,6 +1473,7 @@ async def use_llm_func_with_cache( history_messages: list[dict[str, str]] = None, cache_type: str = "extract", chunk_id: str | None = None, + cache_keys_collector: list = None, ) -> str: """Call LLM function with cache support @@ -1551,6 +1488,8 @@ async def use_llm_func_with_cache( history_messages: History messages list cache_type: Type of cache chunk_id: Chunk identifier to store in cache + text_chunks_storage: Text chunks storage to update llm_cache_list + cache_keys_collector: Optional list to collect cache keys for batch processing Returns: LLM response text @@ -1563,6 +1502,9 @@ async def use_llm_func_with_cache( _prompt = input_text arg_hash = compute_args_hash(_prompt) + # Generate cache key for this LLM call + cache_key = generate_cache_key("default", cache_type, arg_hash) + cached_return, _1, _2, _3 = await handle_cache( llm_response_cache, arg_hash, @@ -1573,6 +1515,11 @@ async def use_llm_func_with_cache( if cached_return: logger.debug(f"Found cache for {arg_hash}") statistic_data["llm_cache"] += 1 + + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return cached_return statistic_data["llm_call"] += 1 @@ -1597,6 +1544,10 @@ async def use_llm_func_with_cache( ), ) + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + return res # When cache is disabled, directly call LLM diff --git a/lightrag/utils_graph.py b/lightrag/utils_graph.py index 5485d47c..c2ccd313 100644 --- a/lightrag/utils_graph.py +++ b/lightrag/utils_graph.py @@ -6,7 +6,7 @@ from typing import Any, cast from .base import DeletionResult from .kg.shared_storage import get_graph_db_lock -from .prompt import GRAPH_FIELD_SEP +from .constants import GRAPH_FIELD_SEP from .utils import compute_mdhash_id, logger from .base import StorageNameSpace diff --git a/reproduce/batch_eval.py b/reproduce/batch_eval.py index a85e1ede..424b4f54 100644 --- a/reproduce/batch_eval.py +++ b/reproduce/batch_eval.py @@ -57,6 +57,10 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path): "Winner": "[Answer 1 or Answer 2]", "Explanation": "[Provide explanation here]" }}, + "Diversity": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }}, "Empowerment": {{ "Winner": "[Answer 1 or Answer 2]", "Explanation": "[Provide explanation here]" diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 3fd1abbc..62f658ff 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -8,6 +8,7 @@ 支持的图存储类型包括: - NetworkXStorage - Neo4JStorage +- MongoDBStorage - PGGraphStorage - MemgraphStorage """