From 99643f01dec35a8a4d498fef07bcd138d8ffad22 Mon Sep 17 00:00:00 2001 From: SJ Date: Wed, 13 Aug 2025 02:08:13 -0500 Subject: [PATCH 01/15] Enhancement: support aws bedrock as an LLm binding #1733 --- lightrag/api/config.py | 72 +++-------- lightrag/api/lightrag_server.py | 222 +++++++++++++++++--------------- lightrag/llm/bedrock.py | 197 +++++++++++++++++++++++----- 3 files changed, 299 insertions(+), 192 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 56e506cc..0a16e89b 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,9 +77,7 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) + parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") # Server configuration parser.add_argument( @@ -209,14 +207,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], + choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai"], + choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) @@ -272,18 +270,10 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ) - args.doc_status_storage = get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ) - args.graph_storage = get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ) - args.vector_storage = get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ) + args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) + args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) + args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) + args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -299,12 +289,8 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value( - "LLM_BINDING_HOST", get_default_host(args.llm_binding) - ) - args.embedding_binding_host = get_env_value( - "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) - ) + args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) + args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -318,9 +304,7 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool - ) + args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -370,40 +354,24 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value( - "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float - ) + args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value( - "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int - ) - args.max_relation_tokens = get_env_value( - "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int - ) - args.max_total_tokens = get_env_value( - "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int - ) - args.cosine_threshold = get_env_value( - "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float - ) - args.related_chunk_number = get_env_value( - "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int - ) + args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) + args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) + args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) + args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) + args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value( - "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int - ) - args.embedding_batch_num = get_env_value( - "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int - ) + args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) + args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -417,9 +385,7 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning( - f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" - ) + logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 699f59ac..8ee23826 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -106,6 +106,7 @@ def create_app(args): "openai", "openai-ollama", "azure_openai", + "aws_bedrock", ]: raise Exception("llm binding not supported") @@ -114,6 +115,7 @@ def create_app(args): "ollama", "openai", "azure_openai", + "aws_bedrock", "jina", ]: raise Exception("embedding binding not supported") @@ -128,9 +130,7 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception( - "SSL certificate and key files must be provided when SSL is enabled" - ) + raise Exception("SSL certificate and key files must be provided when SSL is enabled") if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -188,10 +188,11 @@ def create_app(args): # Initialize FastAPI app_kwargs = { "title": "LightRAG Server API", - "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - + "(With authentication)" - if api_key - else "", + "description": ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + if api_key + else "" + ), "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL "docs_url": "/docs", # Explicitly set docs URL @@ -244,9 +245,9 @@ def create_app(args): azure_openai_complete_if_cache, azure_openai_embed, ) + if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama": - from lightrag.llm.openai import openai_complete_if_cache - from lightrag.llm.ollama import ollama_embed from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed @@ -312,41 +313,80 @@ def create_app(args): **kwargs, ) + async def bedrock_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + if history_messages is None: + history_messages = [] + + # Use global temperature for Bedrock + kwargs["temperature"] = args.temperature + + return await bedrock_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + embedding_func = EmbeddingFunc( embedding_dim=args.embedding_dim, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - options=OllamaEmbeddingOptions.options_dict(args), - ) - if args.embedding_binding == "ollama" - else azure_openai_embed( - texts, - model=args.embedding_model, # no host is used for openai, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "azure_openai" - else jina_embed( - texts, - dimensions=args.embedding_dim, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "jina" - else openai_embed( - texts, - model=args.embedding_model, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, + func=lambda texts: ( + lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "lollms" + else ( + ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + options=OllamaEmbeddingOptions.options_dict(args), + ) + if args.embedding_binding == "ollama" + else ( + azure_openai_embed( + texts, + model=args.embedding_model, # no host is used for openai, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "azure_openai" + else ( + bedrock_embed( + texts, + model=args.embedding_model, + ) + if args.embedding_binding == "aws_bedrock" + else ( + jina_embed( + texts, + dimensions=args.embedding_dim, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "jina" + else openai_embed( + texts, + model=args.embedding_model, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + ) + ) + ) + ) ), ) @@ -355,9 +395,7 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func( - query: str, documents: list, top_n: int = None, **kwargs - ): + async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -370,9 +408,7 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info( - f"Rerank model configured: {args.rerank_model} (can be enabled per query)" - ) + logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -381,41 +417,43 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos( - name=args.simulated_model_name, tag=args.simulated_model_tag - ) + ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) # Initialize RAG - if args.llm_binding in ["lollms", "ollama", "openai"]: + if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else openai_alike_model_complete, + llm_model_func=( + lollms_model_complete + if args.llm_binding == "lollms" + else ( + ollama_model_complete + if args.llm_binding == "ollama" + else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + ) + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.max_tokens, chunk_token_size=int(args.chunk_size), chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": OllamaLLMOptions.options_dict(args), - "api_key": args.llm_binding_api_key, - } - if args.llm_binding == "lollms" or args.llm_binding == "ollama" - else {}, + llm_model_kwargs=( + { + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": OllamaLLMOptions.options_dict(args), + "api_key": args.llm_binding_api_key, + } + if args.llm_binding == "lollms" or args.llm_binding == "ollama" + else {} + ), embedding_func=embedding_func, kv_storage=args.kv_storage, graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -442,9 +480,7 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,9 +516,7 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "auth_configured": False, "access_token": guest_token, @@ -508,9 +542,7 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "access_token": guest_token, "token_type": "bearer", @@ -523,14 +555,10 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") # Regular user login - user_token = auth_handler.create_token( - username=username, role="user", metadata={"auth_mode": "enabled"} - ) + user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) return { "access_token": user_token, "token_type": "bearer", @@ -579,12 +607,8 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model - if rerank_model_func is not None - else None, - "rerank_binding_host": args.rerank_binding_host - if rerank_model_func is not None - else None, + "rerank_model": args.rerank_model if rerank_model_func is not None else None, + "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -614,17 +638,11 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = ( - "no-cache, no-store, must-revalidate" - ) + response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif ( - "/assets/" in path - ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = ( - "public, max-age=31536000, immutable" - ) + elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = "public, max-age=31536000, immutable" # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -640,9 +658,7 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles( - directory=static_dir, html=True, check_dir=True - ), # Use SmartStaticFiles + SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles name="webui", ) @@ -798,9 +814,7 @@ def main(): } ) - print( - f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" - ) + print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 1640abbb..51c4b895 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -15,11 +15,25 @@ from tenacity import ( retry_if_exception_type, ) +import sys + +if sys.version_info < (3, 9): + from typing import AsyncIterator +else: + from collections.abc import AsyncIterator +from typing import Union + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" +def _set_env_if_present(key: str, value): + """Set environment variable only if a non-empty value is provided.""" + if value is not None and value != "": + os.environ[key] = value + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=60), @@ -34,17 +48,35 @@ async def bedrock_complete_if_cache( aws_secret_access_key=None, aws_session_token=None, **kwargs, -) -> str: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) +) -> Union[str, AsyncIterator[str]]: + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + # Region handling: prefer env, else kwarg (optional) + region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None) kwargs.pop("hashing_kv", None) + # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter + # We'll use this to determine whether to call converse_stream or converse + stream = bool(kwargs.pop("stream", False)) + # Remove unsupported args for Bedrock Converse API + for k in [ + "response_format", + "tools", + "tool_choice", + "seed", + "presence_penalty", + "frequency_penalty", + "n", + "logprobs", + "top_logprobs", + "max_completion_tokens", + "response_format", + ]: + kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: @@ -68,30 +100,126 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list( - set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) - ): + if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = ( - kwargs.pop(param) - ) + args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) - # Call model via Converse API + # Import logging for error handling + import logging + + # For streaming responses, we need a different approach to keep the connection open + if stream: + # Create a session that will be used throughout the streaming process + session = aioboto3.Session() + client = None + + # Define the generator function that will manage the client lifecycle + async def stream_generator(): + nonlocal client + + # Create the client outside the generator to ensure it stays open + client = await session.client("bedrock-runtime", region_name=region).__aenter__() + event_stream = None + iteration_started = False + + try: + # Make the API call + response = await client.converse_stream(**args, **kwargs) + event_stream = response.get("stream") + iteration_started = True + + # Process the stream + async for event in event_stream: + # Validate event structure + if not event or not isinstance(event, dict): + continue + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta", {}) + text = delta.get("text") + if text: + yield text + # Handle other event types that might indicate stream end + elif "messageStop" in event: + break + + except Exception as e: + # Log the specific error for debugging + logging.error(f"Bedrock streaming error: {e}") + + # Try to clean up resources if possible + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream: {close_error}") + + raise BedrockError(f"Streaming error: {e}") + + finally: + # Clean up the event stream + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + + # Clean up the client + if client: + try: + await client.__aexit__(None, None, None) + except Exception as client_close_error: + logging.warning(f"Failed to close Bedrock client: {client_close_error}") + + # Return the generator that manages its own lifecycle + return stream_generator() + + # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: try: + # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) - except Exception as e: - raise BedrockError(e) - return response["output"]["message"]["content"][0]["text"] + # Validate response structure + if ( + not response + or "output" not in response + or "message" not in response["output"] + or "content" not in response["output"]["message"] + or not response["output"]["message"]["content"] + ): + raise BedrockError("Invalid response structure from Bedrock API") + + content = response["output"]["message"]["content"][0]["text"] + + if not content or content.strip() == "": + raise BedrockError("Received empty content from Bedrock API") + + return content + + except Exception as e: + if isinstance(e, BedrockError): + raise + else: + raise BedrockError(f"Bedrock API error: {e}") # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: +) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await bedrock_complete_if_cache( @@ -117,18 +245,19 @@ async def bedrock_embed( aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + + # Region handling: prefer env + region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -156,9 +285,7 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps( - {"texts": texts, "input_type": "search_document", "truncate": "NONE"} - ) + body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) response = await bedrock_async_client.invoke_model( model=model, From f7ca9ae16a26a81d9b4b188f7a4a1e071b6419a1 Mon Sep 17 00:00:00 2001 From: SJ Date: Fri, 15 Aug 2025 22:21:34 +0000 Subject: [PATCH 02/15] Ruff formatted --- lightrag/api/config.py | 77 +++++++++++++++++++++++++-------- lightrag/api/lightrag_server.py | 75 ++++++++++++++++++++++++-------- lightrag/llm/bedrock.py | 36 +++++++++++---- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 0a16e89b..01d0dd75 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,7 +77,9 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) # Server configuration parser.add_argument( @@ -207,7 +209,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], + choices=[ + "lollms", + "ollama", + "openai", + "openai-ollama", + "azure_openai", + "aws_bedrock", + ], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( @@ -270,10 +279,18 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) - args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) - args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) - args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) + args.kv_storage = get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ) + args.doc_status_storage = get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ) + args.graph_storage = get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ) + args.vector_storage = get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -289,8 +306,12 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) - args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) + args.llm_binding_host = get_env_value( + "LLM_BINDING_HOST", get_default_host(args.llm_binding) + ) + args.embedding_binding_host = get_env_value( + "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) + ) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -304,7 +325,9 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) + args.enable_llm_cache_for_extract = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool + ) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -354,24 +377,40 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) + args.min_rerank_score = get_env_value( + "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float + ) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) - args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) - args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) - args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) - args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) + args.max_entity_tokens = get_env_value( + "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int + ) + args.max_relation_tokens = get_env_value( + "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int + ) + args.max_total_tokens = get_env_value( + "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int + ) + args.cosine_threshold = get_env_value( + "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float + ) + args.related_chunk_number = get_env_value( + "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int + ) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) - args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) + args.embedding_func_max_async = get_env_value( + "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int + ) + args.embedding_batch_num = get_env_value( + "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int + ) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -385,7 +424,9 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8ee23826..92349861 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -130,7 +130,9 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -189,7 +191,8 @@ def create_app(args): app_kwargs = { "title": "LightRAG Server API", "description": ( - "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + + "(With authentication)" if api_key else "" ), @@ -395,7 +398,9 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): + async def server_rerank_func( + query: str, documents: list, top_n: int = None, **kwargs + ): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -408,7 +413,9 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") + logger.info( + f"Rerank model configured: {args.rerank_model} (can be enabled per query)" + ) else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -417,7 +424,9 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) + ollama_server_infos = OllamaServerInfos( + name=args.simulated_model_name, tag=args.simulated_model_tag + ) # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: @@ -430,7 +439,9 @@ def create_app(args): else ( ollama_model_complete if args.llm_binding == "ollama" - else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + else bedrock_model_complete + if args.llm_binding == "aws_bedrock" + else openai_alike_model_complete ) ), llm_model_name=args.llm_model, @@ -453,7 +464,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,7 +493,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -516,7 +531,9 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "auth_configured": False, "access_token": guest_token, @@ -542,7 +559,9 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "access_token": guest_token, "token_type": "bearer", @@ -555,10 +574,14 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" + ) # Regular user login - user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) + user_token = auth_handler.create_token( + username=username, role="user", metadata={"auth_mode": "enabled"} + ) return { "access_token": user_token, "token_type": "bearer", @@ -607,8 +630,12 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model if rerank_model_func is not None else None, - "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, + "rerank_model": args.rerank_model + if rerank_model_func is not None + else None, + "rerank_binding_host": args.rerank_binding_host + if rerank_model_func is not None + else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -638,11 +665,17 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" + response.headers["Cache-Control"] = ( + "no-cache, no-store, must-revalidate" + ) response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = "public, max-age=31536000, immutable" + elif ( + "/assets/" in path + ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = ( + "public, max-age=31536000, immutable" + ) # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -658,7 +691,9 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles + SmartStaticFiles( + directory=static_dir, html=True, check_dir=True + ), # Use SmartStaticFiles name="webui", ) @@ -814,7 +849,9 @@ def main(): } ) - print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") + print( + f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 51c4b895..69d00e2d 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -100,10 +100,14 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): + if inference_params := list( + set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) + ): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) + args["inferenceConfig"][inference_params_map.get(param, param)] = ( + kwargs.pop(param) + ) # Import logging for error handling import logging @@ -119,7 +123,9 @@ async def bedrock_complete_if_cache( nonlocal client # Create the client outside the generator to ensure it stays open - client = await session.client("bedrock-runtime", region_name=region).__aenter__() + client = await session.client( + "bedrock-runtime", region_name=region + ).__aenter__() event_stream = None iteration_started = False @@ -158,7 +164,9 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream: {close_error}" + ) raise BedrockError(f"Streaming error: {e}") @@ -173,21 +181,27 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream in finally block: {close_error}" + ) # Clean up the client if client: try: await client.__aexit__(None, None, None) except Exception as client_close_error: - logging.warning(f"Failed to close Bedrock client: {client_close_error}") + logging.warning( + f"Failed to close Bedrock client: {client_close_error}" + ) # Return the generator that manages its own lifecycle return stream_generator() # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: try: # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) @@ -257,7 +271,9 @@ async def bedrock_embed( region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -285,7 +301,9 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) + body = json.dumps( + {"texts": texts, "input_type": "search_document", "truncate": "NONE"} + ) response = await bedrock_async_client.invoke_model( model=model, From 5d00c4c7a8ec1a2245830d0b23b688428d25b704 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 13:19:20 +0800 Subject: [PATCH 03/15] feat: move processed files to __enqueued__ directory after processing with filename conflicts handling --- lightrag/api/routers/document_routes.py | 57 +++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index eaec6fbf..7ad21f45 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -746,6 +746,39 @@ class DocumentManager: return any(filename.lower().endswith(ext) for ext in self.supported_extensions) +def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str: + """Generate a unique filename in the target directory by adding numeric suffixes if needed + + Args: + target_dir: Target directory path + original_name: Original filename + + Returns: + str: Unique filename (may have numeric suffix added) + """ + from pathlib import Path + import time + + original_path = Path(original_name) + base_name = original_path.stem + extension = original_path.suffix + + # Try original name first + if not (target_dir / original_name).exists(): + return original_name + + # Try with numeric suffixes 001-999 + for i in range(1, 1000): + suffix = f"{i:03d}" + new_name = f"{base_name}_{suffix}{extension}" + if not (target_dir / new_name).exists(): + return new_name + + # Fallback with timestamp if all 999 slots are taken + timestamp = int(time.time()) + return f"{base_name}_{timestamp}{extension}" + + async def pipeline_enqueue_file( rag: LightRAG, file_path: Path, track_id: str = None ) -> tuple[bool, str]: @@ -939,6 +972,30 @@ async def pipeline_enqueue_file( ) logger.info(f"Successfully fetched and enqueued file: {file_path.name}") + + # Move file to __enqueued__ directory after enqueuing + try: + enqueued_dir = file_path.parent / "__enqueued__" + enqueued_dir.mkdir(exist_ok=True) + + # Generate unique filename to avoid conflicts + unique_filename = get_unique_filename_in_enqueued( + enqueued_dir, file_path.name + ) + target_path = enqueued_dir / unique_filename + + # Move the file + file_path.rename(target_path) + logger.info( + f"Moved file to enqueued directory: {file_path.name} -> {unique_filename}" + ) + + except Exception as move_error: + logger.error( + f"Failed to move file {file_path.name} to __enqueued__ directory: {move_error}" + ) + # Don't affect the main function's success status + return True, track_id else: logger.error(f"No content could be extracted from file: {file_path.name}") From 5591ef3ac83b82a629f4b10429c822f225c68a76 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 17:22:08 +0800 Subject: [PATCH 04/15] Fix document filtering logic and improve logging for ignored docs --- lightrag/lightrag.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cf2aaf19..bef965fa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1077,19 +1077,21 @@ class LightRAG: # 4. Filter out already processed documents # Get docs ids all_new_doc_ids = set(new_docs.keys()) - # Exclude IDs of documents that are already in progress + # Exclude IDs of documents that are already enqueued unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids) - # Log ignored document IDs - ignored_ids = [ - doc_id for doc_id in unique_new_doc_ids if doc_id not in new_docs - ] + # Log ignored document IDs (documents that were filtered out because they already exist) + ignored_ids = list(all_new_doc_ids - unique_new_doc_ids) if ignored_ids: - logger.warning( - f"Ignoring {len(ignored_ids)} document IDs not found in new_docs" - ) for doc_id in ignored_ids: - logger.warning(f"Ignored document ID: {doc_id}") + file_path = new_docs.get(doc_id, {}).get("file_path", "unknown_source") + logger.warning( + f"Ignoring document ID (already exists): {doc_id} ({file_path})" + ) + if len(ignored_ids) > 3: + logger.warning( + f"Total Ignoring {len(ignored_ids)} document IDs that already exist in storage" + ) # Filter new_docs to only include documents with unique IDs new_docs = { @@ -1099,7 +1101,7 @@ class LightRAG: } if not new_docs: - logger.info("No new unique documents were found.") + logger.warning("No new unique documents were found.") return # 5. Store document content in full_docs and status in doc_status From e1310c526246f02a293a595c43cb04bf580f2741 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 17:23:01 +0800 Subject: [PATCH 05/15] Optimize document processing pipeline by removing duplicate step --- lightrag/lightrag.py | 60 ++++++++++++++++++-------------------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index bef965fa..8be61205 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -971,11 +971,10 @@ class LightRAG: """ Pipeline for Processing Documents - 1. Validate ids if provided or generate MD5 hash IDs - 2. Remove duplicate contents - 3. Generate document initial status - 4. Filter out already processed documents - 5. Enqueue document in status + 1. Validate ids if provided or generate MD5 hash IDs and remove duplicate contents + 2. Generate document initial status + 3. Filter out already processed documents + 4. Enqueue document in status Args: input: Single document string or list of document strings @@ -1008,7 +1007,7 @@ class LightRAG: # If no file paths provided, use placeholder file_paths = ["unknown_source"] * len(input) - # 1. Validate ids if provided or generate MD5 hash IDs + # 1. Validate ids if provided or generate MD5 hash IDs and remove duplicate contents if ids is not None: # Check if the number of IDs matches the number of documents if len(ids) != len(input): @@ -1018,22 +1017,25 @@ class LightRAG: if len(ids) != len(set(ids)): raise ValueError("IDs must be unique") - # Generate contents dict of IDs provided by user and documents + # Generate contents dict and remove duplicates in one pass + unique_contents = {} + for id_, doc, path in zip(ids, input, file_paths): + cleaned_content = clean_text(doc) + if cleaned_content not in unique_contents: + unique_contents[cleaned_content] = (id_, path) + + # Reconstruct contents with unique content contents = { - id_: {"content": doc, "file_path": path} - for id_, doc, path in zip(ids, input, file_paths) + id_: {"content": content, "file_path": file_path} + for content, (id_, file_path) in unique_contents.items() } else: - # Clean input text and remove duplicates - cleaned_input = [ - (clean_text(doc), path) for doc, path in zip(input, file_paths) - ] + # Clean input text and remove duplicates in one pass unique_content_with_paths = {} - - # Keep track of unique content and their paths - for content, path in cleaned_input: - if content not in unique_content_with_paths: - unique_content_with_paths[content] = path + for doc, path in zip(input, file_paths): + cleaned_content = clean_text(doc) + if cleaned_content not in unique_content_with_paths: + unique_content_with_paths[cleaned_content] = path # Generate contents dict of MD5 hash IDs and documents with paths contents = { @@ -1044,21 +1046,7 @@ class LightRAG: for content, path in unique_content_with_paths.items() } - # 2. Remove duplicate contents - unique_contents = {} - for id_, content_data in contents.items(): - content = content_data["content"] - file_path = content_data["file_path"] - if content not in unique_contents: - unique_contents[content] = (id_, file_path) - - # Reconstruct contents with unique content - contents = { - id_: {"content": content, "file_path": file_path} - for content, (id_, file_path) in unique_contents.items() - } - - # 3. Generate document initial status (without content) + # 2. Generate document initial status (without content) new_docs: dict[str, Any] = { id_: { "status": DocStatus.PENDING, @@ -1074,7 +1062,7 @@ class LightRAG: for id_, content_data in contents.items() } - # 4. Filter out already processed documents + # 3. Filter out already processed documents # Get docs ids all_new_doc_ids = set(new_docs.keys()) # Exclude IDs of documents that are already enqueued @@ -1104,8 +1092,8 @@ class LightRAG: logger.warning("No new unique documents were found.") return - # 5. Store document content in full_docs and status in doc_status - # Store full document content separately + # 4. Store document content in full_docs and status in doc_status + # Store full document content separately full_docs_data = { doc_id: {"content": contents[doc_id]["content"]} for doc_id in new_docs.keys() From ca4c18baaa9093aab54714587630fefd7af1197c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 22:29:46 +0800 Subject: [PATCH 06/15] Preserve failed documents during data consistency validation for manual review --- lightrag/lightrag.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8be61205..b659c70f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1114,17 +1114,37 @@ class LightRAG: pipeline_status: dict, pipeline_status_lock: asyncio.Lock, ) -> dict[str, DocProcessingStatus]: - """Validate and fix document data consistency by deleting inconsistent entries""" + """Validate and fix document data consistency by deleting inconsistent entries, but preserve failed documents""" inconsistent_docs = [] + failed_docs_to_preserve = [] # Check each document's data consistency for doc_id, status_doc in to_process_docs.items(): # Check if corresponding content exists in full_docs content_data = await self.full_docs.get_by_id(doc_id) if not content_data: - inconsistent_docs.append(doc_id) + # Check if this is a failed document that should be preserved + if ( + hasattr(status_doc, "status") + and status_doc.status == DocStatus.FAILED + ): + failed_docs_to_preserve.append(doc_id) + else: + inconsistent_docs.append(doc_id) - # Delete inconsistent document entries one by one + # Log information about failed documents that will be preserved + if failed_docs_to_preserve: + async with pipeline_status_lock: + preserve_message = f"Preserving {len(failed_docs_to_preserve)} failed document entries for manual review" + logger.info(preserve_message) + pipeline_status["latest_message"] = preserve_message + pipeline_status["history_messages"].append(preserve_message) + + # Remove failed documents from processing list but keep them in doc_status + for doc_id in failed_docs_to_preserve: + to_process_docs.pop(doc_id, None) + + # Delete inconsistent document entries(excluding failed documents) if inconsistent_docs: async with pipeline_status_lock: summary_message = ( @@ -1146,7 +1166,9 @@ class LightRAG: # Log successful deletion async with pipeline_status_lock: - log_message = f"Deleted entry: {doc_id} ({file_path})" + log_message = ( + f"Deleted inconsistent entry: {doc_id} ({file_path})" + ) logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -1164,7 +1186,7 @@ class LightRAG: # Final summary log async with pipeline_status_lock: - final_message = f"Data consistency cleanup completed: successfully deleted {successful_deletions} entries" + final_message = f"Data consistency cleanup completed: successfully deleted {successful_deletions} inconsistent entries, preserved {len(failed_docs_to_preserve)} failed documents" logger.info(final_message) pipeline_status["latest_message"] = final_message pipeline_status["history_messages"].append(final_message) From f5b0c3d38c6c7dbbabe1c079d3a7fda5202caa90 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 23:08:52 +0800 Subject: [PATCH 07/15] feat: Recording file extraction error status to document pipeline - Add apipeline_enqueue_error_documents function to LightRAG class for recording file processing errors in doc_status storage - Enhance pipeline_enqueue_file with detailed error handling for all file processing stages: * File access errors (permissions, not found) * UTF-8 encoding errors * Format-specific processing errors (PDF, DOCX, PPTX, XLSX) * Content validation errors * Unsupported file type errors This implementation ensures all file extraction failures are properly tracked and recorded in the doc_status storage system, providing better visibility into document processing issues and enabling improved error monitoring and debugging capabilities. --- lightrag/api/routers/document_routes.py | 589 ++++++++++++++++-------- lightrag/lightrag.py | 77 ++++ 2 files changed, 481 insertions(+), 185 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 7ad21f45..7a6e5973 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -792,225 +792,444 @@ async def pipeline_enqueue_file( tuple: (success: bool, track_id: str) """ + # Generate track_id if not provided + if track_id is None: + track_id = generate_track_id("unknown") + try: content = "" ext = file_path.suffix.lower() + file_size = 0 + + # Get file size for error reporting + try: + file_size = file_path.stat().st_size + except Exception: + file_size = 0 file = None - async with aiofiles.open(file_path, "rb") as f: - file = await f.read() + try: + async with aiofiles.open(file_path, "rb") as f: + file = await f.read() + except PermissionError as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "Permission denied - cannot read file", + "original_error": str(e), + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"Permission denied reading file: {file_path.name}") + return False, track_id + except FileNotFoundError as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "File not found", + "original_error": str(e), + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"File not found: {file_path.name}") + return False, track_id + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "File reading error", + "original_error": str(e), + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"Error reading file {file_path.name}: {str(e)}") + return False, track_id # Process based on file type - match ext: - case ( - ".txt" - | ".md" - | ".html" - | ".htm" - | ".tex" - | ".json" - | ".xml" - | ".yaml" - | ".yml" - | ".rtf" - | ".odt" - | ".epub" - | ".csv" - | ".log" - | ".conf" - | ".ini" - | ".properties" - | ".sql" - | ".bat" - | ".sh" - | ".c" - | ".cpp" - | ".py" - | ".java" - | ".js" - | ".ts" - | ".swift" - | ".go" - | ".rb" - | ".php" - | ".css" - | ".scss" - | ".less" - ): - try: - # Try to decode as UTF-8 - content = file.decode("utf-8") + try: + match ext: + case ( + ".txt" + | ".md" + | ".html" + | ".htm" + | ".tex" + | ".json" + | ".xml" + | ".yaml" + | ".yml" + | ".rtf" + | ".odt" + | ".epub" + | ".csv" + | ".log" + | ".conf" + | ".ini" + | ".properties" + | ".sql" + | ".bat" + | ".sh" + | ".c" + | ".cpp" + | ".py" + | ".java" + | ".js" + | ".ts" + | ".swift" + | ".go" + | ".rb" + | ".php" + | ".css" + | ".scss" + | ".less" + ): + try: + # Try to decode as UTF-8 + content = file.decode("utf-8") - # Validate content - if not content or len(content.strip()) == 0: - logger.error(f"Empty content in file: {file_path.name}") - return False, "" - - # Check if content looks like binary data string representation - if content.startswith("b'") or content.startswith('b"'): - logger.error( - f"File {file_path.name} appears to contain binary data representation instead of text" - ) - return False, "" - - except UnicodeDecodeError: - logger.error( - f"File {file_path.name} is not valid UTF-8 encoded text. Please convert it to UTF-8 before processing." - ) - return False, "" - case ".pdf": - if global_args.document_loading_engine == "DOCLING": - if not pm.is_installed("docling"): # type: ignore - pm.install("docling") - from docling.document_converter import DocumentConverter # type: ignore - - converter = DocumentConverter() - result = converter.convert(file_path) - content = result.document.export_to_markdown() - else: - if not pm.is_installed("pypdf2"): # type: ignore - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" - case ".docx": - if global_args.document_loading_engine == "DOCLING": - if not pm.is_installed("docling"): # type: ignore - pm.install("docling") - from docling.document_converter import DocumentConverter # type: ignore - - converter = DocumentConverter() - result = converter.convert(file_path) - content = result.document.export_to_markdown() - else: - if not pm.is_installed("python-docx"): # type: ignore - try: - pm.install("python-docx") - except Exception: - pm.install("docx") - from docx import Document # type: ignore - from io import BytesIO - - docx_file = BytesIO(file) - doc = Document(docx_file) - content = "\n".join( - [paragraph.text for paragraph in doc.paragraphs] - ) - case ".pptx": - if global_args.document_loading_engine == "DOCLING": - if not pm.is_installed("docling"): # type: ignore - pm.install("docling") - from docling.document_converter import DocumentConverter # type: ignore - - converter = DocumentConverter() - result = converter.convert(file_path) - content = result.document.export_to_markdown() - else: - if not pm.is_installed("python-pptx"): # type: ignore - pm.install("pptx") - from pptx import Presentation # type: ignore - from io import BytesIO - - pptx_file = BytesIO(file) - prs = Presentation(pptx_file) - for slide in prs.slides: - for shape in slide.shapes: - if hasattr(shape, "text"): - content += shape.text + "\n" - case ".xlsx": - if global_args.document_loading_engine == "DOCLING": - if not pm.is_installed("docling"): # type: ignore - pm.install("docling") - from docling.document_converter import DocumentConverter # type: ignore - - converter = DocumentConverter() - result = converter.convert(file_path) - content = result.document.export_to_markdown() - else: - if not pm.is_installed("openpyxl"): # type: ignore - pm.install("openpyxl") - from openpyxl import load_workbook # type: ignore - from io import BytesIO - - xlsx_file = BytesIO(file) - wb = load_workbook(xlsx_file) - for sheet in wb: - content += f"Sheet: {sheet.title}\n" - for row in sheet.iter_rows(values_only=True): - content += ( - "\t".join( - str(cell) if cell is not None else "" - for cell in row - ) - + "\n" + # Validate content + if not content or len(content.strip()) == 0: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "Empty file content", + "original_error": "File contains no content or only whitespace", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id ) - content += "\n" - case _: - logger.error( - f"Unsupported file type: {file_path.name} (extension {ext})" - ) - return False, "" + logger.error(f"Empty content in file: {file_path.name}") + return False, track_id + + # Check if content looks like binary data string representation + if content.startswith("b'") or content.startswith('b"'): + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "Binary data in text file", + "original_error": "File appears to contain binary data representation instead of text", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error( + f"File {file_path.name} appears to contain binary data representation instead of text" + ) + return False, track_id + + except UnicodeDecodeError as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "UTF-8 encoding error", + "original_error": f"File is not valid UTF-8 encoded text: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error( + f"File {file_path.name} is not valid UTF-8 encoded text. Please convert it to UTF-8 before processing." + ) + return False, track_id + + case ".pdf": + try: + if global_args.document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter # type: ignore + + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("pypdf2"): # type: ignore + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file) + reader = PdfReader(pdf_file) + for page in reader.pages: + content += page.extract_text() + "\n" + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "PDF processing error", + "original_error": f"Failed to extract text from PDF: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error(f"Error processing PDF {file_path.name}: {str(e)}") + return False, track_id + + case ".docx": + try: + if global_args.document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter # type: ignore + + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-docx"): # type: ignore + try: + pm.install("python-docx") + except Exception: + pm.install("docx") + from docx import Document # type: ignore + from io import BytesIO + + docx_file = BytesIO(file) + doc = Document(docx_file) + content = "\n".join( + [paragraph.text for paragraph in doc.paragraphs] + ) + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "DOCX processing error", + "original_error": f"Failed to extract text from DOCX: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error( + f"Error processing DOCX {file_path.name}: {str(e)}" + ) + return False, track_id + + case ".pptx": + try: + if global_args.document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter # type: ignore + + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("python-pptx"): # type: ignore + pm.install("pptx") + from pptx import Presentation # type: ignore + from io import BytesIO + + pptx_file = BytesIO(file) + prs = Presentation(pptx_file) + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + content += shape.text + "\n" + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "PPTX processing error", + "original_error": f"Failed to extract text from PPTX: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error( + f"Error processing PPTX {file_path.name}: {str(e)}" + ) + return False, track_id + + case ".xlsx": + try: + if global_args.document_loading_engine == "DOCLING": + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter # type: ignore + + converter = DocumentConverter() + result = converter.convert(file_path) + content = result.document.export_to_markdown() + else: + if not pm.is_installed("openpyxl"): # type: ignore + pm.install("openpyxl") + from openpyxl import load_workbook # type: ignore + from io import BytesIO + + xlsx_file = BytesIO(file) + wb = load_workbook(xlsx_file) + for sheet in wb: + content += f"Sheet: {sheet.title}\n" + for row in sheet.iter_rows(values_only=True): + content += ( + "\t".join( + str(cell) if cell is not None else "" + for cell in row + ) + + "\n" + ) + content += "\n" + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "XLSX processing error", + "original_error": f"Failed to extract text from XLSX: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents( + error_files, track_id + ) + logger.error( + f"Error processing XLSX {file_path.name}: {str(e)}" + ) + return False, track_id + + case _: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": f"Unsupported file type: {ext}", + "original_error": f"File extension {ext} is not supported", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error( + f"Unsupported file type: {file_path.name} (extension {ext})" + ) + return False, track_id + + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "File format processing error", + "original_error": f"Unexpected error during file processing: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"Unexpected error processing file {file_path.name}: {str(e)}") + return False, track_id # Insert into the RAG queue if content: # Check if content contains only whitespace characters if not content.strip(): + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "File contains only whitespace", + "original_error": "File content contains only whitespace characters", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) logger.warning( f"File contains only whitespace characters. file_paths={file_path.name}" ) + return False, track_id - # Generate track_id if not provided - if track_id is None: - track_id = generate_track_id("unkown") - - await rag.apipeline_enqueue_documents( - content, file_paths=file_path.name, track_id=track_id - ) - - logger.info(f"Successfully fetched and enqueued file: {file_path.name}") - - # Move file to __enqueued__ directory after enqueuing try: - enqueued_dir = file_path.parent / "__enqueued__" - enqueued_dir.mkdir(exist_ok=True) - - # Generate unique filename to avoid conflicts - unique_filename = get_unique_filename_in_enqueued( - enqueued_dir, file_path.name - ) - target_path = enqueued_dir / unique_filename - - # Move the file - file_path.rename(target_path) - logger.info( - f"Moved file to enqueued directory: {file_path.name} -> {unique_filename}" + await rag.apipeline_enqueue_documents( + content, file_paths=file_path.name, track_id=track_id ) - except Exception as move_error: - logger.error( - f"Failed to move file {file_path.name} to __enqueued__ directory: {move_error}" - ) - # Don't affect the main function's success status + logger.info(f"Successfully fetched and enqueued file: {file_path.name}") - return True, track_id + # Move file to __enqueued__ directory after enqueuing + try: + enqueued_dir = file_path.parent / "__enqueued__" + enqueued_dir.mkdir(exist_ok=True) + + # Generate unique filename to avoid conflicts + unique_filename = get_unique_filename_in_enqueued( + enqueued_dir, file_path.name + ) + target_path = enqueued_dir / unique_filename + + # Move the file + file_path.rename(target_path) + logger.info( + f"Moved file to enqueued directory: {file_path.name} -> {unique_filename}" + ) + + except Exception as move_error: + logger.error( + f"Failed to move file {file_path.name} to __enqueued__ directory: {move_error}" + ) + # Don't affect the main function's success status + + return True, track_id + + except Exception as e: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "Document enqueue error", + "original_error": f"Failed to enqueue document: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"Error enqueueing document {file_path.name}: {str(e)}") + return False, track_id else: + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "No content extracted", + "original_error": "No content could be extracted from file", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) logger.error(f"No content could be extracted from file: {file_path.name}") - return False, "" + return False, track_id except Exception as e: + # Catch-all for any unexpected errors + try: + file_size = file_path.stat().st_size if file_path.exists() else 0 + except Exception: + file_size = 0 + + error_files = [ + { + "file_path": str(file_path.name), + "error_description": "Unexpected processing error", + "original_error": f"Unexpected error: {str(e)}", + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") logger.error(traceback.format_exc()) + return False, track_id finally: if file_path.name.startswith(temp_prefix): try: file_path.unlink() except Exception as e: logger.error(f"Error deleting file {file_path}: {str(e)}") - return False, "" async def pipeline_index_file(rag: LightRAG, file_path: Path, track_id: str = None): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b659c70f..be5f2687 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1108,6 +1108,83 @@ class LightRAG: return track_id + async def apipeline_enqueue_error_documents( + self, + error_files: list[dict[str, Any]], + track_id: str | None = None, + ) -> None: + """ + Record file extraction errors in doc_status storage. + + This function creates error document entries in the doc_status storage for files + that failed during the extraction process. Each error entry contains information + about the failure to help with debugging and monitoring. + + Args: + error_files: List of dictionaries containing error information for each failed file. + Each dictionary should contain: + - file_path: Original file name/path + - error_description: Brief error description (for content_summary) + - original_error: Full error message (for error_msg) + - file_size: File size in bytes (for content_length, 0 if unknown) + track_id: Optional tracking ID for grouping related operations + + Returns: + None + """ + if not error_files: + logger.debug("No error files to record") + return + + # Generate track_id if not provided + if track_id is None or track_id.strip() == "": + track_id = generate_track_id("error") + + error_docs: dict[str, Any] = {} + current_time = datetime.now(timezone.utc).isoformat() + + for error_file in error_files: + file_path = error_file.get("file_path", "unknown_file") + error_description = error_file.get( + "error_description", "File extraction failed" + ) + original_error = error_file.get("original_error", "Unknown error") + file_size = error_file.get("file_size", 0) + + # Generate unique doc_id with "error-" prefix + doc_id_content = f"{file_path}-{error_description}" + doc_id = compute_mdhash_id(doc_id_content, prefix="error-") + + error_docs[doc_id] = { + "status": DocStatus.FAILED, + "content_summary": error_description, + "content_length": file_size, + "error_msg": original_error, + "chunks_count": 0, # No chunks for failed files + "created_at": current_time, + "updated_at": current_time, + "file_path": file_path, + "track_id": track_id, + "metadata": { + "error_type": "file_extraction_error", + }, + } + + # Store error documents in doc_status + if error_docs: + await self.doc_status.upsert(error_docs) + logger.info( + f"Recorded {len(error_docs)} file extraction errors in doc_status" + ) + + # Log each error for debugging + for doc_id, error_doc in error_docs.items(): + logger.error( + f"File extraction error recorded - ID: {doc_id}, " + f"File: {error_doc['file_path']}, " + f"Error: {error_doc['content_summary']}" + ) + async def _validate_and_fix_document_consistency( self, to_process_docs: dict[str, DocProcessingStatus], From cceb46b3209a805d251f18fda1d44105ba1366ac Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 23:46:33 +0800 Subject: [PATCH 08/15] fix: subdirectories are no longer processed during file scans MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Change rglob to glob for file scanning • Simplify error logging messages --- lightrag/api/routers/document_routes.py | 16 +++++++++------- lightrag/lightrag.py | 8 +------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 7a6e5973..e3477759 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -734,7 +734,7 @@ class DocumentManager: new_files = [] for ext in self.supported_extensions: logger.debug(f"Scanning for {ext} files in {self.input_dir}") - for file_path in self.input_dir.rglob(f"*{ext}"): + for file_path in self.input_dir.glob(f"*{ext}"): if file_path not in self.indexed_files: new_files.append(file_path) return new_files @@ -1122,12 +1122,14 @@ async def pipeline_enqueue_file( { "file_path": str(file_path.name), "error_description": "File format processing error", - "original_error": f"Unexpected error during file processing: {str(e)}", + "original_error": f"Unexpected error during file extracting: {str(e)}", "file_size": file_size, } ] await rag.apipeline_enqueue_error_documents(error_files, track_id) - logger.error(f"Unexpected error processing file {file_path.name}: {str(e)}") + logger.error( + f"Unexpected error during {file_path.name} extracting: {str(e)}" + ) return False, track_id # Insert into the RAG queue @@ -1144,7 +1146,7 @@ async def pipeline_enqueue_file( ] await rag.apipeline_enqueue_error_documents(error_files, track_id) logger.warning( - f"File contains only whitespace characters. file_paths={file_path.name}" + f"File contains only whitespace characters: {file_path.name}" ) return False, track_id @@ -1168,7 +1170,7 @@ async def pipeline_enqueue_file( # Move the file file_path.rename(target_path) - logger.info( + logger.debug( f"Moved file to enqueued directory: {file_path.name} -> {unique_filename}" ) @@ -1202,7 +1204,7 @@ async def pipeline_enqueue_file( } ] await rag.apipeline_enqueue_error_documents(error_files, track_id) - logger.error(f"No content could be extracted from file: {file_path.name}") + logger.error(f"No content extracted from file: {file_path.name}") return False, track_id except Exception as e: @@ -1221,7 +1223,7 @@ async def pipeline_enqueue_file( } ] await rag.apipeline_enqueue_error_documents(error_files, track_id) - logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logger.error(f"Enqueuing file {file_path.name} error: {str(e)}") logger.error(traceback.format_exc()) return False, track_id finally: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index be5f2687..d2a8ff46 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1173,16 +1173,10 @@ class LightRAG: # Store error documents in doc_status if error_docs: await self.doc_status.upsert(error_docs) - logger.info( - f"Recorded {len(error_docs)} file extraction errors in doc_status" - ) - # Log each error for debugging for doc_id, error_doc in error_docs.items(): logger.error( - f"File extraction error recorded - ID: {doc_id}, " - f"File: {error_doc['file_path']}, " - f"Error: {error_doc['content_summary']}" + f"File processing error: - ID: {doc_id} {error_doc['file_path']}" ) async def _validate_and_fix_document_consistency( From 45365ff6efa3b33da355284c60ffa7ade5fcfbf1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 16 Aug 2025 23:53:01 +0800 Subject: [PATCH 09/15] Bump api version to 0202 --- lightrag/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index a2433058..700baa24 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0201" +__api_version__ = "0202" From e064534941dcca1fbf415d7ce8d2dc52e5bf3b5c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 17 Aug 2025 01:33:39 +0800 Subject: [PATCH 10/15] feat(ui): enhance ClearDocumentsDialog with loading spinner and timeout protection - Add loading spinner animation during document clearing operation - Implement 30-second timeout protection to prevent hanging operations - Disable all interactive controls during clearing to prevent duplicate requests - Add comprehensive error handling with automatic state reset --- .../documents/ClearDocumentsDialog.tsx | 67 +++++++++++++++++-- lightrag_webui/src/locales/ar.json | 2 + lightrag_webui/src/locales/en.json | 2 + lightrag_webui/src/locales/fr.json | 2 + lightrag_webui/src/locales/zh.json | 2 + lightrag_webui/src/locales/zh_TW.json | 2 + 6 files changed, 70 insertions(+), 7 deletions(-) diff --git a/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx b/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx index b8438011..e42fbbfa 100644 --- a/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +++ b/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx @@ -1,4 +1,4 @@ -import { useState, useCallback, useEffect } from 'react' +import { useState, useCallback, useEffect, useRef } from 'react' import Button from '@/components/ui/Button' import { Dialog, @@ -15,7 +15,7 @@ import { toast } from 'sonner' import { errorMessage } from '@/lib/utils' import { clearDocuments, clearCache } from '@/api/lightrag' -import { EraserIcon, AlertTriangleIcon } from 'lucide-react' +import { EraserIcon, AlertTriangleIcon, Loader2Icon } from 'lucide-react' import { useTranslation } from 'react-i18next' // 简单的Label组件 @@ -43,18 +43,51 @@ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocume const [open, setOpen] = useState(false) const [confirmText, setConfirmText] = useState('') const [clearCacheOption, setClearCacheOption] = useState(false) + const [isClearing, setIsClearing] = useState(false) + const timeoutRef = useRef | null>(null) const isConfirmEnabled = confirmText.toLowerCase() === 'yes' + // 超时常量 (30秒) + const CLEAR_TIMEOUT = 30000 + // 重置状态当对话框关闭时 useEffect(() => { if (!open) { setConfirmText('') setClearCacheOption(false) + setIsClearing(false) + + // 清理超时定时器 + if (timeoutRef.current) { + clearTimeout(timeoutRef.current) + timeoutRef.current = null + } } }, [open]) + // 组件卸载时的清理 + useEffect(() => { + return () => { + // 组件卸载时清理超时定时器 + if (timeoutRef.current) { + clearTimeout(timeoutRef.current) + } + } + }, []) + const handleClear = useCallback(async () => { - if (!isConfirmEnabled) return + if (!isConfirmEnabled || isClearing) return + + setIsClearing(true) + + // 设置超时保护 + timeoutRef.current = setTimeout(() => { + if (isClearing) { + toast.error(t('documentPanel.clearDocuments.timeout')) + setIsClearing(false) + setConfirmText('') // 超时后重置确认文本 + } + }, CLEAR_TIMEOUT) try { const result = await clearDocuments() @@ -86,8 +119,15 @@ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocume } catch (err) { toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) })) setConfirmText('') + } finally { + // 清除超时定时器 + if (timeoutRef.current) { + clearTimeout(timeoutRef.current) + timeoutRef.current = null + } + setIsClearing(false) } - }, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared]) + }, [isConfirmEnabled, isClearing, clearCacheOption, setOpen, t, onDocumentsCleared, CLEAR_TIMEOUT]) return ( @@ -125,6 +165,7 @@ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocume onChange={(e: React.ChangeEvent) => setConfirmText(e.target.value)} placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')} className="w-full" + disabled={isClearing} /> @@ -133,6 +174,7 @@ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocume id="clear-cache" checked={clearCacheOption} onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)} + disabled={isClearing} />