From ed46d375fb5cbb82eb624febccaf3109cb0fba9f Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 17 Nov 2025 07:14:02 +0800 Subject: [PATCH] Auto-initialize pipeline status in LightRAG.initialize_storages() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Remove manual initialize_pipeline_status calls • Auto-init in initialize_storages method • Update error messages for clarity • Warn on workspace conflicts (cherry picked from commit e22ac52ebc239e25e1d9f486bbdbbcb9f3a391de) --- lightrag/api/lightrag_server.py | 649 +++++++++++++++---- lightrag/exceptions.py | 26 +- lightrag/lightrag.py | 1052 +++++++++++++++++++++++-------- 3 files changed, 1335 insertions(+), 392 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 69689c12..b29e39b2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -5,14 +5,16 @@ LightRAG FastAPI Server from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from fastapi.openapi.docs import ( + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, +) import os import logging import logging.config -import signal import sys import uvicorn import pipmaster as pm -import inspect from fastapi.staticfiles import StaticFiles from fastapi.responses import RedirectResponse from pathlib import Path @@ -50,17 +52,12 @@ from lightrag.api.routers.document_routes import ( from lightrag.api.routers.query_routes import create_query_routes from lightrag.api.routers.graph_routes import create_graph_routes from lightrag.api.routers.ollama_api import OllamaAPI -from lightrag.api.routers.tenant_routes import create_tenant_routes -from lightrag.api.routers.admin_routes import create_admin_routes -from lightrag.services.tenant_service import TenantService -from lightrag.tenant_rag_manager import TenantRAGManager -from lightrag.api.middleware.tenant import TenantMiddleware -from lightrag.namespace import NameSpace from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, - initialize_pipeline_status, + get_default_workspace, + # set_default_workspace, cleanup_keyed_lock, finalize_share_data, ) @@ -84,24 +81,6 @@ config.read("config.ini") auth_configured = bool(auth_handler.accounts) -def setup_signal_handlers(): - """Setup signal handlers for graceful shutdown""" - - def signal_handler(sig, frame): - print(f"\n\nReceived signal {sig}, shutting down gracefully...") - print(f"Process ID: {os.getpid()}") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - # Register signal handlers - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - class LLMConfigCache: """Smart LLM and Embedding configuration cache class""" @@ -110,6 +89,8 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None + self.gemini_llm_options = None + self.gemini_embedding_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -120,6 +101,12 @@ class LLMConfigCache: self.openai_llm_options = OpenAILLMOptions.options_dict(args) logger.info(f"OpenAI LLM Options: {self.openai_llm_options}") + if args.llm_binding == "gemini": + from lightrag.llm.binding_options import GeminiLLMOptions + + self.gemini_llm_options = GeminiLLMOptions.options_dict(args) + logger.info(f"Gemini LLM Options: {self.gemini_llm_options}") + # Only initialize and log Ollama LLM options when using Ollama LLM binding if args.llm_binding == "ollama": try: @@ -150,8 +137,159 @@ class LLMConfigCache: ) self.ollama_embedding_options = {} + # Only initialize and log Gemini Embedding options when using Gemini Embedding binding + if args.embedding_binding == "gemini": + try: + from lightrag.llm.binding_options import GeminiEmbeddingOptions + + self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict( + args + ) + logger.info( + f"Gemini Embedding Options: {self.gemini_embedding_options}" + ) + except ImportError: + logger.warning( + "GeminiEmbeddingOptions not available, using default configuration" + ) + self.gemini_embedding_options = {} + + +def check_frontend_build(): + """Check if frontend is built and optionally check if source is up-to-date + + Returns: + bool: True if frontend is outdated, False if up-to-date or production environment + """ + webui_dir = Path(__file__).parent / "webui" + index_html = webui_dir / "index.html" + + # 1. Check if build files exist (required) + if not index_html.exists(): + ASCIIColors.red("\n" + "=" * 80) + ASCIIColors.red("ERROR: Frontend Not Built") + ASCIIColors.red("=" * 80) + ASCIIColors.yellow("The WebUI frontend has not been built yet.") + ASCIIColors.yellow( + "Please build the frontend code first using the following commands:\n" + ) + ASCIIColors.cyan(" cd lightrag_webui") + ASCIIColors.cyan(" bun install --frozen-lockfile") + ASCIIColors.cyan(" bun run build") + ASCIIColors.cyan(" cd ..") + ASCIIColors.yellow("\nThen restart the service.\n") + ASCIIColors.cyan( + "Note: Make sure you have Bun installed. Visit https://bun.sh for installation." + ) + ASCIIColors.red("=" * 80 + "\n") + sys.exit(1) # Exit immediately + + # 2. Check if this is a development environment (source directory exists) + try: + source_dir = Path(__file__).parent.parent.parent / "lightrag_webui" + src_dir = source_dir / "src" + + # Determine if this is a development environment: source directory exists and contains src directory + if not source_dir.exists() or not src_dir.exists(): + # Production environment, skip source code check + logger.debug( + "Production environment detected, skipping source freshness check" + ) + return False + + # Development environment, perform source code timestamp check + logger.debug("Development environment detected, checking source freshness") + + # Source code file extensions (files to check) + source_extensions = { + ".ts", + ".tsx", + ".js", + ".jsx", + ".mjs", + ".cjs", # TypeScript/JavaScript + ".css", + ".scss", + ".sass", + ".less", # Style files + ".json", + ".jsonc", # Configuration/data files + ".html", + ".htm", # Template files + ".md", + ".mdx", # Markdown + } + + # Key configuration files (in lightrag_webui root directory) + key_files = [ + source_dir / "package.json", + source_dir / "bun.lock", + source_dir / "vite.config.ts", + source_dir / "tsconfig.json", + source_dir / "tailraid.config.js", + source_dir / "index.html", + ] + + # Get the latest modification time of source code + latest_source_time = 0 + + # Check source code files in src directory + for file_path in src_dir.rglob("*"): + if file_path.is_file(): + # Only check source code files, ignore temporary files and logs + if file_path.suffix.lower() in source_extensions: + mtime = file_path.stat().st_mtime + latest_source_time = max(latest_source_time, mtime) + + # Check key configuration files + for key_file in key_files: + if key_file.exists(): + mtime = key_file.stat().st_mtime + latest_source_time = max(latest_source_time, mtime) + + # Get build time + build_time = index_html.stat().st_mtime + + # Compare timestamps (5 second tolerance to avoid file system time precision issues) + if latest_source_time > build_time + 5: + ASCIIColors.yellow("\n" + "=" * 80) + ASCIIColors.yellow("WARNING: Frontend Source Code Has Been Updated") + ASCIIColors.yellow("=" * 80) + ASCIIColors.yellow( + "The frontend source code is newer than the current build." + ) + ASCIIColors.yellow( + "This might happen after 'git pull' or manual code changes.\n" + ) + ASCIIColors.cyan( + "Recommended: Rebuild the frontend to use the latest changes:" + ) + ASCIIColors.cyan(" cd lightrag_webui") + ASCIIColors.cyan(" bun install --frozen-lockfile") + ASCIIColors.cyan(" bun run build") + ASCIIColors.cyan(" cd ..") + ASCIIColors.yellow("\nThe server will continue with the current build.") + ASCIIColors.yellow("=" * 80 + "\n") + return True # Frontend is outdated + else: + logger.info("Frontend build is up-to-date") + return False # Frontend is up-to-date + + except Exception as e: + # If check fails, log warning but don't affect startup + logger.warning(f"Failed to check frontend source freshness: {e}") + return False # Assume up-to-date on error + def create_app(args): + # Check frontend build first and get outdated status + is_frontend_outdated = check_frontend_build() + + # Create unified API version display with warning symbol if frontend is outdated + api_version_display = ( + f"{__api_version__}⚠️" if is_frontend_outdated else __api_version__ + ) + # Setup logging logger.setLevel(args.log_level) set_verbose_debug(args.verbose) @@ -166,6 +304,7 @@ def create_app(args): "openai", "azure_openai", "aws_bedrock", + "gemini", ]: raise Exception("llm binding not supported") @@ -176,13 +315,9 @@ def create_app(args): "azure_openai", "aws_bedrock", "jina", + "gemini", ]: raise Exception("embedding binding not supported") - - # Log the configured embeddings binding for debugging - logger.info(f"Configured embedding binding: {args.embedding_binding}") - logger.info(f"Configured embedding model: {args.embedding_model}") - logger.info(f"Configured embedding host: {args.embedding_binding_host}") # Set default hosts if not provided if args.llm_binding_host is None: @@ -216,12 +351,8 @@ def create_app(args): try: # Initialize database connections + # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() - await initialize_pipeline_status() - - # Initialize tenant storage - if hasattr(tenant_storage, "initialize"): - await tenant_storage.initialize() # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -234,25 +365,31 @@ def create_app(args): # Clean up database connections await rag.finalize_storages() - # Clean up tenant manager - if hasattr(rag_manager, "cleanup_all"): - await rag_manager.cleanup_all() - - # Clean up shared data - finalize_share_data() + if "LIGHTRAG_GUNICORN_MODE" not in os.environ: + # Only perform cleanup in Uvicorn single-process mode + logger.debug("Unvicorn Mode: finalizing shared storage...") + finalize_share_data() + else: + # In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks + logger.debug( + "Gunicorn Mode: postpone shared storage finalization to master process" + ) # Initialize FastAPI + base_description = ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + ) + swagger_description = ( + base_description + + (" (API-Key Enabled)" if api_key else "") + + "\n\n[View ReDoc documentation](/redoc)" + ) app_kwargs = { "title": "LightRAG Server API", - "description": ( - "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - + "(With authentication)" - if api_key - else "" - ), + "description": swagger_description, "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL - "docs_url": "/docs", # Explicitly set docs URL + "docs_url": None, # Disable default docs, we'll create custom endpoint "redoc_url": "/redoc", # Explicitly set redoc URL "lifespan": lifespan, } @@ -316,6 +453,28 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) + def get_workspace_from_request(request: Request) -> str | None: + """ + Extract workspace from HTTP request header or use default. + + This enables multi-workspace API support by checking the custom + 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the + server's default workspace configuration. + + Args: + request: FastAPI Request object + + Returns: + Workspace identifier (may be empty string for global namespace) + """ + # Check custom header first + workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + + if not workspace: + workspace = None + + return workspace + # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -394,6 +553,44 @@ def create_app(args): return optimized_azure_openai_model_complete + def create_optimized_gemini_llm_func( + config_cache: LLMConfigCache, args, llm_timeout: int + ): + """Create optimized Gemini LLM function with cached configuration""" + + async def optimized_gemini_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + from lightrag.llm.gemini import gemini_complete_if_cache + + if history_messages is None: + history_messages = [] + + # Use pre-processed configuration to avoid repeated parsing + kwargs["timeout"] = llm_timeout + if ( + config_cache.gemini_llm_options is not None + and "generation_config" not in kwargs + ): + kwargs["generation_config"] = dict(config_cache.gemini_llm_options) + + return await gemini_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=args.llm_binding_api_key, + base_url=args.llm_binding_host, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + return optimized_gemini_model_complete + def create_llm_model_func(binding: str): """ Create LLM model function based on binding type. @@ -415,6 +612,8 @@ def create_app(args): return create_optimized_azure_openai_llm_func( config_cache, args, llm_timeout ) + elif binding == "gemini": + return create_optimized_gemini_llm_func(config_cache, args, llm_timeout) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) @@ -441,34 +640,109 @@ def create_app(args): return {} def create_optimized_embedding_function( - config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args - ): + config_cache: LLMConfigCache, binding, model, host, api_key, args + ) -> EmbeddingFunc: """ - Create optimized embedding function with pre-processed configuration for applicable bindings. - Uses lazy imports for all bindings and avoids repeated configuration parsing. + Create optimized embedding function and return an EmbeddingFunc instance + with proper max_token_size inheritance from provider defaults. + + This function: + 1. Imports the provider embedding function + 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc + 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) + 4. Returns a properly configured EmbeddingFunc instance """ - async def optimized_embedding_function(texts): + # Step 1: Import provider function and extract default attributes + provider_func = None + provider_max_token_size = None + provider_embedding_dim = None + + try: + if binding == "openai": + from lightrag.llm.openai import openai_embed + + provider_func = openai_embed + elif binding == "ollama": + from lightrag.llm.ollama import ollama_embed + + provider_func = ollama_embed + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + provider_func = gemini_embed + elif binding == "jina": + from lightrag.llm.jina import jina_embed + + provider_func = jina_embed + elif binding == "azure_openai": + from lightrag.llm.azure_openai import azure_openai_embed + + provider_func = azure_openai_embed + elif binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_embed + + provider_func = bedrock_embed + elif binding == "lollms": + from lightrag.llm.lollms import lollms_embed + + provider_func = lollms_embed + + # Extract attributes if provider is an EmbeddingFunc + if provider_func and isinstance(provider_func, EmbeddingFunc): + provider_max_token_size = provider_func.max_token_size + provider_embedding_dim = provider_func.embedding_dim + logger.debug( + f"Extracted from {binding} provider: " + f"max_token_size={provider_max_token_size}, " + f"embedding_dim={provider_embedding_dim}" + ) + except ImportError as e: + logger.warning(f"Could not import provider function for {binding}: {e}") + + # Step 2: Apply priority (user config > provider default) + # For max_token_size: explicit env var > provider default > None + final_max_token_size = args.embedding_token_limit or provider_max_token_size + # For embedding_dim: user config (always has value) takes priority + # Only use provider default if user config is explicitly None (which shouldn't happen) + final_embedding_dim = ( + args.embedding_dim if args.embedding_dim else provider_embedding_dim + ) + + # Step 3: Create optimized embedding function (calls underlying function directly) + async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - return await lollms_embed( + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + lollms_embed.func + if isinstance(lollms_embed, EmbeddingFunc) + else lollms_embed + ) + return await actual_func( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + ollama_embed.func + if isinstance(ollama_embed, EmbeddingFunc) + else ollama_embed + ) + + # Use pre-processed configuration if available if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await ollama_embed( + return await actual_func( texts, embed_model=model, host=host, @@ -478,27 +752,93 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + actual_func = ( + azure_openai_embed.func + if isinstance(azure_openai_embed, EmbeddingFunc) + else azure_openai_embed + ) + return await actual_func(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + actual_func = ( + bedrock_embed.func + if isinstance(bedrock_embed, EmbeddingFunc) + else bedrock_embed + ) + return await actual_func(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - return await jina_embed( - texts, dimensions=dimensions, base_url=host, api_key=api_key + actual_func = ( + jina_embed.func + if isinstance(jina_embed, EmbeddingFunc) + else jina_embed + ) + return await actual_func( + texts, + embedding_dim=embedding_dim, + base_url=host, + api_key=api_key, + ) + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + actual_func = ( + gemini_embed.func + if isinstance(gemini_embed, EmbeddingFunc) + else gemini_embed + ) + + # Use pre-processed configuration if available + if config_cache.gemini_embedding_options is not None: + gemini_options = config_cache.gemini_embedding_options + else: + from lightrag.llm.binding_options import GeminiEmbeddingOptions + + gemini_options = GeminiEmbeddingOptions.options_dict(args) + + return await actual_func( + texts, + model=model, + base_url=host, + api_key=api_key, + embedding_dim=embedding_dim, + task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), ) else: # openai and compatible from lightrag.llm.openai import openai_embed - return await openai_embed( - texts, model=model, base_url=host, api_key=api_key + actual_func = ( + openai_embed.func + if isinstance(openai_embed, EmbeddingFunc) + else openai_embed + ) + return await actual_func( + texts, + model=model, + base_url=host, + api_key=api_key, + embedding_dim=embedding_dim, ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - return optimized_embedding_function + # Step 4: Wrap in EmbeddingFunc and return + embedding_func_instance = EmbeddingFunc( + embedding_dim=final_embedding_dim, + func=optimized_embedding_function, + max_token_size=final_max_token_size, + send_dimensions=False, # Will be set later based on binding requirements + ) + + # Log final embedding configuration + logger.info( + f"Embedding config: binding={binding} model={model} " + f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" + ) + + return embedding_func_instance llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -532,20 +872,63 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration - embedding_func = EmbeddingFunc( - embedding_dim=args.embedding_dim, - func=create_optimized_embedding_function( - config_cache=config_cache, - binding=args.embedding_binding, - model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - dimensions=args.embedding_dim, - args=args, # Pass args object for fallback option generation - ), + # Create embedding function with optimized configuration and max_token_size inheritance + import inspect + + # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) + embedding_func = create_optimized_embedding_function( + config_cache=config_cache, + binding=args.embedding_binding, + model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + args=args, ) + # Get embedding_send_dim from centralized configuration + embedding_send_dim = args.embedding_send_dim + + # Check if the underlying function signature has embedding_dim parameter + sig = inspect.signature(embedding_func.func) + has_embedding_dim_param = "embedding_dim" in sig.parameters + + # Determine send_dimensions value based on binding type + # Jina and Gemini REQUIRE dimension parameter (forced to True) + # OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable + if args.embedding_binding in ["jina", "gemini"]: + # Jina and Gemini APIs require dimension parameter - always send it + send_dimensions = has_embedding_dim_param + dimension_control = f"forced by {args.embedding_binding.title()} API" + else: + # For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting + send_dimensions = embedding_send_dim and has_embedding_dim_param + if send_dimensions or not embedding_send_dim: + dimension_control = "by env var" + else: + dimension_control = "by not hasparam" + + # Set send_dimensions on the EmbeddingFunc instance + embedding_func.send_dimensions = send_dimensions + + logger.info( + f"Send embedding dimension: {send_dimensions} {dimension_control} " + f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " + f"binding={args.embedding_binding})" + ) + + # Log max_token_size source + if embedding_func.max_token_size: + source = ( + "env variable" + if args.embedding_token_limit + else f"{args.embedding_binding} provider default" + ) + logger.info( + f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" + ) + else: + logger.info("Embedding max_token_size: not set (90% token warning disabled)") + # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None if args.rerank_binding != "null": @@ -648,48 +1031,40 @@ def create_app(args): logger.error(f"Failed to initialize LightRAG: {e}") raise - # Initialize TenantService for multi-tenant support - tenant_storage = rag.key_string_value_json_storage_cls( - namespace=NameSpace.KV_STORE_TENANTS, - workspace=rag.workspace, - embedding_func=rag.embedding_func, - ) - tenant_service = TenantService(kv_storage=tenant_storage) - - # Initialize TenantRAGManager for managing per-tenant RAG instances with caching - # This enables efficient multi-tenant deployments by caching RAG instances - # Pass the main RAG instance as a template for tenant-specific instances - rag_manager = TenantRAGManager( - base_working_dir=args.working_dir, - tenant_service=tenant_service, - template_rag=rag, - max_cached_instances=int(os.getenv("MAX_CACHED_RAG_INSTANCES", "100")) - ) - - # Store rag_manager in app state for dependency injection - app.state.rag_manager = rag_manager - app.include_router(create_tenant_routes(tenant_service)) - app.include_router(create_admin_routes(tenant_service)) - - # Add membership management routes - from lightrag.api.routers import membership_routes - app.include_router(membership_routes.router) - + # Add routes app.include_router( create_document_routes( rag, doc_manager, api_key, - rag_manager, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k, rag_manager)) - app.include_router(create_graph_routes(rag, api_key, rag_manager)) + app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_graph_routes(rag, api_key)) - # Add Ollama API routes with tenant-scoped RAG support - ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key, rag_manager=rag_manager) + # Add Ollama API routes + ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") + # Custom Swagger UI endpoint for offline support + @app.get("/docs", include_in_schema=False) + async def custom_swagger_ui_html(): + """Custom Swagger UI HTML with local static files""" + return get_swagger_ui_html( + openapi_url=app.openapi_url, + title=app.title + " - Swagger UI", + oauth2_redirect_url="/docs/oauth2-redirect", + swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js", + swagger_css_url="/static/swagger-ui/swagger-ui.css", + swagger_favicon_url="/static/swagger-ui/favicon-32x32.png", + swagger_ui_parameters=app.swagger_ui_parameters, + ) + + @app.get("/docs/oauth2-redirect", include_in_schema=False) + async def swagger_ui_redirect(): + """OAuth2 redirect for Swagger UI""" + return get_swagger_ui_oauth2_redirect_html() + @app.get("/") async def redirect_to_webui(): """Redirect root path to /webui""" @@ -711,7 +1086,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -720,7 +1095,7 @@ def create_app(args): "auth_configured": True, "auth_mode": "enabled", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -738,7 +1113,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -747,26 +1122,30 @@ def create_app(args): raise HTTPException(status_code=401, detail="Incorrect credentials") # Regular user login - role = "admin" if username == "admin" else "user" - print(f"DEBUG: Login user={username}, role={role}") user_token = auth_handler.create_token( - username=username, role=role, metadata={"auth_mode": "enabled"} + username=username, role="user", metadata={"auth_mode": "enabled"} ) return { "access_token": user_token, "token_type": "bearer", "auth_mode": "enabled", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @app.get("/health", dependencies=[Depends(combined_auth)]) - async def get_status(): + async def get_status(request: Request): """Get current system status""" try: - pipeline_status = await get_namespace_data("pipeline_status") + workspace = get_workspace_from_request(request) + default_workspace = get_default_workspace() + if workspace is None: + workspace = default_workspace + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=workspace + ) if not auth_configured: auth_mode = "disabled" @@ -797,7 +1176,7 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": args.workspace, + "workspace": default_workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -821,7 +1200,7 @@ def create_app(args): "pipeline_busy": pipeline_status.get("busy", False), "keyed_locks": keyed_lock_info, "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -834,7 +1213,9 @@ def create_app(args): async def get_response(self, path: str, scope): response = await super().get_response(path, scope) - if path.endswith(".html"): + is_html = path.endswith(".html") or response.media_type == "text/html" + + if is_html: response.headers["Cache-Control"] = ( "no-cache, no-store, must-revalidate" ) @@ -856,6 +1237,15 @@ def create_app(args): return response + # Mount Swagger UI static files for offline support + swagger_static_dir = Path(__file__).parent / "static" / "swagger-ui" + if swagger_static_dir.exists(): + app.mount( + "/static/swagger-ui", + StaticFiles(directory=swagger_static_dir), + name="swagger-ui-static", + ) + # Webui mount webui/index.html static_dir = Path(__file__).parent / "webui" static_dir.mkdir(exist_ok=True) @@ -867,9 +1257,6 @@ def create_app(args): name="webui", ) - # Add Tenant middleware - app.add_middleware(TenantMiddleware) - return app @@ -978,6 +1365,12 @@ def check_and_install_dependencies(): def main(): + # Explicitly initialize configuration for clarity + # (The proxy will auto-initialize anyway, but this makes intent clear) + from .config import initialize_config + + initialize_config() + # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application @@ -1000,8 +1393,10 @@ def main(): update_uvicorn_mode_config() display_splash_screen(global_args) - # Setup signal handlers for graceful shutdown - setup_signal_handlers() + # Note: Signal handlers are NOT registered here because: + # - Uvicorn has built-in signal handling that properly calls lifespan shutdown + # - Custom signal handlers can interfere with uvicorn's graceful shutdown + # - Cleanup is handled by the lifespan context manager's finally block # Create application instance directly instead of using factory function app = create_app(global_args) diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index d57df1ac..54a52507 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -68,10 +68,7 @@ class StorageNotInitializedError(RuntimeError): f"{storage_type} not initialized. Please ensure proper initialization:\n" f"\n" f" rag = LightRAG(...)\n" - f" await rag.initialize_storages() # Required\n" - f" \n" - f" from lightrag.kg.shared_storage import initialize_pipeline_status\n" - f" await initialize_pipeline_status() # Required for pipeline operations\n" + f" await rag.initialize_storages() # Required - auto-initializes pipeline_status\n" f"\n" f"See: https://github.com/HKUDS/LightRAG#important-initialization-requirements" ) @@ -82,17 +79,20 @@ class PipelineNotInitializedError(KeyError): def __init__(self, namespace: str = ""): msg = ( - f"Pipeline namespace '{namespace}' not found. " - f"This usually means pipeline status was not initialized.\n" + f"Pipeline namespace '{namespace}' not found.\n" f"\n" - f"Please call 'await initialize_pipeline_status()' after initializing storages:\n" + f"Pipeline status should be auto-initialized by initialize_storages().\n" + f"If you see this error, please ensure:\n" f"\n" + f" 1. You called await rag.initialize_storages()\n" + f" 2. For multi-workspace setups, each LightRAG instance was properly initialized\n" + f"\n" + f"Standard initialization:\n" + f" rag = LightRAG(workspace='your_workspace')\n" + f" await rag.initialize_storages() # Auto-initializes pipeline_status\n" + f"\n" + f"If you need manual control (advanced):\n" f" from lightrag.kg.shared_storage import initialize_pipeline_status\n" - f" await initialize_pipeline_status()\n" - f"\n" - f"Full initialization sequence:\n" - f" rag = LightRAG(...)\n" - f" await rag.initialize_storages()\n" - f" await initialize_pipeline_status()" + f" await initialize_pipeline_status(workspace='your_workspace')" ) super().__init__(msg) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8b0c54fb..4f22a305 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,6 +3,7 @@ from __future__ import annotations import traceback import asyncio import configparser +import inspect import os import time import warnings @@ -12,6 +13,7 @@ from functools import partial from typing import ( Any, AsyncIterator, + Awaitable, Callable, Iterator, cast, @@ -20,7 +22,10 @@ from typing import ( Optional, List, Dict, + Union, ) +from lightrag.prompt import PROMPTS +from lightrag.exceptions import PipelineCancelledException from lightrag.constants import ( DEFAULT_MAX_GLEANING, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, @@ -39,10 +44,15 @@ from lightrag.constants import ( DEFAULT_MAX_ASYNC, DEFAULT_MAX_PARALLEL_INSERT, DEFAULT_MAX_GRAPH_NODES, + DEFAULT_MAX_SOURCE_IDS_PER_ENTITY, + DEFAULT_MAX_SOURCE_IDS_PER_RELATION, DEFAULT_ENTITY_TYPES, DEFAULT_SUMMARY_LANGUAGE, DEFAULT_LLM_TIMEOUT, DEFAULT_EMBEDDING_TIMEOUT, + DEFAULT_SOURCE_IDS_LIMIT_METHOD, + DEFAULT_MAX_FILE_PATHS, + DEFAULT_FILE_PATH_MORE_PLACEHOLDER, ) from lightrag.utils import get_env_value @@ -54,10 +64,10 @@ from lightrag.kg import ( from lightrag.kg.shared_storage import ( get_namespace_data, - get_pipeline_status_lock, - get_graph_db_lock, get_data_init_lock, - initialize_pipeline_status, + get_default_workspace, + set_default_workspace, + get_namespace_lock, ) from lightrag.base import ( @@ -81,7 +91,7 @@ from lightrag.operate import ( merge_nodes_and_edges, kg_query, naive_query, - _rebuild_knowledge_from_chunks, + rebuild_knowledge_from_chunks, ) from lightrag.constants import GRAPH_FIELD_SEP from lightrag.utils import ( @@ -98,6 +108,9 @@ from lightrag.utils import ( generate_track_id, convert_to_user_format, logger, + subtract_source_ids, + make_relation_chunk_key, + normalize_source_ids_limit_method, ) from lightrag.types import KnowledgeGraph from dotenv import load_dotenv @@ -234,11 +247,13 @@ class LightRAG: int, int, ], - List[Dict[str, Any]], + Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]], ] = field(default_factory=lambda: chunking_by_token_size) """ Custom chunking function for splitting text into chunks before processing. + The function can be either synchronous or asynchronous. + The function should take the following parameters: - `tokenizer`: A Tokenizer instance to use for tokenization. @@ -248,7 +263,8 @@ class LightRAG: - `chunk_token_size`: The maximum number of tokens per chunk. - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. - The function should return a list of dictionaries, where each dictionary contains the following keys: + The function should return a list of dictionaries (or an awaitable that resolves to a list), + where each dictionary contains the following keys: - `tokens`: The number of tokens in the chunk. - `content`: The text content of the chunk. @@ -261,6 +277,9 @@ class LightRAG: embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" + embedding_token_limit: int | None = field(default=None, init=False) + """Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__.""" + embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10))) """Batch size for embedding computations.""" @@ -360,6 +379,41 @@ class LightRAG: ) """Maximum number of graph nodes to return in knowledge graph queries.""" + max_source_ids_per_entity: int = field( + default=get_env_value( + "MAX_SOURCE_IDS_PER_ENTITY", DEFAULT_MAX_SOURCE_IDS_PER_ENTITY, int + ) + ) + """Maximum number of source (chunk) ids in entity Grpah + VDB.""" + + max_source_ids_per_relation: int = field( + default=get_env_value( + "MAX_SOURCE_IDS_PER_RELATION", + DEFAULT_MAX_SOURCE_IDS_PER_RELATION, + int, + ) + ) + """Maximum number of source (chunk) ids in relation Graph + VDB.""" + + source_ids_limit_method: str = field( + default_factory=lambda: normalize_source_ids_limit_method( + get_env_value( + "SOURCE_IDS_LIMIT_METHOD", + DEFAULT_SOURCE_IDS_LIMIT_METHOD, + str, + ) + ) + ) + """Strategy for enforcing source_id limits: IGNORE_NEW or FIFO.""" + + max_file_paths: int = field( + default=get_env_value("MAX_FILE_PATHS", DEFAULT_MAX_FILE_PATHS, int) + ) + """Maximum number of file paths to store in entity/relation file_path field.""" + + file_path_more_placeholder: str = field(default=DEFAULT_FILE_PATH_MORE_PLACEHOLDER) + """Placeholder text when file paths exceed max_file_paths limit.""" + addon_params: dict[str, Any] = field( default_factory=lambda: { "language": get_env_value( @@ -385,11 +439,6 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) - @property - def pipeline_status_key(self) -> str: - """Get the namespaced pipeline status key for this instance.""" - return f"pipeline_status_{compute_mdhash_id(self.working_dir)}" - def __post_init__(self): from lightrag.kg.shared_storage import ( initialize_share_data, @@ -474,6 +523,16 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init Embedding + # Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes) + embedding_max_token_size = None + if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): + embedding_max_token_size = self.embedding_func.max_token_size + logger.debug( + f"Captured embedding max_token_size: {embedding_max_token_size}" + ) + self.embedding_token_limit = embedding_max_token_size + + # Step 2: Apply priority wrapper decorator self.embedding_func = priority_limit_async_func_call( self.embedding_func_max_async, llm_timeout=self.default_embedding_timeout, @@ -534,6 +593,18 @@ class LightRAG: embedding_func=self.embedding_func, ) + self.entity_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + namespace=NameSpace.KV_STORE_ENTITY_CHUNKS, + workspace=self.workspace, + embedding_func=self.embedding_func, + ) + + self.relation_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + namespace=NameSpace.KV_STORE_RELATION_CHUNKS, + workspace=self.workspace, + embedding_func=self.embedding_func, + ) + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, workspace=self.workspace, @@ -588,11 +659,29 @@ class LightRAG: async def initialize_storages(self): """Storage initialization must be called one by one to prevent deadlock""" if self._storages_status == StoragesStatus.CREATED: + # Set the first initialized workspace will set the default workspace + # Allows namespace operation without specifying workspace for backward compatibility + default_workspace = get_default_workspace() + if default_workspace is None: + set_default_workspace(self.workspace) + elif default_workspace != self.workspace: + logger.warning( + f"Creating LightRAG instance with workspace='{self.workspace}' " + f"but default workspace is already set to '{default_workspace}'." + ) + + # Auto-initialize pipeline_status for this workspace + from lightrag.kg.shared_storage import initialize_pipeline_status + + await initialize_pipeline_status(workspace=self.workspace) + for storage in ( self.full_docs, self.text_chunks, self.full_entities, self.full_relations, + self.entity_chunks, + self.relation_chunks, self.entities_vdb, self.relationships_vdb, self.chunks_vdb, @@ -615,6 +704,8 @@ class LightRAG: ("text_chunks", self.text_chunks), ("full_entities", self.full_entities), ("full_relations", self.full_relations), + ("entity_chunks", self.entity_chunks), + ("relation_chunks", self.relation_chunks), ("entities_vdb", self.entities_vdb), ("relationships_vdb", self.relationships_vdb), ("chunks_vdb", self.chunks_vdb), @@ -655,7 +746,7 @@ class LightRAG: async def check_and_migrate_data(self): """Check if data migration is needed and perform migration if necessary""" - async with get_data_init_lock(enable_logging=True): + async with get_data_init_lock(): try: # Check if migration is needed: # 1. chunk_entity_relation_graph has entities and relations (count > 0) @@ -670,6 +761,13 @@ class LightRAG: logger.debug("No entities found in graph, skipping migration check") return + try: + # Initialize chunk tracking storage after migration + await self._migrate_chunk_tracking_storage() + except Exception as e: + logger.error(f"Error during chunk_tracking migration: {e}") + raise e + # Check if full_entities and full_relations are empty # Get all processed documents to check their entity/relation data try: @@ -710,11 +808,11 @@ class LightRAG: except Exception as e: logger.error(f"Error during migration check: {e}") - # Don't raise the error, just log it to avoid breaking initialization + raise e except Exception as e: logger.error(f"Error in data migration check: {e}") - # Don't raise the error to avoid breaking initialization + raise e async def _migrate_entity_relation_data(self, processed_docs: dict): """Migrate existing entity and relation data to full_entities and full_relations storage""" @@ -813,6 +911,140 @@ class LightRAG: f"Data migration completed: migrated {migration_count} documents with entities/relations" ) + async def _migrate_chunk_tracking_storage(self) -> None: + """Ensure entity/relation chunk tracking KV stores exist and are seeded.""" + + if not self.entity_chunks or not self.relation_chunks: + return + + need_entity_migration = False + need_relation_migration = False + + try: + need_entity_migration = await self.entity_chunks.is_empty() + except Exception as exc: # pragma: no cover - defensive logging + logger.error(f"Failed to check entity chunks storage: {exc}") + raise exc + + try: + need_relation_migration = await self.relation_chunks.is_empty() + except Exception as exc: # pragma: no cover - defensive logging + logger.error(f"Failed to check relation chunks storage: {exc}") + raise exc + + if not need_entity_migration and not need_relation_migration: + return + + BATCH_SIZE = 500 # Process 500 records per batch + + if need_entity_migration: + try: + nodes = await self.chunk_entity_relation_graph.get_all_nodes() + except Exception as exc: + logger.error(f"Failed to fetch nodes for chunk migration: {exc}") + nodes = [] + + logger.info(f"Starting chunk_tracking data migration: {len(nodes)} nodes") + + # Process nodes in batches + total_nodes = len(nodes) + total_batches = (total_nodes + BATCH_SIZE - 1) // BATCH_SIZE + total_migrated = 0 + + for batch_idx in range(total_batches): + start_idx = batch_idx * BATCH_SIZE + end_idx = min((batch_idx + 1) * BATCH_SIZE, total_nodes) + batch_nodes = nodes[start_idx:end_idx] + + upsert_payload: dict[str, dict[str, object]] = {} + for node in batch_nodes: + entity_id = node.get("entity_id") or node.get("id") + if not entity_id: + continue + + raw_source = node.get("source_id") or "" + chunk_ids = [ + chunk_id + for chunk_id in raw_source.split(GRAPH_FIELD_SEP) + if chunk_id + ] + if not chunk_ids: + continue + + upsert_payload[entity_id] = { + "chunk_ids": chunk_ids, + "count": len(chunk_ids), + } + + if upsert_payload: + await self.entity_chunks.upsert(upsert_payload) + total_migrated += len(upsert_payload) + logger.info( + f"Processed entity batch {batch_idx + 1}/{total_batches}: {len(upsert_payload)} records (total: {total_migrated}/{total_nodes})" + ) + + if total_migrated > 0: + # Persist entity_chunks data to disk + await self.entity_chunks.index_done_callback() + logger.info( + f"Entity chunk_tracking migration completed: {total_migrated} records persisted" + ) + + if need_relation_migration: + try: + edges = await self.chunk_entity_relation_graph.get_all_edges() + except Exception as exc: + logger.error(f"Failed to fetch edges for chunk migration: {exc}") + edges = [] + + logger.info(f"Starting chunk_tracking data migration: {len(edges)} edges") + + # Process edges in batches + total_edges = len(edges) + total_batches = (total_edges + BATCH_SIZE - 1) // BATCH_SIZE + total_migrated = 0 + + for batch_idx in range(total_batches): + start_idx = batch_idx * BATCH_SIZE + end_idx = min((batch_idx + 1) * BATCH_SIZE, total_edges) + batch_edges = edges[start_idx:end_idx] + + upsert_payload: dict[str, dict[str, object]] = {} + for edge in batch_edges: + src = edge.get("source") or edge.get("src_id") or edge.get("src") + tgt = edge.get("target") or edge.get("tgt_id") or edge.get("tgt") + if not src or not tgt: + continue + + raw_source = edge.get("source_id") or "" + chunk_ids = [ + chunk_id + for chunk_id in raw_source.split(GRAPH_FIELD_SEP) + if chunk_id + ] + if not chunk_ids: + continue + + storage_key = make_relation_chunk_key(src, tgt) + upsert_payload[storage_key] = { + "chunk_ids": chunk_ids, + "count": len(chunk_ids), + } + + if upsert_payload: + await self.relation_chunks.upsert(upsert_payload) + total_migrated += len(upsert_payload) + logger.info( + f"Processed relation batch {batch_idx + 1}/{total_batches}: {len(upsert_payload)} records (total: {total_migrated}/{total_edges})" + ) + + if total_migrated > 0: + # Persist relation_chunks data to disk + await self.relation_chunks.index_done_callback() + logger.info( + f"Relation chunk_tracking migration completed: {total_migrated} records persisted" + ) + async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels() return text @@ -912,9 +1144,8 @@ class LightRAG: ids: str | list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - tenant_context: Optional[Any] = None, ) -> str: - """Async Insert documents with checkpoint support and optional tenant context + """Async Insert documents with checkpoint support Args: input: Single document string or list of document strings @@ -925,8 +1156,6 @@ class LightRAG: ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated file_paths: list of file paths corresponding to each document, used for citation track_id: tracking ID for monitoring processing status, if not provided, will be generated - tenant_context: Optional TenantContext for multi-tenant deployments. If provided, tenant_id and kb_id - from context will be propagated to storage operations for proper isolation. Returns: str: tracking ID for monitoring processing status @@ -935,11 +1164,6 @@ class LightRAG: if track_id is None: track_id = generate_track_id("insert") - # Store tenant context for propagation to storage operations - if tenant_context: - self._tenant_id = tenant_context.tenant_id - self._kb_id = tenant_context.kb_id - await self.apipeline_enqueue_documents(input, ids, file_paths, track_id) await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only @@ -1025,7 +1249,6 @@ class LightRAG: ids: list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - external_ids: list[str] | None = None, ) -> str: """ Pipeline for Processing Documents @@ -1066,31 +1289,6 @@ class LightRAG: # If no file paths provided, use placeholder file_paths = ["unknown_source"] * len(input) - # If external_ids provided, check idempotency before further processing - # We iterate in-order and skip any document whose external_id already exists in doc_status - indices_to_process = list(range(len(input))) - if external_ids is not None: - if isinstance(external_ids, str): - external_ids = [external_ids] - if len(external_ids) != len(input): - raise ValueError("Number of external_ids must match the number of documents") - - # Call get_doc_by_external_id for each external_id and filter out existing documents - remaining_indices = [] - for i, ext_id in enumerate(external_ids): - if ext_id and str(ext_id).strip(): - try: - existing = await self.doc_status.get_doc_by_external_id(ext_id) - except Exception: - existing = None - if existing: - # Skip this index (idempotent: doc exists) - logger.info(f"Skipping document with external_id {ext_id} since it already exists") - continue - remaining_indices.append(i) - - indices_to_process = remaining_indices - # 1. Validate ids if provided or generate MD5 hash IDs and remove duplicate contents if ids is not None: # Check if the number of IDs matches the number of documents @@ -1116,13 +1314,7 @@ class LightRAG: else: # Clean input text and remove duplicates in one pass unique_content_with_paths = {} - # When ids isn't provided we compute md5 ids, but we only consider indices_to_process - for idx, (doc, path) in enumerate(zip(input, file_paths)): - if idx not in indices_to_process: - continue - cleaned_content = sanitize_text_for_encoding(doc) - if cleaned_content not in unique_content_with_paths: - unique_content_with_paths[cleaned_content] = path + for doc, path in zip(input, file_paths): cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_content_with_paths: unique_content_with_paths[cleaned_content] = path @@ -1136,29 +1328,6 @@ class LightRAG: for content, path in unique_content_with_paths.items() } - # If external_ids were provided, attach them to the corresponding new_docs (metadata) - external_map: dict[str, str] = {} - if external_ids is not None: - # Map external_ids to the computed doc ids in a best-effort manner using original indices - # We only have doc ids in 'contents' keys; we'll map in insertion order - # Build list of doc ids in same ordering of processed inputs - content_list = list(contents.items()) - # Map by index in indices_to_process - try: - for idx, ext_id in zip(indices_to_process, external_ids): - if ext_id and str(ext_id).strip(): - # Find the generated doc_id for this content - # Need to compute cleaned content to find key - cleaned = sanitize_text_for_encoding(input[idx]) - # find matching doc id - for doc_id, data in contents.items(): - if data.get("content") == cleaned: - external_map[doc_id] = ext_id - break - except Exception: - # If anything goes wrong mapping external ids, ignore external map - external_map = {} - # 2. Generate document initial status (without content) new_docs: dict[str, Any] = { id_: { @@ -1219,14 +1388,6 @@ class LightRAG: await self.full_docs.index_done_callback() # Store document status (without content) - # Attach external_id metadata to status entries when available - if external_map: - for doc_id, ext_id in external_map.items(): - if doc_id in new_docs: - doc_metadata = new_docs[doc_id].get("metadata", {}) - doc_metadata["external_id"] = ext_id - new_docs[doc_id]["metadata"] = doc_metadata - await self.doc_status.upsert(new_docs) logger.debug(f"Stored {len(new_docs)} new unique documents") @@ -1448,9 +1609,12 @@ class LightRAG: """ # Get pipeline status shared data and lock - await initialize_pipeline_status(self.pipeline_status_key) - pipeline_status = await get_namespace_data(self.pipeline_status_key) - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace + ) # Check if another process is already processing the queue async with pipeline_status_lock: @@ -1480,6 +1644,7 @@ class LightRAG: "batchs": 0, # Total number of files to be processed "cur_batch": 0, # Number of files already processed "request_pending": False, # Clear any previous request + "cancellation_requested": False, # Initialize cancellation flag "latest_message": "", } ) @@ -1496,6 +1661,22 @@ class LightRAG: try: # Process documents until no more documents or requests while True: + # Check for cancellation request at the start of main loop + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + # Clear pending request + pipeline_status["request_pending"] = False + # Celar cancellation flag + pipeline_status["cancellation_requested"] = False + + log_message = "Pipeline cancelled by user" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Exit directly, skipping request_pending check + return + if not to_process_docs: log_message = "All enqueued documents have been processed" logger.info(log_message) @@ -1558,14 +1739,25 @@ class LightRAG: semaphore: asyncio.Semaphore, ) -> None: """Process single document""" + # Initialize variables at the start to prevent UnboundLocalError in error handling + file_path = "unknown_source" + current_file_number = 0 file_extraction_stage_ok = False + processing_start_time = int(time.time()) + first_stage_tasks = [] + entity_relation_task = None + async with semaphore: nonlocal processed_count - current_file_number = 0 # Initialize to prevent UnboundLocalError in error handling first_stage_tasks = [] entity_relation_task = None try: + # Check for cancellation before starting document processing + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled") + # Get file path from status document file_path = getattr( status_doc, "file_path", "unknown_source" @@ -1604,7 +1796,28 @@ class LightRAG: ) content = content_data["content"] - # Generate chunks from document + # Call chunking function, supporting both sync and async implementations + chunking_result = self.chunking_func( + self.tokenizer, + content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + ) + + # If result is awaitable, await to get actual result + if inspect.isawaitable(chunking_result): + chunking_result = await chunking_result + + # Validate return type + if not isinstance(chunking_result, (list, tuple)): + raise TypeError( + f"chunking_func must return a list or tuple of dicts, " + f"got {type(chunking_result)}" + ) + + # Build chunks dictionary chunks: dict[str, Any] = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, @@ -1612,14 +1825,7 @@ class LightRAG: "file_path": file_path, # Add file path to each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk } - for dp in self.chunking_func( - self.tokenizer, - content, - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - ) + for dp in chunking_result } if not chunks: @@ -1628,6 +1834,11 @@ class LightRAG: # Record processing start time processing_start_time = int(time.time()) + # Check for cancellation before entity extraction + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled") + # Process document in two stages # Stage 1: Process text chunks and docs (parallel execution) doc_status_task = asyncio.create_task( @@ -1678,20 +1889,33 @@ class LightRAG: chunks, pipeline_status, pipeline_status_lock ) ) - await entity_relation_task + chunk_results = await entity_relation_task file_extraction_stage_ok = True except Exception as e: - # Log error and update pipeline status - logger.error(traceback.format_exc()) - error_msg = f"Failed to extract document {current_file_number}/{total_files}: {file_path}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append( - traceback.format_exc() - ) - pipeline_status["history_messages"].append(error_msg) + # Check if this is a user cancellation + if isinstance(e, PipelineCancelledException): + # User cancellation - log brief message only, no traceback + error_msg = f"User cancelled {current_file_number}/{total_files}: {file_path}" + logger.warning(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + error_msg + ) + else: + # Other exceptions - log with traceback + logger.error(traceback.format_exc()) + error_msg = f"Failed to extract document {current_file_number}/{total_files}: {file_path}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + traceback.format_exc() + ) + pipeline_status["history_messages"].append( + error_msg + ) # Cancel tasks that are not yet completed all_tasks = first_stage_tasks + ( @@ -1738,8 +1962,16 @@ class LightRAG: # Concurrency is controlled by keyed lock for individual entities and relationships if file_extraction_stage_ok: try: - # Get chunk_results from entity_relation_task - chunk_results = await entity_relation_task + # Check for cancellation before merge + async with pipeline_status_lock: + if pipeline_status.get( + "cancellation_requested", False + ): + raise PipelineCancelledException( + "User cancelled" + ) + + # Use chunk_results from entity_relation_task await merge_nodes_and_edges( chunk_results=chunk_results, # result collected from entity_relation_task knowledge_graph_inst=self.chunk_entity_relation_graph, @@ -1752,6 +1984,8 @@ class LightRAG: pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, + entity_chunks_storage=self.entity_chunks, + relation_chunks_storage=self.relation_chunks, current_file_number=current_file_number, total_files=total_files, file_path=file_path, @@ -1794,18 +2028,29 @@ class LightRAG: ) except Exception as e: - # Log error and update pipeline status - logger.error(traceback.format_exc()) - error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append( - traceback.format_exc() - ) - pipeline_status["history_messages"].append( - error_msg - ) + # Check if this is a user cancellation + if isinstance(e, PipelineCancelledException): + # User cancellation - log brief message only, no traceback + error_msg = f"User cancelled during merge {current_file_number}/{total_files}: {file_path}" + logger.warning(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + error_msg + ) + else: + # Other exceptions - log with traceback + logger.error(traceback.format_exc()) + error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + traceback.format_exc() + ) + pipeline_status["history_messages"].append( + error_msg + ) # Persistent llm cache with error handling if self.llm_response_cache: @@ -1855,7 +2100,19 @@ class LightRAG: ) # Wait for all document processing to complete - await asyncio.gather(*doc_tasks) + try: + await asyncio.gather(*doc_tasks) + except PipelineCancelledException: + # Cancel all remaining tasks + for task in doc_tasks: + if not task.done(): + task.cancel() + + # Wait for all tasks to complete cancellation + await asyncio.wait(doc_tasks, return_when=asyncio.ALL_COMPLETED) + + # Exit directly (document statuses already updated in process_document) + return # Check if there's a pending request to process more documents (with lock) has_pending_request = False @@ -1886,11 +2143,14 @@ class LightRAG: to_process_docs.update(pending_docs) finally: - log_message = "Enqueued document processing pipeline stoped" + log_message = "Enqueued document processing pipeline stopped" logger.info(log_message) - # Always reset busy status when done or if an exception occurs (with lock) + # Always reset busy status and cancellation flag when done or if an exception occurs (with lock) async with pipeline_status_lock: pipeline_status["busy"] = False + pipeline_status["cancellation_requested"] = ( + False # Always reset cancellation flag + ) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -1926,6 +2186,8 @@ class LightRAG: self.text_chunks, self.full_entities, self.full_relations, + self.entity_chunks, + self.relation_chunks, self.llm_response_cache, self.entities_vdb, self.relationships_vdb, @@ -2151,7 +2413,6 @@ class LightRAG: query: str, param: QueryParam = QueryParam(), system_prompt: str | None = None, - tenant_context: Optional[Any] = None, ) -> str | AsyncIterator[str]: """ Perform a async query (backward compatibility wrapper). @@ -2164,19 +2425,12 @@ class LightRAG: param (QueryParam): Configuration parameters for query execution. If param.model_func is provided, it will be used instead of the global model. system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. - tenant_context: Optional TenantContext for multi-tenant deployments. If provided, query execution will - be scoped to the specified tenant and knowledge base for proper data isolation. Returns: str | AsyncIterator[str]: The LLM response content. - Non-streaming: Returns str - Streaming: Returns AsyncIterator[str] """ - # Store tenant context for propagation to storage operations - if tenant_context: - self._tenant_id = tenant_context.tenant_id - self._kb_id = tenant_context.kb_id - # Call the new aquery_llm function to get complete results result = await self.aquery_llm(query, param, system_prompt) @@ -2213,7 +2467,6 @@ class LightRAG: self, query: str, param: QueryParam = QueryParam(), - tenant_context: Optional[Any] = None, ) -> dict[str, Any]: """ Asynchronous data retrieval API: returns structured retrieval results without LLM generation. @@ -2224,8 +2477,6 @@ class LightRAG: Args: query: Query text for retrieval. param: Query parameters controlling retrieval behavior (same as aquery). - tenant_context: Optional TenantContext for multi-tenant deployments. If provided, query execution will - be scoped to the specified tenant and knowledge base for proper data isolation. Returns: dict[str, Any]: Structured data result in the following format: @@ -2323,11 +2574,6 @@ class LightRAG: actual data is nested under the 'data' field, with 'status' and 'message' fields at the top level. """ - # Store tenant context for propagation to storage operations - if tenant_context: - self._tenant_id = tenant_context.tenant_id - self._kb_id = tenant_context.kb_id - global_config = asdict(self) # Create a copy of param to avoid modifying the original @@ -2391,20 +2637,35 @@ class LightRAG: else: raise ValueError(f"Unknown mode {data_param.mode}") - # Extract raw_data from QueryResult - final_data = query_result.raw_data if query_result else {} - - # Log final result counts - adapt to new data format from convert_to_user_format - if final_data and "data" in final_data: - data_section = final_data["data"] - entities_count = len(data_section.get("entities", [])) - relationships_count = len(data_section.get("relationships", [])) - chunks_count = len(data_section.get("chunks", [])) - logger.debug( - f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks" - ) + if query_result is None: + no_result_message = "Query returned no results" + if data_param.mode == "naive": + no_result_message = "No relevant document chunks found." + final_data: dict[str, Any] = { + "status": "failure", + "message": no_result_message, + "data": {}, + "metadata": { + "failure_reason": "no_results", + "mode": data_param.mode, + }, + } + logger.info("[aquery_data] Query returned no results.") else: - logger.warning("[aquery_data] No data section found in query result") + # Extract raw_data from QueryResult + final_data = query_result.raw_data or {} + + # Log final result counts - adapt to new data format from convert_to_user_format + if final_data and "data" in final_data: + data_section = final_data["data"] + entities_count = len(data_section.get("entities", [])) + relationships_count = len(data_section.get("relationships", [])) + chunks_count = len(data_section.get("chunks", [])) + logger.debug( + f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks" + ) + else: + logger.warning("[aquery_data] No data section found in query result") await self._query_done() return final_data @@ -2507,16 +2768,19 @@ class LightRAG: "status": "failure", "message": "Query returned no results", "data": {}, - "metadata": {}, + "metadata": { + "failure_reason": "no_results", + "mode": param.mode, + }, "llm_response": { - "content": None, + "content": PROMPTS["fail_response"], "response_iterator": None, "is_streaming": False, }, } # Extract structured data from query result - raw_data = query_result.raw_data if query_result and query_result.raw_data else {} + raw_data = query_result.raw_data or {} raw_data["llm_response"] = { "content": query_result.content if not query_result.is_streaming @@ -2674,17 +2938,20 @@ class LightRAG: # Return the dictionary containing statuses only for the found document IDs return found_statuses - async def adelete_by_doc_id(self, doc_id: str, tenant_context: Optional[Any] = None) -> DeletionResult: - """Delete a document and all its related data, including chunks, graph elements, and cached entries. + async def adelete_by_doc_id( + self, doc_id: str, delete_llm_cache: bool = False + ) -> DeletionResult: + """Delete a document and all its related data, including chunks, graph elements. This method orchestrates a comprehensive deletion process for a given document ID. It ensures that not only the document itself but also all its derived and associated - data across different storage layers are removed. If entities or relationships are partially affected, it triggers. + data across different storage layers are removed or rebuiled. If entities or relationships + are partially affected, they will be rebuilded using LLM cached from remaining documents. Args: doc_id (str): The unique identifier of the document to be deleted. - tenant_context: Optional TenantContext for multi-tenant deployments. If provided, deletion will - be scoped to the specified tenant and knowledge base. + delete_llm_cache (bool): Whether to delete cached LLM extraction results + associated with the document. Defaults to False. Returns: DeletionResult: An object containing the outcome of the deletion process. @@ -2694,17 +2961,17 @@ class LightRAG: - `status_code` (int): HTTP status code (e.g., 200, 404, 500). - `file_path` (str | None): The file path of the deleted document, if available. """ - # Store tenant context for propagation to storage operations - if tenant_context: - self._tenant_id = tenant_context.tenant_id - self._kb_id = tenant_context.kb_id - deletion_operations_started = False original_exception = None + doc_llm_cache_ids: list[str] = [] # Get pipeline status shared data and lock for status updates - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace + ) async with pipeline_status_lock: log_message = f"Starting deletion process for document {doc_id}" @@ -2727,7 +2994,12 @@ class LightRAG: ) # Check document status and log warning for non-completed documents - doc_status = doc_status_data.get("status") + raw_status = doc_status_data.get("status") + try: + doc_status = DocStatus(raw_status) + except ValueError: + doc_status = raw_status + if doc_status != DocStatus.PROCESSED: if doc_status == DocStatus.PENDING: warning_msg = ( @@ -2737,12 +3009,23 @@ class LightRAG: warning_msg = ( f"Deleting {doc_id} {file_path}(previous status: PROCESSING)" ) + elif doc_status == DocStatus.PREPROCESSED: + warning_msg = ( + f"Deleting {doc_id} {file_path}(previous status: PREPROCESSED)" + ) elif doc_status == DocStatus.FAILED: warning_msg = ( f"Deleting {doc_id} {file_path}(previous status: FAILED)" ) else: - warning_msg = f"Deleting {doc_id} {file_path}(previous status: {doc_status.value})" + status_text = ( + doc_status.value + if isinstance(doc_status, DocStatus) + else str(doc_status) + ) + warning_msg = ( + f"Deleting {doc_id} {file_path}(previous status: {status_text})" + ) logger.info(warning_msg) # Update pipeline status for monitoring async with pipeline_status_lock: @@ -2785,11 +3068,64 @@ class LightRAG: # Mark that deletion operations have started deletion_operations_started = True + if delete_llm_cache and chunk_ids: + if not self.llm_response_cache: + logger.info( + "Skipping LLM cache collection for document %s because cache storage is unavailable", + doc_id, + ) + elif not self.text_chunks: + logger.info( + "Skipping LLM cache collection for document %s because text chunk storage is unavailable", + doc_id, + ) + else: + try: + chunk_data_list = await self.text_chunks.get_by_ids( + list(chunk_ids) + ) + seen_cache_ids: set[str] = set() + for chunk_data in chunk_data_list: + if not chunk_data or not isinstance(chunk_data, dict): + continue + cache_ids = chunk_data.get("llm_cache_list", []) + if not isinstance(cache_ids, list): + continue + for cache_id in cache_ids: + if ( + isinstance(cache_id, str) + and cache_id + and cache_id not in seen_cache_ids + ): + doc_llm_cache_ids.append(cache_id) + seen_cache_ids.add(cache_id) + if doc_llm_cache_ids: + logger.info( + "Collected %d LLM cache entries for document %s", + len(doc_llm_cache_ids), + doc_id, + ) + else: + logger.info( + "No LLM cache entries found for document %s", doc_id + ) + except Exception as cache_collect_error: + logger.error( + "Failed to collect LLM cache ids for document %s: %s", + doc_id, + cache_collect_error, + ) + raise Exception( + f"Failed to collect LLM cache ids for document {doc_id}: {cache_collect_error}" + ) from cache_collect_error + # 4. Analyze entities and relationships that will be affected entities_to_delete = set() - entities_to_rebuild = {} # entity_name -> remaining_chunk_ids + entities_to_rebuild = {} # entity_name -> remaining chunk id list relationships_to_delete = set() - relationships_to_rebuild = {} # (src, tgt) -> remaining_chunk_ids + relationships_to_rebuild = {} # (src, tgt) -> remaining chunk id list + entity_chunk_updates: dict[str, list[str]] = {} + relation_chunk_updates: dict[tuple[str, str], list[str]] = {} try: # Get affected entities and relations from full_entities and full_relations storage @@ -2845,14 +3181,44 @@ class LightRAG: # Process entities for node_data in affected_nodes: node_label = node_data.get("entity_id") - if node_label and "source_id" in node_data: - sources = set(node_data["source_id"].split(GRAPH_FIELD_SEP)) - remaining_sources = sources - chunk_ids + if not node_label: + continue - if not remaining_sources: - entities_to_delete.add(node_label) - elif remaining_sources != sources: - entities_to_rebuild[node_label] = remaining_sources + existing_sources: list[str] = [] + if self.entity_chunks: + stored_chunks = await self.entity_chunks.get_by_id(node_label) + if stored_chunks and isinstance(stored_chunks, dict): + existing_sources = [ + chunk_id + for chunk_id in stored_chunks.get("chunk_ids", []) + if chunk_id + ] + + if not existing_sources and node_data.get("source_id"): + existing_sources = [ + chunk_id + for chunk_id in node_data["source_id"].split( + GRAPH_FIELD_SEP + ) + if chunk_id + ] + + if not existing_sources: + # No chunk references means this entity should be deleted + entities_to_delete.add(node_label) + entity_chunk_updates[node_label] = [] + continue + + remaining_sources = subtract_source_ids(existing_sources, chunk_ids) + + if not remaining_sources: + entities_to_delete.add(node_label) + entity_chunk_updates[node_label] = [] + elif remaining_sources != existing_sources: + entities_to_rebuild[node_label] = remaining_sources + entity_chunk_updates[node_label] = remaining_sources + else: + logger.info(f"Untouch entity: {node_label}") async with pipeline_status_lock: log_message = f"Found {len(entities_to_rebuild)} affected entities" @@ -2862,24 +3228,58 @@ class LightRAG: # Process relationships for edge_data in affected_edges: + # source target is not in normalize order in graph db property src = edge_data.get("source") tgt = edge_data.get("target") - if src and tgt and "source_id" in edge_data: - edge_tuple = tuple(sorted((src, tgt))) - if ( - edge_tuple in relationships_to_delete - or edge_tuple in relationships_to_rebuild - ): - continue + if not src or not tgt or "source_id" not in edge_data: + continue - sources = set(edge_data["source_id"].split(GRAPH_FIELD_SEP)) - remaining_sources = sources - chunk_ids + edge_tuple = tuple(sorted((src, tgt))) + if ( + edge_tuple in relationships_to_delete + or edge_tuple in relationships_to_rebuild + ): + continue - if not remaining_sources: - relationships_to_delete.add(edge_tuple) - elif remaining_sources != sources: - relationships_to_rebuild[edge_tuple] = remaining_sources + existing_sources: list[str] = [] + if self.relation_chunks: + storage_key = make_relation_chunk_key(src, tgt) + stored_chunks = await self.relation_chunks.get_by_id( + storage_key + ) + if stored_chunks and isinstance(stored_chunks, dict): + existing_sources = [ + chunk_id + for chunk_id in stored_chunks.get("chunk_ids", []) + if chunk_id + ] + + if not existing_sources: + existing_sources = [ + chunk_id + for chunk_id in edge_data["source_id"].split( + GRAPH_FIELD_SEP + ) + if chunk_id + ] + + if not existing_sources: + # No chunk references means this relationship should be deleted + relationships_to_delete.add(edge_tuple) + relation_chunk_updates[edge_tuple] = [] + continue + + remaining_sources = subtract_source_ids(existing_sources, chunk_ids) + + if not remaining_sources: + relationships_to_delete.add(edge_tuple) + relation_chunk_updates[edge_tuple] = [] + elif remaining_sources != existing_sources: + relationships_to_rebuild[edge_tuple] = remaining_sources + relation_chunk_updates[edge_tuple] = remaining_sources + else: + logger.info(f"Untouch relation: {edge_tuple}") async with pipeline_status_lock: log_message = ( @@ -2889,60 +3289,147 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) + current_time = int(time.time()) + + if entity_chunk_updates and self.entity_chunks: + entity_upsert_payload = {} + for entity_name, remaining in entity_chunk_updates.items(): + if not remaining: + # Empty entities are deleted alongside graph nodes later + continue + entity_upsert_payload[entity_name] = { + "chunk_ids": remaining, + "count": len(remaining), + "updated_at": current_time, + } + if entity_upsert_payload: + await self.entity_chunks.upsert(entity_upsert_payload) + + if relation_chunk_updates and self.relation_chunks: + relation_upsert_payload = {} + for edge_tuple, remaining in relation_chunk_updates.items(): + if not remaining: + # Empty relations are deleted alongside graph edges later + continue + storage_key = make_relation_chunk_key(*edge_tuple) + relation_upsert_payload[storage_key] = { + "chunk_ids": remaining, + "count": len(remaining), + "updated_at": current_time, + } + + if relation_upsert_payload: + await self.relation_chunks.upsert(relation_upsert_payload) + except Exception as e: logger.error(f"Failed to process graph analysis results: {e}") raise Exception(f"Failed to process graph dependencies: {e}") from e - # Use graph database lock to prevent dirty read - graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: - # 5. Delete chunks from storage - if chunk_ids: - try: - await self.chunks_vdb.delete(chunk_ids) - await self.text_chunks.delete(chunk_ids) + # Data integrity is ensured by allowing only one process to hold pipeline at a time(no graph db lock is needed anymore) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(chunk_ids)} chunks from storage" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # 5. Delete chunks from storage + if chunk_ids: + try: + await self.chunks_vdb.delete(chunk_ids) + await self.text_chunks.delete(chunk_ids) - except Exception as e: - logger.error(f"Failed to delete chunks: {e}") - raise Exception(f"Failed to delete document chunks: {e}") from e + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(chunk_ids)} chunks from storage" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # 6. Delete entities that have no remaining sources - if entities_to_delete: - try: - # Delete from vector database - entity_vdb_ids = [ - compute_mdhash_id(entity, prefix="ent-") - for entity in entities_to_delete + except Exception as e: + logger.error(f"Failed to delete chunks: {e}") + raise Exception(f"Failed to delete document chunks: {e}") from e + + # 6. Delete relationships that have no remaining sources + if relationships_to_delete: + try: + # Delete from relation vdb + rel_ids_to_delete = [] + for src, tgt in relationships_to_delete: + rel_ids_to_delete.extend( + [ + compute_mdhash_id(src + tgt, prefix="rel-"), + compute_mdhash_id(tgt + src, prefix="rel-"), + ] + ) + await self.relationships_vdb.delete(rel_ids_to_delete) + + # Delete from graph + await self.chunk_entity_relation_graph.remove_edges( + list(relationships_to_delete) + ) + + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in relationships_to_delete ] - await self.entities_vdb.delete(entity_vdb_ids) + await self.relation_chunks.delete(relation_storage_keys) - # Delete from graph - await self.chunk_entity_relation_graph.remove_nodes( + async with pipeline_status_lock: + log_message = f"Successfully deleted {len(relationships_to_delete)} relations" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + except Exception as e: + logger.error(f"Failed to delete relationships: {e}") + raise Exception(f"Failed to delete relationships: {e}") from e + + # 7. Delete entities that have no remaining sources + if entities_to_delete: + try: + # Batch get all edges for entities to avoid N+1 query problem + nodes_edges_dict = ( + await self.chunk_entity_relation_graph.get_nodes_edges_batch( list(entities_to_delete) ) + ) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(entities_to_delete)} entities" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Debug: Check and log all edges before deleting nodes + edges_to_delete = set() + edges_still_exist = 0 - except Exception as e: - logger.error(f"Failed to delete entities: {e}") - raise Exception(f"Failed to delete entities: {e}") from e + for entity, edges in nodes_edges_dict.items(): + if edges: + for src, tgt in edges: + # Normalize edge representation (sorted for consistency) + edge_tuple = tuple(sorted((src, tgt))) + edges_to_delete.add(edge_tuple) - # 7. Delete relationships that have no remaining sources - if relationships_to_delete: - try: - # Delete from vector database + if ( + src in entities_to_delete + and tgt in entities_to_delete + ): + logger.warning( + f"Edge still exists: {src} <-> {tgt}" + ) + elif src in entities_to_delete: + logger.warning( + f"Edge still exists: {src} --> {tgt}" + ) + else: + logger.warning( + f"Edge still exists: {src} <-- {tgt}" + ) + edges_still_exist += 1 + + if edges_still_exist: + logger.warning( + f"⚠️ {edges_still_exist} entities still has edges before deletion" + ) + + # Clean residual edges from VDB and storage before deleting nodes + if edges_to_delete: + # Delete from relationships_vdb rel_ids_to_delete = [] - for src, tgt in relationships_to_delete: + for src, tgt in edges_to_delete: rel_ids_to_delete.extend( [ compute_mdhash_id(src + tgt, prefix="rel-"), @@ -2951,28 +3438,53 @@ class LightRAG: ) await self.relationships_vdb.delete(rel_ids_to_delete) - # Delete from graph - await self.chunk_entity_relation_graph.remove_edges( - list(relationships_to_delete) + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in edges_to_delete + ] + await self.relation_chunks.delete(relation_storage_keys) + + logger.info( + f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage" ) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(relationships_to_delete)} relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Delete from graph (edges will be auto-deleted with nodes) + await self.chunk_entity_relation_graph.remove_nodes( + list(entities_to_delete) + ) - except Exception as e: - logger.error(f"Failed to delete relationships: {e}") - raise Exception(f"Failed to delete relationships: {e}") from e + # Delete from vector vdb + entity_vdb_ids = [ + compute_mdhash_id(entity, prefix="ent-") + for entity in entities_to_delete + ] + await self.entities_vdb.delete(entity_vdb_ids) - # Persist changes to graph database before releasing graph database lock - await self._insert_done() + # Delete from entity_chunks storage + if self.entity_chunks: + await self.entity_chunks.delete(list(entities_to_delete)) + + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(entities_to_delete)} entities" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + except Exception as e: + logger.error(f"Failed to delete entities: {e}") + raise Exception(f"Failed to delete entities: {e}") from e + + # Persist changes to graph database before entity and relationship rebuild + await self._insert_done() # 8. Rebuild entities and relationships from remaining chunks if entities_to_rebuild or relationships_to_rebuild: try: - await _rebuild_knowledge_from_chunks( + await rebuild_knowledge_from_chunks( entities_to_rebuild=entities_to_rebuild, relationships_to_rebuild=relationships_to_rebuild, knowledge_graph_inst=self.chunk_entity_relation_graph, @@ -2983,6 +3495,8 @@ class LightRAG: global_config=asdict(self), pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, + entity_chunks_storage=self.entity_chunks, + relation_chunks_storage=self.relation_chunks, ) except Exception as e: @@ -3007,6 +3521,23 @@ class LightRAG: logger.error(f"Failed to delete document and status: {e}") raise Exception(f"Failed to delete document and status: {e}") from e + if delete_llm_cache and doc_llm_cache_ids and self.llm_response_cache: + try: + await self.llm_response_cache.delete(doc_llm_cache_ids) + cache_log_message = f"Successfully deleted {len(doc_llm_cache_ids)} LLM cache entries for document {doc_id}" + logger.info(cache_log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = cache_log_message + pipeline_status["history_messages"].append(cache_log_message) + log_message = cache_log_message + except Exception as cache_delete_error: + log_message = f"Failed to delete LLM cache for document {doc_id}: {cache_delete_error}" + logger.error(log_message) + logger.error(traceback.format_exc()) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + return DeletionResult( status="success", doc_id=doc_id, @@ -3171,16 +3702,22 @@ class LightRAG: ) async def aedit_entity( - self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True + self, + entity_name: str, + updated_data: dict[str, str], + allow_rename: bool = True, + allow_merge: bool = False, ) -> dict[str, Any]: """Asynchronously edit entity information. Updates entity information in the knowledge graph and re-embeds the entity in the vector database. + Also synchronizes entity_chunks_storage and relation_chunks_storage to track chunk references. Args: entity_name: Name of the entity to edit updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"} allow_rename: Whether to allow entity renaming, defaults to True + allow_merge: Whether to merge into an existing entity when renaming to an existing name Returns: Dictionary containing updated entity information @@ -3194,14 +3731,21 @@ class LightRAG: entity_name, updated_data, allow_rename, + allow_merge, + self.entity_chunks, + self.relation_chunks, ) def edit_entity( - self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True + self, + entity_name: str, + updated_data: dict[str, str], + allow_rename: bool = True, + allow_merge: bool = False, ) -> dict[str, Any]: loop = always_get_an_event_loop() return loop.run_until_complete( - self.aedit_entity(entity_name, updated_data, allow_rename) + self.aedit_entity(entity_name, updated_data, allow_rename, allow_merge) ) async def aedit_relation( @@ -3210,6 +3754,7 @@ class LightRAG: """Asynchronously edit relation information. Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database. + Also synchronizes the relation_chunks_storage to track which chunks reference this relation. Args: source_entity: Name of the source entity @@ -3228,6 +3773,7 @@ class LightRAG: source_entity, target_entity, updated_data, + self.relation_chunks, ) def edit_relation( @@ -3339,6 +3885,8 @@ class LightRAG: target_entity, merge_strategy, target_entity_data, + self.entity_chunks, + self.relation_chunks, ) def merge_entities(