From 96f23d59af07ac880cf79a5e34b1ea561d57097d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:30 +0800 Subject: [PATCH] cherry-pick fc40a369 --- env.example | 5 +- lightrag/api/lightrag_server.py | 108 ++++++++++------ lightrag/llm/gemini.py | 212 +++++++++++++++++++++++++------- lightrag/llm/openai.py | 5 +- 4 files changed, 245 insertions(+), 85 deletions(-) diff --git a/env.example b/env.example index 2c7faded..dd3389b2 100644 --- a/env.example +++ b/env.example @@ -177,9 +177,10 @@ LLM_BINDING_API_KEY=your_api_key ### Gemini example # LLM_BINDING=gemini # LLM_MODEL=gemini-flash-latest -# LLM_BINDING_HOST=https://generativelanguage.googleapis.com # LLM_BINDING_API_KEY=your_gemini_api_key -# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192 +# LLM_BINDING_HOST=https://generativelanguage.googleapis.com +GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}' +# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000 # GEMINI_LLM_TEMPERATURE=0.7 ### OpenAI Compatible API Specific Parameters diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 89feca32..c9bb1a44 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -5,10 +5,13 @@ LightRAG FastAPI Server from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from fastapi.openapi.docs import ( + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, +) import os import logging import logging.config -import signal import sys import uvicorn import pipmaster as pm @@ -78,24 +81,6 @@ config.read("config.ini") auth_configured = bool(auth_handler.accounts) -def setup_signal_handlers(): - """Setup signal handlers for graceful shutdown""" - - def signal_handler(sig, frame): - print(f"\n\nReceived signal {sig}, shutting down gracefully...") - print(f"Process ID: {os.getpid()}") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - # Register signal handlers - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - class LLMConfigCache: """Smart LLM and Embedding configuration cache class""" @@ -153,7 +138,11 @@ class LLMConfigCache: def check_frontend_build(): - """Check if frontend is built and optionally check if source is up-to-date""" + """Check if frontend is built and optionally check if source is up-to-date + + Returns: + bool: True if frontend is outdated, False if up-to-date or production environment + """ webui_dir = Path(__file__).parent / "webui" index_html = webui_dir / "index.html" @@ -188,7 +177,7 @@ def check_frontend_build(): logger.debug( "Production environment detected, skipping source freshness check" ) - return + return False # Development environment, perform source code timestamp check logger.debug("Development environment detected, checking source freshness") @@ -219,7 +208,7 @@ def check_frontend_build(): source_dir / "bun.lock", source_dir / "vite.config.ts", source_dir / "tsconfig.json", - source_dir / "tailwind.config.js", + source_dir / "tailraid.config.js", source_dir / "index.html", ] @@ -263,17 +252,25 @@ def check_frontend_build(): ASCIIColors.cyan(" cd ..") ASCIIColors.yellow("\nThe server will continue with the current build.") ASCIIColors.yellow("=" * 80 + "\n") + return True # Frontend is outdated else: logger.info("Frontend build is up-to-date") + return False # Frontend is up-to-date except Exception as e: # If check fails, log warning but don't affect startup logger.warning(f"Failed to check frontend source freshness: {e}") + return False # Assume up-to-date on error def create_app(args): - # Check frontend build first - check_frontend_build() + # Check frontend build first and get outdated status + is_frontend_outdated = check_frontend_build() + + # Create unified API version display with warning symbol if frontend is outdated + api_version_display = ( + f"{__api_version__}⚠️" if is_frontend_outdated else __api_version__ + ) # Setup logging logger.setLevel(args.log_level) @@ -349,8 +346,15 @@ def create_app(args): # Clean up database connections await rag.finalize_storages() - # Clean up shared data - finalize_share_data() + if "LIGHTRAG_GUNICORN_MODE" not in os.environ: + # Only perform cleanup in Uvicorn single-process mode + logger.debug("Unvicorn Mode: finalizing shared storage...") + finalize_share_data() + else: + # In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks + logger.debug( + "Gunicorn Mode: postpone shared storage finalization to master process" + ) # Initialize FastAPI base_description = ( @@ -366,7 +370,7 @@ def create_app(args): "description": swagger_description, "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL - "docs_url": "/docs", # Explicitly set docs URL + "docs_url": None, # Disable default docs, we'll create custom endpoint "redoc_url": "/redoc", # Explicitly set redoc URL "lifespan": lifespan, } @@ -509,7 +513,7 @@ def create_app(args): return optimized_azure_openai_model_complete def create_optimized_gemini_llm_func( - config_cache: LLMConfigCache, args + config_cache: LLMConfigCache, args, llm_timeout: int ): """Create optimized Gemini LLM function with cached configuration""" @@ -525,6 +529,8 @@ def create_app(args): if history_messages is None: history_messages = [] + # Use pre-processed configuration to avoid repeated parsing + kwargs["timeout"] = llm_timeout if ( config_cache.gemini_llm_options is not None and "generation_config" not in kwargs @@ -566,7 +572,7 @@ def create_app(args): config_cache, args, llm_timeout ) elif binding == "gemini": - return create_optimized_gemini_llm_func(config_cache, args) + return create_optimized_gemini_llm_func(config_cache, args, llm_timeout) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) @@ -815,6 +821,25 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") + # Custom Swagger UI endpoint for offline support + @app.get("/docs", include_in_schema=False) + async def custom_swagger_ui_html(): + """Custom Swagger UI HTML with local static files""" + return get_swagger_ui_html( + openapi_url=app.openapi_url, + title=app.title + " - Swagger UI", + oauth2_redirect_url="/docs/oauth2-redirect", + swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js", + swagger_css_url="/static/swagger-ui/swagger-ui.css", + swagger_favicon_url="/static/swagger-ui/favicon-32x32.png", + swagger_ui_parameters=app.swagger_ui_parameters, + ) + + @app.get("/docs/oauth2-redirect", include_in_schema=False) + async def swagger_ui_redirect(): + """OAuth2 redirect for Swagger UI""" + return get_swagger_ui_oauth2_redirect_html() + @app.get("/") async def redirect_to_webui(): """Redirect root path to /webui""" @@ -836,7 +861,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -845,7 +870,7 @@ def create_app(args): "auth_configured": True, "auth_mode": "enabled", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -863,7 +888,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -880,7 +905,7 @@ def create_app(args): "token_type": "bearer", "auth_mode": "enabled", "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -944,7 +969,7 @@ def create_app(args): "pipeline_busy": pipeline_status.get("busy", False), "keyed_locks": keyed_lock_info, "core_version": core_version, - "api_version": __api_version__, + "api_version": api_version_display, "webui_title": webui_title, "webui_description": webui_description, } @@ -981,6 +1006,15 @@ def create_app(args): return response + # Mount Swagger UI static files for offline support + swagger_static_dir = Path(__file__).parent / "static" / "swagger-ui" + if swagger_static_dir.exists(): + app.mount( + "/static/swagger-ui", + StaticFiles(directory=swagger_static_dir), + name="swagger-ui-static", + ) + # Webui mount webui/index.html static_dir = Path(__file__).parent / "webui" static_dir.mkdir(exist_ok=True) @@ -1122,8 +1156,10 @@ def main(): update_uvicorn_mode_config() display_splash_screen(global_args) - # Setup signal handlers for graceful shutdown - setup_signal_handlers() + # Note: Signal handlers are NOT registered here because: + # - Uvicorn has built-in signal handling that properly calls lifespan shutdown + # - Custom signal handlers can interfere with uvicorn's graceful shutdown + # - Cleanup is handled by the lifespan context manager's finally block # Create application instance directly instead of using factory function app = create_app(global_args) diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index f3991403..f06ec6b3 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -33,24 +33,33 @@ LOG = logging.getLogger(__name__) @lru_cache(maxsize=8) -def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client: +def _get_gemini_client( + api_key: str, base_url: str | None, timeout: int | None = None +) -> genai.Client: """ Create (or fetch cached) Gemini client. Args: api_key: Google Gemini API key. base_url: Optional custom API endpoint. + timeout: Optional request timeout in milliseconds. Returns: genai.Client: Configured Gemini client instance. """ client_kwargs: dict[str, Any] = {"api_key": api_key} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None: try: - client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url) + http_options_kwargs = {} + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: + http_options_kwargs["api_endpoint"] = base_url + if timeout is not None: + http_options_kwargs["timeout"] = timeout + + client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs) except Exception as exc: # pragma: no cover - defensive - LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc) + LOG.warning("Failed to apply custom Gemini http_options: %s", exc) try: return genai.Client(**client_kwargs) @@ -114,28 +123,44 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text(response: Any) -> str: +def _extract_response_text( + response: Any, extract_thoughts: bool = False +) -> tuple[str, str]: """ - Extract text content from Gemini response, avoiding warnings about non-text parts. + Extract text content from Gemini response, separating regular content from thoughts. - Always extracts text manually from parts to avoid triggering warnings when - non-text parts (like 'thought_signature') are present in the response. + Args: + response: Gemini API response object + extract_thoughts: Whether to extract thought content separately + + Returns: + Tuple of (regular_text, thought_text) """ candidates = getattr(response, "candidates", None) if not candidates: - return "" + return ("", "") + + regular_parts: list[str] = [] + thought_parts: list[str] = [] - parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - for part in getattr(candidate.content, "parts", []): - # Only extract text parts to avoid non-text content like thought_signature + # Use 'or []' to handle None values from parts attribute + for part in getattr(candidate.content, "parts", None) or []: text = getattr(part, "text", None) - if text: - parts.append(text) + if not text: + continue - return "\n".join(parts) + # Check if this part is thought content using the 'thought' attribute + is_thought = getattr(part, "thought", False) + + if is_thought and extract_thoughts: + thought_parts.append(text) + elif not is_thought: + regular_parts.append(text) + + return ("\n".join(regular_parts), "\n".join(thought_parts)) async def gemini_complete_if_cache( @@ -143,22 +168,58 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - *, - api_key: str | None = None, + enable_cot: bool = False, base_url: str | None = None, - generation_config: dict[str, Any] | None = None, - keyword_extraction: bool = False, + api_key: str | None = None, token_tracker: Any | None = None, - hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity stream: bool | None = None, - enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently - timeout: float | None = None, # noqa: ARG001 - handled by caller if needed + keyword_extraction: bool = False, + generation_config: dict[str, Any] | None = None, + timeout: int | None = None, **_: Any, ) -> str | AsyncIterator[str]: + """ + Complete a prompt using Gemini's API with Chain of Thought (COT) support. + + This function supports automatic integration of reasoning content from Gemini models + that provide Chain of Thought capabilities via the thinking_config API feature. + + COT Integration: + - When enable_cot=True: Thought content is wrapped in ... tags + - When enable_cot=False: Thought content is filtered out, only regular content returned + - Thought content is identified by the 'thought' attribute on response parts + - Requires thinking_config to be enabled in generation_config for API to return thoughts + + Args: + model: The Gemini model to use. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + api_key: Optional Gemini API key. If None, uses environment variable. + base_url: Optional custom API endpoint. + generation_config: Optional generation configuration dict. + keyword_extraction: Whether to use JSON response format. + token_tracker: Optional token usage tracker for monitoring API usage. + stream: Whether to stream the response. + hashing_kv: Storage interface (for interface parity with other bindings). + enable_cot: Whether to include Chain of Thought content in the response. + timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). + **_: Additional keyword arguments (ignored). + + Returns: + The completed text (with COT content if enable_cot=True) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. + + Raises: + RuntimeError: If the response from Gemini is empty. + ValueError: If API key is not provided or configured. + """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) - client = _get_gemini_client(key, base_url) + # Convert timeout from seconds to milliseconds for Gemini API + timeout_ms = timeout * 1000 if timeout else None + client = _get_gemini_client(key, base_url, timeout_ms) history_block = _format_history_messages(history_messages) prompt_sections = [] @@ -188,6 +249,11 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: + # COT state tracking for streaming + cot_active = False + cot_started = False + initial_content_seen = False + try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -195,19 +261,61 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - # Always use manual extraction to avoid warnings about non-text parts - text_piece = _extract_response_text(chunk) - if text_piece: - loop.call_soon_threadsafe(queue.put_nowait, text_piece) + + # Extract both regular and thought content + regular_text, thought_text = _extract_response_text( + chunk, extract_thoughts=True + ) + + if enable_cot: + # Process regular content + if regular_text: + if not initial_content_seen: + initial_content_seen = True + + # Close COT section if it was active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = False + + # Send regular content + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Process thought content + if thought_text: + if not initial_content_seen and not cot_started: + # Start COT section + loop.call_soon_threadsafe(queue.put_nowait, "") + cot_active = True + cot_started = True + + # Send thought content if COT is active + if cot_active: + loop.call_soon_threadsafe( + queue.put_nowait, thought_text + ) + else: + # COT disabled - only send regular content + if regular_text: + loop.call_soon_threadsafe(queue.put_nowait, regular_text) + + # Ensure COT is properly closed if still active + if cot_active: + loop.call_soon_threadsafe(queue.put_nowait, "") + loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues + # Try to close COT tag before reporting error + if cot_active: + try: + loop.call_soon_threadsafe(queue.put_nowait, "") + except Exception: + pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: - accumulated = "" - emitted = "" try: while True: item = await queue.get() @@ -220,16 +328,9 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - accumulated += chunk_text - sanitized = remove_think_tags(accumulated) - if sanitized.startswith(emitted): - delta = sanitized[len(emitted) :] - else: - delta = sanitized - emitted = sanitized - - if delta: - yield delta + # Yield the chunk directly without filtering + # COT filtering is already handled in _stream_model() + yield chunk_text finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -247,14 +348,33 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - text = _extract_response_text(response) - if not text: + # Extract both regular text and thought text + regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) + + # Apply COT filtering logic based on enable_cot parameter + if enable_cot: + # Include thought content wrapped in tags + if thought_text and thought_text.strip(): + if not regular_text or regular_text.strip() == "": + # Only thought content available + final_text = f"{thought_text}" + else: + # Both content types present: prepend thought to regular content + final_text = f"{thought_text}{regular_text}" + else: + # No thought content, use regular content only + final_text = regular_text or "" + else: + # Filter out thought content, return only regular content + final_text = regular_text or "" + + if not final_text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in text: - text = safe_unicode_decode(text.encode("utf-8")) + if "\\u" in final_text: + final_text = safe_unicode_decode(final_text.encode("utf-8")) - text = remove_think_tags(text) + final_text = remove_think_tags(final_text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -266,8 +386,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(text)) - return text + logger.debug("Gemini response length: %s", len(final_text)) + return final_text async def gemini_model_complete( diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 2cdbb72b..511a3a62 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -138,6 +138,9 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, + keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI + stream: bool | None = None, + timeout: int | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -172,9 +175,9 @@ async def openai_complete_if_cache( - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - - hashing_kv: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI. - stream: Whether to stream the response. Default is False. + - timeout: Request timeout in seconds. Default is None. Returns: The completed text (with integrated COT content if available) or an async iterator