From af3b2cf1184c5c4db2f79ba0218509f14a7f03ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:29 +0800 Subject: [PATCH] cherry-pick 0b3d3150 --- README.md | 2 + env.example | 88 ++---- lightrag/api/config.py | 141 +-------- lightrag/api/lightrag_server.py | 517 ++++++-------------------------- lightrag/llm/binding_options.py | 28 +- lightrag/llm/gemini.py | 357 +++------------------- pyproject.toml | 84 ++---- requirements-offline-llm.txt | 22 +- requirements-offline.txt | 9 +- 9 files changed, 209 insertions(+), 1039 deletions(-) diff --git a/README.md b/README.md index ebdfa3e4..68bb6a81 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,8 @@ cp env.example .env # Update the .env with your LLM and embedding configuration docker compose up ``` +> Tip: When targeting Google Gemini, set `LLM_BINDING=gemini`, choose a model such as `LLM_MODEL=gemini-flash-latest`, and provide your Gemini key via `LLM_BINDING_API_KEY` (or `GEMINI_API_KEY`). The server now understands this binding out of the box. + > Historical versions of LightRAG docker images can be found here: [LightRAG Docker Images]( https://github.com/HKUDS/LightRAG/pkgs/container/lightrag) ### Install LightRAG Core diff --git a/env.example b/env.example index 8fca2b5e..2c7faded 100644 --- a/env.example +++ b/env.example @@ -50,8 +50,6 @@ OLLAMA_EMULATING_MODEL_TAG=latest # JWT_ALGORITHM=HS256 ### API-Key to access LightRAG Server API -### Use this key in HTTP requests with the 'X-API-Key' header -### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query # LIGHTRAG_API_KEY=your-secure-api-key-here # WHITELIST_PATHS=/health,/api/* @@ -76,6 +74,11 @@ ENABLE_LLM_CACHE=true ### control the maximum tokens send to LLM (include entities, relations and chunks) # MAX_TOTAL_TOKENS=30000 +### maximum number of related chunks per source entity or relation +### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) +### Higher values increase re-ranking time +# RELATED_CHUNK_NUMBER=5 + ### chunk selection strategies ### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval ### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM @@ -121,9 +124,6 @@ ENABLE_LLM_CACHE_FOR_EXTRACT=true ### Document processing output language: English, Chinese, French, German ... SUMMARY_LANGUAGE=English -### PDF decryption password for protected PDF files -# PDF_DECRYPT_PASSWORD=your_pdf_password_here - ### Entity types that the LLM will attempt to recognize # ENTITY_TYPES='["Person", "Creature", "Organization", "Location", "Event", "Concept", "Method", "Content", "Data", "Artifact", "NaturalObject"]' @@ -140,22 +140,6 @@ SUMMARY_LANGUAGE=English ### Maximum context size sent to LLM for description summary # SUMMARY_CONTEXT_SIZE=12000 -### control the maximum chunk_ids stored in vector and graph db -# MAX_SOURCE_IDS_PER_ENTITY=300 -# MAX_SOURCE_IDS_PER_RELATION=300 -### control chunk_ids limitation method: FIFO, KEEP -### FIFO: First in first out -### KEEP: Keep oldest (less merge action and faster) -# SOURCE_IDS_LIMIT_METHOD=FIFO - -# Maximum number of file paths stored in entity/relation file_path field (For displayed only, does not affect query performance) -# MAX_FILE_PATHS=100 - -### maximum number of related chunks per source entity or relation -### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) -### Higher values increase re-ranking time -# RELATED_CHUNK_NUMBER=5 - ############################### ### Concurrency Configuration ############################### @@ -168,11 +152,10 @@ MAX_PARALLEL_INSERT=2 ### Num of chunks send to Embedding in single request # EMBEDDING_BATCH_NUM=10 -########################################################################### +########################################################### ### LLM Configuration -### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock -### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service -########################################################################### +### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini +########################################################### ### LLM request timeout setting for all llm (0 means no timeout for Ollma) # LLM_TIMEOUT=180 @@ -191,6 +174,14 @@ LLM_BINDING_API_KEY=your_api_key # LLM_BINDING_API_KEY=your_api_key # LLM_BINDING=openai +### Gemini example +# LLM_BINDING=gemini +# LLM_MODEL=gemini-flash-latest +# LLM_BINDING_HOST=https://generativelanguage.googleapis.com +# LLM_BINDING_API_KEY=your_gemini_api_key +# GEMINI_LLM_MAX_OUTPUT_TOKENS=8192 +# GEMINI_LLM_TEMPERATURE=0.7 + ### OpenAI Compatible API Specific Parameters ### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B. # OPENAI_LLM_TEMPERATURE=0.9 @@ -214,7 +205,6 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 # OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}' ### use the following command to see all support options for Ollama LLM -### If LightRAG deployed in Docker uses host.docker.internal instead of localhost in LLM_BINDING_HOST ### lightrag-server --llm-binding ollama --help ### Ollama Server Specific Parameters ### OLLAMA_LLM_NUM_CTX must be provided, and should at least larger than MAX_TOTAL_TOKENS + 2000 @@ -227,24 +217,16 @@ OLLAMA_LLM_NUM_CTX=32768 ### Bedrock Specific Parameters # BEDROCK_LLM_TEMPERATURE=1.0 -####################################################################################### +#################################################################################### ### Embedding Configuration (Should not be changed after the first file processed) ### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock -### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service -####################################################################################### +#################################################################################### # EMBEDDING_TIMEOUT=30 - -### Control whether to send embedding_dim parameter to embedding API -### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina) -### Set to 'false' (default) to disable sending dimension parameter -### Note: This is automatically ignored for backends that don't support dimension parameter -# EMBEDDING_SEND_DIM=false - EMBEDDING_BINDING=ollama EMBEDDING_MODEL=bge-m3:latest EMBEDDING_DIM=1024 EMBEDDING_BINDING_API_KEY=your_api_key -# If LightRAG deployed in Docker uses host.docker.internal instead of localhost +# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost EMBEDDING_BINDING_HOST=http://localhost:11434 ### OpenAI compatible (VoyageAI embedding openai compatible) @@ -401,35 +383,3 @@ MEMGRAPH_USERNAME= MEMGRAPH_PASSWORD= MEMGRAPH_DATABASE=memgraph # MEMGRAPH_WORKSPACE=forced_workspace_name - -############################ -### Evaluation Configuration -############################ -### RAGAS evaluation models (used for RAG quality assessment) -### ⚠️ IMPORTANT: Both LLM and Embedding endpoints MUST be OpenAI-compatible -### Default uses OpenAI models for evaluation - -### LLM Configuration for Evaluation -# EVAL_LLM_MODEL=gpt-4o-mini -### API key for LLM evaluation (fallback to OPENAI_API_KEY if not set) -# EVAL_LLM_BINDING_API_KEY=your_api_key -### Custom OpenAI-compatible endpoint for LLM evaluation (optional) -# EVAL_LLM_BINDING_HOST=https://api.openai.com/v1 - -### Embedding Configuration for Evaluation -# EVAL_EMBEDDING_MODEL=text-embedding-3-large -### API key for embeddings (fallback: EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY) -# EVAL_EMBEDDING_BINDING_API_KEY=your_embedding_api_key -### Custom OpenAI-compatible endpoint for embeddings (fallback: EVAL_LLM_BINDING_HOST) -# EVAL_EMBEDDING_BINDING_HOST=https://api.openai.com/v1 - -### Performance Tuning -### Number of concurrent test case evaluations -### Lower values reduce API rate limit issues but increase evaluation time -# EVAL_MAX_CONCURRENT=2 -### TOP_K query parameter of LightRAG (default: 10) -### Number of entities or relations retrieved from KG -# EVAL_QUERY_TOP_K=10 -### LLM request retry and timeout settings for evaluation -# EVAL_LLM_MAX_RETRIES=5 -# EVAL_LLM_TIMEOUT=180 diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 4d8ab1e1..ff5e65b1 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -8,7 +8,6 @@ import logging from dotenv import load_dotenv from lightrag.utils import get_env_value from lightrag.llm.binding_options import ( - GeminiEmbeddingOptions, GeminiLLMOptions, OllamaEmbeddingOptions, OllamaLLMOptions, @@ -239,15 +238,7 @@ def parse_args() -> argparse.Namespace: "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=[ - "lollms", - "ollama", - "openai", - "azure_openai", - "aws_bedrock", - "jina", - "gemini", - ], + choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) parser.add_argument( @@ -258,14 +249,6 @@ def parse_args() -> argparse.Namespace: help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})", ) - # Document loading engine configuration - parser.add_argument( - "--docling", - action="store_true", - default=False, - help="Enable DOCLING document loading engine (default: from env or DEFAULT)", - ) - # Conditionally add binding options defined in binding_options module # This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx) # and corresponding environment variables (e.g., OLLAMA_EMBEDDING_NUM_CTX) @@ -282,19 +265,12 @@ def parse_args() -> argparse.Namespace: if "--embedding-binding" in sys.argv: try: idx = sys.argv.index("--embedding-binding") - if idx + 1 < len(sys.argv): - if sys.argv[idx + 1] == "ollama": - OllamaEmbeddingOptions.add_args(parser) - elif sys.argv[idx + 1] == "gemini": - GeminiEmbeddingOptions.add_args(parser) + if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "ollama": + OllamaEmbeddingOptions.add_args(parser) except IndexError: pass - else: - env_embedding_binding = os.environ.get("EMBEDDING_BINDING") - if env_embedding_binding == "ollama": - OllamaEmbeddingOptions.add_args(parser) - elif env_embedding_binding == "gemini": - GeminiEmbeddingOptions.add_args(parser) + elif os.environ.get("EMBEDDING_BINDING") == "ollama": + OllamaEmbeddingOptions.add_args(parser) # Add OpenAI LLM options when llm-binding is openai or azure_openai if "--llm-binding" in sys.argv: @@ -365,13 +341,8 @@ def parse_args() -> argparse.Namespace: # Inject model configuration args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") - # EMBEDDING_MODEL defaults to None - each binding will use its own default model - # e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4" - args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True) - # EMBEDDING_DIM defaults to None - each binding will use its own default dimension - # Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator - args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True) - args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool) + args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") + args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) @@ -383,16 +354,8 @@ def parse_args() -> argparse.Namespace: ) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) - # Set document_loading_engine from --docling flag - if args.docling: - args.document_loading_engine = "DOCLING" - else: - args.document_loading_engine = get_env_value( - "DOCUMENT_LOADING_ENGINE", "DEFAULT" - ) - - # PDF decryption password - args.pdf_decrypt_password = get_env_value("PDF_DECRYPT_PASSWORD", None) + # Select Document loading tool (DOCLING, DEFAULT) + args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") # Add environment variables that were previously read directly args.cors_origins = get_env_value("CORS_ORIGINS", "*") @@ -449,11 +412,6 @@ def parse_args() -> argparse.Namespace: "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int ) - # Embedding token limit configuration - args.embedding_token_limit = get_env_value( - "EMBEDDING_TOKEN_LIMIT", None, int, special_none=True - ) - ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -471,83 +429,4 @@ def update_uvicorn_mode_config(): ) -# Global configuration with lazy initialization -_global_args = None -_initialized = False - - -def initialize_config(args=None, force=False): - """Initialize global configuration - - This function allows explicit initialization of the configuration, - which is useful for programmatic usage, testing, or embedding LightRAG - in other applications. - - Args: - args: Pre-parsed argparse.Namespace or None to parse from sys.argv - force: Force re-initialization even if already initialized - - Returns: - argparse.Namespace: The configured arguments - - Example: - # Use parsed command line arguments (default) - initialize_config() - - # Use custom configuration programmatically - custom_args = argparse.Namespace( - host='localhost', - port=8080, - working_dir='./custom_rag', - # ... other config - ) - initialize_config(custom_args) - """ - global _global_args, _initialized - - if _initialized and not force: - return _global_args - - _global_args = args if args is not None else parse_args() - _initialized = True - return _global_args - - -def get_config(): - """Get global configuration, auto-initializing if needed - - Returns: - argparse.Namespace: The configured arguments - """ - if not _initialized: - initialize_config() - return _global_args - - -class _GlobalArgsProxy: - """Proxy object that auto-initializes configuration on first access - - This maintains backward compatibility with existing code while - allowing programmatic control over initialization timing. - """ - - def __getattr__(self, name): - if not _initialized: - initialize_config() - return getattr(_global_args, name) - - def __setattr__(self, name, value): - if not _initialized: - initialize_config() - setattr(_global_args, name, value) - - def __repr__(self): - if not _initialized: - return "" - return repr(_global_args) - - -# Create proxy instance for backward compatibility -# Existing code like `from config import global_args` continues to work -# The proxy will auto-initialize on first attribute access -global_args = _GlobalArgsProxy() +global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0f6324b4..89feca32 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -5,16 +5,14 @@ LightRAG FastAPI Server from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from fastapi.openapi.docs import ( - get_swagger_ui_html, - get_swagger_ui_oauth2_redirect_html, -) import os import logging import logging.config +import signal import sys import uvicorn import pipmaster as pm +import inspect from fastapi.staticfiles import StaticFiles from fastapi.responses import RedirectResponse from pathlib import Path @@ -56,8 +54,7 @@ from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, - get_default_workspace, - # set_default_workspace, + initialize_pipeline_status, cleanup_keyed_lock, finalize_share_data, ) @@ -81,6 +78,24 @@ config.read("config.ini") auth_configured = bool(auth_handler.accounts) +def setup_signal_handlers(): + """Setup signal handlers for graceful shutdown""" + + def signal_handler(sig, frame): + print(f"\n\nReceived signal {sig}, shutting down gracefully...") + print(f"Process ID: {os.getpid()}") + + # Release shared resources + finalize_share_data() + + # Exit with success status + sys.exit(0) + + # Register signal handlers + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + + class LLMConfigCache: """Smart LLM and Embedding configuration cache class""" @@ -90,7 +105,6 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None self.gemini_llm_options = None - self.gemini_embedding_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -137,44 +151,20 @@ class LLMConfigCache: ) self.ollama_embedding_options = {} - # Only initialize and log Gemini Embedding options when using Gemini Embedding binding - if args.embedding_binding == "gemini": - try: - from lightrag.llm.binding_options import GeminiEmbeddingOptions - - self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict( - args - ) - logger.info( - f"Gemini Embedding Options: {self.gemini_embedding_options}" - ) - except ImportError: - logger.warning( - "GeminiEmbeddingOptions not available, using default configuration" - ) - self.gemini_embedding_options = {} - def check_frontend_build(): - """Check if frontend is built and optionally check if source is up-to-date - - Returns: - tuple: (assets_exist: bool, is_outdated: bool) - - assets_exist: True if WebUI build files exist - - is_outdated: True if source is newer than build (only in dev environment) - """ + """Check if frontend is built and optionally check if source is up-to-date""" webui_dir = Path(__file__).parent / "webui" index_html = webui_dir / "index.html" - # 1. Check if build files exist + # 1. Check if build files exist (required) if not index_html.exists(): - ASCIIColors.yellow("\n" + "=" * 80) - ASCIIColors.yellow("WARNING: Frontend Not Built") - ASCIIColors.yellow("=" * 80) + ASCIIColors.red("\n" + "=" * 80) + ASCIIColors.red("ERROR: Frontend Not Built") + ASCIIColors.red("=" * 80) ASCIIColors.yellow("The WebUI frontend has not been built yet.") - ASCIIColors.yellow("The API server will start without the WebUI interface.") ASCIIColors.yellow( - "\nTo enable WebUI, build the frontend using these commands:\n" + "Please build the frontend code first using the following commands:\n" ) ASCIIColors.cyan(" cd lightrag_webui") ASCIIColors.cyan(" bun install --frozen-lockfile") @@ -184,8 +174,8 @@ def check_frontend_build(): ASCIIColors.cyan( "Note: Make sure you have Bun installed. Visit https://bun.sh for installation." ) - ASCIIColors.yellow("=" * 80 + "\n") - return (False, False) # Assets don't exist, not outdated + ASCIIColors.red("=" * 80 + "\n") + sys.exit(1) # Exit immediately # 2. Check if this is a development environment (source directory exists) try: @@ -198,7 +188,7 @@ def check_frontend_build(): logger.debug( "Production environment detected, skipping source freshness check" ) - return (True, False) # Assets exist, not outdated (prod environment) + return # Development environment, perform source code timestamp check logger.debug("Development environment detected, checking source freshness") @@ -229,7 +219,7 @@ def check_frontend_build(): source_dir / "bun.lock", source_dir / "vite.config.ts", source_dir / "tsconfig.json", - source_dir / "tailraid.config.js", + source_dir / "tailwind.config.js", source_dir / "index.html", ] @@ -273,25 +263,17 @@ def check_frontend_build(): ASCIIColors.cyan(" cd ..") ASCIIColors.yellow("\nThe server will continue with the current build.") ASCIIColors.yellow("=" * 80 + "\n") - return (True, True) # Assets exist, outdated else: logger.info("Frontend build is up-to-date") - return (True, False) # Assets exist, up-to-date except Exception as e: # If check fails, log warning but don't affect startup logger.warning(f"Failed to check frontend source freshness: {e}") - return (True, False) # Assume assets exist and up-to-date on error def create_app(args): - # Check frontend build first and get status - webui_assets_exist, is_frontend_outdated = check_frontend_build() - - # Create unified API version display with warning symbol if frontend is outdated - api_version_display = ( - f"{__api_version__}⚠️" if is_frontend_outdated else __api_version__ - ) + # Check frontend build first + check_frontend_build() # Setup logging logger.setLevel(args.log_level) @@ -318,7 +300,6 @@ def create_app(args): "azure_openai", "aws_bedrock", "jina", - "gemini", ]: raise Exception("embedding binding not supported") @@ -354,8 +335,8 @@ def create_app(args): try: # Initialize database connections - # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() + await initialize_pipeline_status() # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -368,15 +349,8 @@ def create_app(args): # Clean up database connections await rag.finalize_storages() - if "LIGHTRAG_GUNICORN_MODE" not in os.environ: - # Only perform cleanup in Uvicorn single-process mode - logger.debug("Unvicorn Mode: finalizing shared storage...") - finalize_share_data() - else: - # In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks - logger.debug( - "Gunicorn Mode: postpone shared storage finalization to master process" - ) + # Clean up shared data + finalize_share_data() # Initialize FastAPI base_description = ( @@ -392,7 +366,7 @@ def create_app(args): "description": swagger_description, "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL - "docs_url": None, # Disable default docs, we'll create custom endpoint + "docs_url": "/docs", # Explicitly set docs URL "redoc_url": "/redoc", # Explicitly set redoc URL "lifespan": lifespan, } @@ -456,28 +430,6 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) - def get_workspace_from_request(request: Request) -> str | None: - """ - Extract workspace from HTTP request header or use default. - - This enables multi-workspace API support by checking the custom - 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the - server's default workspace configuration. - - Args: - request: FastAPI Request object - - Returns: - Workspace identifier (may be empty string for global namespace) - """ - # Check custom header first - workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() - - if not workspace: - workspace = None - - return workspace - # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -557,7 +509,7 @@ def create_app(args): return optimized_azure_openai_model_complete def create_optimized_gemini_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int + config_cache: LLMConfigCache, args ): """Create optimized Gemini LLM function with cached configuration""" @@ -573,8 +525,6 @@ def create_app(args): if history_messages is None: history_messages = [] - # Use pre-processed configuration to avoid repeated parsing - kwargs["timeout"] = llm_timeout if ( config_cache.gemini_llm_options is not None and "generation_config" not in kwargs @@ -616,7 +566,7 @@ def create_app(args): config_cache, args, llm_timeout ) elif binding == "gemini": - return create_optimized_gemini_llm_func(config_cache, args, llm_timeout) + return create_optimized_gemini_llm_func(config_cache, args) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) @@ -643,120 +593,34 @@ def create_app(args): return {} def create_optimized_embedding_function( - config_cache: LLMConfigCache, binding, model, host, api_key, args - ) -> EmbeddingFunc: + config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args + ): """ - Create optimized embedding function and return an EmbeddingFunc instance - with proper max_token_size inheritance from provider defaults. - - This function: - 1. Imports the provider embedding function - 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc - 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) - 4. Returns a properly configured EmbeddingFunc instance - - Configuration Rules: - - When EMBEDDING_MODEL is not set: Uses provider's default model and dimension - (e.g., jina-embeddings-v4 with 2048 dims, text-embedding-3-small with 1536 dims) - - When EMBEDDING_MODEL is set to a custom model: User MUST also set EMBEDDING_DIM - to match the custom model's dimension (e.g., for jina-embeddings-v3, set EMBEDDING_DIM=1024) - - Note: The embedding_dim parameter is automatically injected by EmbeddingFunc wrapper - when send_dimensions=True (enabled for Jina and Gemini bindings). This wrapper calls - the underlying provider function directly (.func) to avoid double-wrapping, so we must - explicitly pass embedding_dim to the provider's underlying function. + Create optimized embedding function with pre-processed configuration for applicable bindings. + Uses lazy imports for all bindings and avoids repeated configuration parsing. """ - # Step 1: Import provider function and extract default attributes - provider_func = None - provider_max_token_size = None - provider_embedding_dim = None - - try: - if binding == "openai": - from lightrag.llm.openai import openai_embed - - provider_func = openai_embed - elif binding == "ollama": - from lightrag.llm.ollama import ollama_embed - - provider_func = ollama_embed - elif binding == "gemini": - from lightrag.llm.gemini import gemini_embed - - provider_func = gemini_embed - elif binding == "jina": - from lightrag.llm.jina import jina_embed - - provider_func = jina_embed - elif binding == "azure_openai": - from lightrag.llm.azure_openai import azure_openai_embed - - provider_func = azure_openai_embed - elif binding == "aws_bedrock": - from lightrag.llm.bedrock import bedrock_embed - - provider_func = bedrock_embed - elif binding == "lollms": - from lightrag.llm.lollms import lollms_embed - - provider_func = lollms_embed - - # Extract attributes if provider is an EmbeddingFunc - if provider_func and isinstance(provider_func, EmbeddingFunc): - provider_max_token_size = provider_func.max_token_size - provider_embedding_dim = provider_func.embedding_dim - logger.debug( - f"Extracted from {binding} provider: " - f"max_token_size={provider_max_token_size}, " - f"embedding_dim={provider_embedding_dim}" - ) - except ImportError as e: - logger.warning(f"Could not import provider function for {binding}: {e}") - - # Step 2: Apply priority (user config > provider default) - # For max_token_size: explicit env var > provider default > None - final_max_token_size = args.embedding_token_limit or provider_max_token_size - # For embedding_dim: user config (always has value) takes priority - # Only use provider default if user config is explicitly None (which shouldn't happen) - final_embedding_dim = ( - args.embedding_dim if args.embedding_dim else provider_embedding_dim - ) - - # Step 3: Create optimized embedding function (calls underlying function directly) - async def optimized_embedding_function(texts, embedding_dim=None): + async def optimized_embedding_function(texts): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - lollms_embed.func - if isinstance(lollms_embed, EmbeddingFunc) - else lollms_embed - ) - return await actual_func( + return await lollms_embed( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - ollama_embed.func - if isinstance(ollama_embed, EmbeddingFunc) - else ollama_embed - ) - - # Use pre-processed configuration if available + # Use pre-processed configuration if available, otherwise fallback to dynamic parsing if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: + # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await actual_func( + return await ollama_embed( texts, embed_model=model, host=host, @@ -766,94 +630,27 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - actual_func = ( - azure_openai_embed.func - if isinstance(azure_openai_embed, EmbeddingFunc) - else azure_openai_embed - ) - return await actual_func(texts, model=model, api_key=api_key) + return await azure_openai_embed(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - actual_func = ( - bedrock_embed.func - if isinstance(bedrock_embed, EmbeddingFunc) - else bedrock_embed - ) - return await actual_func(texts, model=model) + return await bedrock_embed(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - actual_func = ( - jina_embed.func - if isinstance(jina_embed, EmbeddingFunc) - else jina_embed - ) - return await actual_func( - texts, - model=model, - embedding_dim=embedding_dim, - base_url=host, - api_key=api_key, - ) - elif binding == "gemini": - from lightrag.llm.gemini import gemini_embed - - actual_func = ( - gemini_embed.func - if isinstance(gemini_embed, EmbeddingFunc) - else gemini_embed - ) - - # Use pre-processed configuration if available - if config_cache.gemini_embedding_options is not None: - gemini_options = config_cache.gemini_embedding_options - else: - from lightrag.llm.binding_options import GeminiEmbeddingOptions - - gemini_options = GeminiEmbeddingOptions.options_dict(args) - - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, - task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), + return await jina_embed( + texts, dimensions=dimensions, base_url=host, api_key=api_key ) else: # openai and compatible from lightrag.llm.openai import openai_embed - actual_func = ( - openai_embed.func - if isinstance(openai_embed, EmbeddingFunc) - else openai_embed - ) - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, + return await openai_embed( + texts, model=model, base_url=host, api_key=api_key ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - # Step 4: Wrap in EmbeddingFunc and return - embedding_func_instance = EmbeddingFunc( - embedding_dim=final_embedding_dim, - func=optimized_embedding_function, - max_token_size=final_max_token_size, - send_dimensions=False, # Will be set later based on binding requirements - ) - - # Log final embedding configuration - logger.info( - f"Embedding config: binding={binding} model={model} " - f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" - ) - - return embedding_func_instance + return optimized_embedding_function llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -887,63 +684,20 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration and max_token_size inheritance - import inspect - - # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) - embedding_func = create_optimized_embedding_function( - config_cache=config_cache, - binding=args.embedding_binding, - model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - args=args, + # Create embedding function with optimized configuration + embedding_func = EmbeddingFunc( + embedding_dim=args.embedding_dim, + func=create_optimized_embedding_function( + config_cache=config_cache, + binding=args.embedding_binding, + model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + dimensions=args.embedding_dim, + args=args, # Pass args object for fallback option generation + ), ) - # Get embedding_send_dim from centralized configuration - embedding_send_dim = args.embedding_send_dim - - # Check if the underlying function signature has embedding_dim parameter - sig = inspect.signature(embedding_func.func) - has_embedding_dim_param = "embedding_dim" in sig.parameters - - # Determine send_dimensions value based on binding type - # Jina and Gemini REQUIRE dimension parameter (forced to True) - # OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable - if args.embedding_binding in ["jina", "gemini"]: - # Jina and Gemini APIs require dimension parameter - always send it - send_dimensions = has_embedding_dim_param - dimension_control = f"forced by {args.embedding_binding.title()} API" - else: - # For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting - send_dimensions = embedding_send_dim and has_embedding_dim_param - if send_dimensions or not embedding_send_dim: - dimension_control = "by env var" - else: - dimension_control = "by not hasparam" - - # Set send_dimensions on the EmbeddingFunc instance - embedding_func.send_dimensions = send_dimensions - - logger.info( - f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " - f"binding={args.embedding_binding})" - ) - - # Log max_token_size source - if embedding_func.max_token_size: - source = ( - "env variable" - if args.embedding_token_limit - else f"{args.embedding_binding} provider default" - ) - logger.info( - f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" - ) - else: - logger.info("Embedding max_token_size: not set (90% token warning disabled)") - # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None if args.rerank_binding != "null": @@ -1061,32 +815,10 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") - # Custom Swagger UI endpoint for offline support - @app.get("/docs", include_in_schema=False) - async def custom_swagger_ui_html(): - """Custom Swagger UI HTML with local static files""" - return get_swagger_ui_html( - openapi_url=app.openapi_url, - title=app.title + " - Swagger UI", - oauth2_redirect_url="/docs/oauth2-redirect", - swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js", - swagger_css_url="/static/swagger-ui/swagger-ui.css", - swagger_favicon_url="/static/swagger-ui/favicon-32x32.png", - swagger_ui_parameters=app.swagger_ui_parameters, - ) - - @app.get("/docs/oauth2-redirect", include_in_schema=False) - async def swagger_ui_redirect(): - """OAuth2 redirect for Swagger UI""" - return get_swagger_ui_oauth2_redirect_html() - @app.get("/") async def redirect_to_webui(): - """Redirect root path based on WebUI availability""" - if webui_assets_exist: - return RedirectResponse(url="/webui") - else: - return RedirectResponse(url="/docs") + """Redirect root path to /webui""" + return RedirectResponse(url="/webui") @app.get("/auth-status") async def get_auth_status(): @@ -1104,7 +836,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": api_version_display, + "api_version": __api_version__, "webui_title": webui_title, "webui_description": webui_description, } @@ -1113,7 +845,7 @@ def create_app(args): "auth_configured": True, "auth_mode": "enabled", "core_version": core_version, - "api_version": api_version_display, + "api_version": __api_version__, "webui_title": webui_title, "webui_description": webui_description, } @@ -1131,7 +863,7 @@ def create_app(args): "auth_mode": "disabled", "message": "Authentication is disabled. Using guest access.", "core_version": core_version, - "api_version": api_version_display, + "api_version": __api_version__, "webui_title": webui_title, "webui_description": webui_description, } @@ -1148,54 +880,16 @@ def create_app(args): "token_type": "bearer", "auth_mode": "enabled", "core_version": core_version, - "api_version": api_version_display, + "api_version": __api_version__, "webui_title": webui_title, "webui_description": webui_description, } - @app.get( - "/health", - dependencies=[Depends(combined_auth)], - summary="Get system health and configuration status", - description="Returns comprehensive system status including WebUI availability, configuration, and operational metrics", - response_description="System health status with configuration details", - responses={ - 200: { - "description": "Successful response with system status", - "content": { - "application/json": { - "example": { - "status": "healthy", - "webui_available": True, - "working_directory": "/path/to/working/dir", - "input_directory": "/path/to/input/dir", - "configuration": { - "llm_binding": "openai", - "llm_model": "gpt-4", - "embedding_binding": "openai", - "embedding_model": "text-embedding-ada-002", - "workspace": "default", - }, - "auth_mode": "enabled", - "pipeline_busy": False, - "core_version": "0.0.1", - "api_version": "0.0.1", - } - } - }, - } - }, - ) - async def get_status(request: Request): - """Get current system status including WebUI availability""" + @app.get("/health", dependencies=[Depends(combined_auth)]) + async def get_status(): + """Get current system status""" try: - workspace = get_workspace_from_request(request) - default_workspace = get_default_workspace() - if workspace is None: - workspace = default_workspace - pipeline_status = await get_namespace_data( - "pipeline_status", workspace=workspace - ) + pipeline_status = await get_namespace_data("pipeline_status") if not auth_configured: auth_mode = "disabled" @@ -1207,7 +901,6 @@ def create_app(args): return { "status": "healthy", - "webui_available": webui_assets_exist, "working_directory": str(args.working_dir), "input_directory": str(args.input_dir), "configuration": { @@ -1227,7 +920,7 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, + "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -1251,7 +944,7 @@ def create_app(args): "pipeline_busy": pipeline_status.get("busy", False), "keyed_locks": keyed_lock_info, "core_version": core_version, - "api_version": api_version_display, + "api_version": __api_version__, "webui_title": webui_title, "webui_description": webui_description, } @@ -1288,36 +981,16 @@ def create_app(args): return response - # Mount Swagger UI static files for offline support - swagger_static_dir = Path(__file__).parent / "static" / "swagger-ui" - if swagger_static_dir.exists(): - app.mount( - "/static/swagger-ui", - StaticFiles(directory=swagger_static_dir), - name="swagger-ui-static", - ) - - # Conditionally mount WebUI only if assets exist - if webui_assets_exist: - static_dir = Path(__file__).parent / "webui" - static_dir.mkdir(exist_ok=True) - app.mount( - "/webui", - SmartStaticFiles( - directory=static_dir, html=True, check_dir=True - ), # Use SmartStaticFiles - name="webui", - ) - logger.info("WebUI assets mounted at /webui") - else: - logger.info("WebUI assets not available, /webui route not mounted") - - # Add redirect for /webui when assets are not available - @app.get("/webui") - @app.get("/webui/") - async def webui_redirect_to_docs(): - """Redirect /webui to /docs when WebUI is not available""" - return RedirectResponse(url="/docs") + # Webui mount webui/index.html + static_dir = Path(__file__).parent / "webui" + static_dir.mkdir(exist_ok=True) + app.mount( + "/webui", + SmartStaticFiles( + directory=static_dir, html=True, check_dir=True + ), # Use SmartStaticFiles + name="webui", + ) return app @@ -1427,12 +1100,6 @@ def check_and_install_dependencies(): def main(): - # Explicitly initialize configuration for clarity - # (The proxy will auto-initialize anyway, but this makes intent clear) - from .config import initialize_config - - initialize_config() - # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application @@ -1455,10 +1122,8 @@ def main(): update_uvicorn_mode_config() display_splash_screen(global_args) - # Note: Signal handlers are NOT registered here because: - # - Uvicorn has built-in signal handling that properly calls lifespan shutdown - # - Custom signal handlers can interfere with uvicorn's graceful shutdown - # - Cleanup is handled by the lifespan context manager's finally block + # Setup signal handlers for graceful shutdown + setup_signal_handlers() # Create application instance directly instead of using factory function app = create_app(global_args) diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 1cb52a81..e2f94649 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -228,7 +228,7 @@ class BindingOptions: argdef = { "argname": f"{args_prefix}-{field.name}", "env_name": f"{env_var_prefix}{field.name.upper()}", - "type": _resolve_optional_type(field.type), + "type": _resolve_optional_type(field.type), "default": default_value, "help": f"{cls._binding_name} -- " + help.get(field.name, ""), } @@ -472,9 +472,6 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): _binding_name: ClassVar[str] = "ollama_llm" -# ============================================================================= -# Binding Options for Gemini -# ============================================================================= @dataclass class GeminiLLMOptions(BindingOptions): """Options for Google Gemini models.""" @@ -489,9 +486,9 @@ class GeminiLLMOptions(BindingOptions): presence_penalty: float = 0.0 frequency_penalty: float = 0.0 stop_sequences: List[str] = field(default_factory=list) - seed: int | None = None - thinking_config: dict | None = None + response_mime_type: str | None = None safety_settings: dict | None = None + system_instruction: str | None = None _help: ClassVar[dict[str, str]] = { "temperature": "Controls randomness (0.0-2.0, higher = more creative)", @@ -501,23 +498,10 @@ class GeminiLLMOptions(BindingOptions): "candidate_count": "Number of candidates returned per request", "presence_penalty": "Penalty for token presence (-2.0 to 2.0)", "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)", - "stop_sequences": "Stop sequences (JSON array of strings, e.g., '[\"END\"]')", - "seed": "Random seed for reproducible generation (leave empty for random)", - "thinking_config": "Thinking configuration (JSON dict, e.g., '{\"thinking_budget\": 1024}' or '{\"include_thoughts\": true}')", + "stop_sequences": 'Stop sequences (JSON array of strings, e.g., \'["END"]\')', + "response_mime_type": "Desired MIME type for the response (e.g., application/json)", "safety_settings": "JSON object with Gemini safety settings overrides", - } - - -@dataclass -class GeminiEmbeddingOptions(BindingOptions): - """Options for Google Gemini embedding models.""" - - _binding_name: ClassVar[str] = "gemini_embedding" - - task_type: str = "RETRIEVAL_DOCUMENT" - - _help: ClassVar[dict[str, str]] = { - "task_type": "Task type for embedding optimization (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION)", + "system_instruction": "Default system instruction applied to every request", } diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 5372307e..14a1b238 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -16,20 +16,7 @@ from collections.abc import AsyncIterator from functools import lru_cache from typing import Any -import numpy as np -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, -) - -from lightrag.utils import ( - logger, - remove_think_tags, - safe_unicode_decode, - wrap_embedding_func_with_attrs, -) +from lightrag.utils import logger, remove_think_tags, safe_unicode_decode import pipmaster as pm @@ -46,33 +33,24 @@ LOG = logging.getLogger(__name__) @lru_cache(maxsize=8) -def _get_gemini_client( - api_key: str, base_url: str | None, timeout: int | None = None -) -> genai.Client: +def _get_gemini_client(api_key: str, base_url: str | None) -> genai.Client: """ Create (or fetch cached) Gemini client. Args: api_key: Google Gemini API key. base_url: Optional custom API endpoint. - timeout: Optional request timeout in milliseconds. Returns: genai.Client: Configured Gemini client instance. """ client_kwargs: dict[str, Any] = {"api_key": api_key} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT or timeout is not None: + if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: try: - http_options_kwargs = {} - if base_url and base_url != DEFAULT_GEMINI_ENDPOINT: - http_options_kwargs["api_endpoint"] = base_url - if timeout is not None: - http_options_kwargs["timeout"] = timeout - - client_kwargs["http_options"] = types.HttpOptions(**http_options_kwargs) + client_kwargs["http_options"] = types.HttpOptions(api_endpoint=base_url) except Exception as exc: # pragma: no cover - defensive - LOG.warning("Failed to apply custom Gemini http_options: %s", exc) + LOG.warning("Failed to apply custom Gemini endpoint %s: %s", base_url, exc) try: return genai.Client(**client_kwargs) @@ -136,44 +114,24 @@ def _format_history_messages(history_messages: list[dict[str, Any]] | None) -> s return "\n".join(history_lines) -def _extract_response_text( - response: Any, extract_thoughts: bool = False -) -> tuple[str, str]: - """ - Extract text content from Gemini response, separating regular content from thoughts. +def _extract_response_text(response: Any) -> str: + if getattr(response, "text", None): + return response.text - Args: - response: Gemini API response object - extract_thoughts: Whether to extract thought content separately - - Returns: - Tuple of (regular_text, thought_text) - """ candidates = getattr(response, "candidates", None) if not candidates: - return ("", "") - - regular_parts: list[str] = [] - thought_parts: list[str] = [] + return "" + parts: list[str] = [] for candidate in candidates: if not getattr(candidate, "content", None): continue - # Use 'or []' to handle None values from parts attribute - for part in getattr(candidate.content, "parts", None) or []: + for part in getattr(candidate.content, "parts", []): text = getattr(part, "text", None) - if not text: - continue + if text: + parts.append(text) - # Check if this part is thought content using the 'thought' attribute - is_thought = getattr(part, "thought", False) - - if is_thought and extract_thoughts: - thought_parts.append(text) - elif not is_thought: - regular_parts.append(text) - - return ("\n".join(regular_parts), "\n".join(thought_parts)) + return "\n".join(parts) async def gemini_complete_if_cache( @@ -181,58 +139,22 @@ async def gemini_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, - enable_cot: bool = False, - base_url: str | None = None, + *, api_key: str | None = None, - token_tracker: Any | None = None, - stream: bool | None = None, - keyword_extraction: bool = False, + base_url: str | None = None, generation_config: dict[str, Any] | None = None, - timeout: int | None = None, + keyword_extraction: bool = False, + token_tracker: Any | None = None, + hashing_kv: Any | None = None, # noqa: ARG001 - present for interface parity + stream: bool | None = None, + enable_cot: bool = False, # noqa: ARG001 - not supported by Gemini currently + timeout: float | None = None, # noqa: ARG001 - handled by caller if needed **_: Any, ) -> str | AsyncIterator[str]: - """ - Complete a prompt using Gemini's API with Chain of Thought (COT) support. - - This function supports automatic integration of reasoning content from Gemini models - that provide Chain of Thought capabilities via the thinking_config API feature. - - COT Integration: - - When enable_cot=True: Thought content is wrapped in ... tags - - When enable_cot=False: Thought content is filtered out, only regular content returned - - Thought content is identified by the 'thought' attribute on response parts - - Requires thinking_config to be enabled in generation_config for API to return thoughts - - Args: - model: The Gemini model to use. - prompt: The prompt to complete. - system_prompt: Optional system prompt to include. - history_messages: Optional list of previous messages in the conversation. - api_key: Optional Gemini API key. If None, uses environment variable. - base_url: Optional custom API endpoint. - generation_config: Optional generation configuration dict. - keyword_extraction: Whether to use JSON response format. - token_tracker: Optional token usage tracker for monitoring API usage. - stream: Whether to stream the response. - hashing_kv: Storage interface (for interface parity with other bindings). - enable_cot: Whether to include Chain of Thought content in the response. - timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). - **_: Additional keyword arguments (ignored). - - Returns: - The completed text (with COT content if enable_cot=True) or an async iterator - of text chunks if streaming. COT content is wrapped in ... tags. - - Raises: - RuntimeError: If the response from Gemini is empty. - ValueError: If API key is not provided or configured. - """ loop = asyncio.get_running_loop() key = _ensure_api_key(api_key) - # Convert timeout from seconds to milliseconds for Gemini API - timeout_ms = timeout * 1000 if timeout else None - client = _get_gemini_client(key, base_url, timeout_ms) + client = _get_gemini_client(key, base_url) history_block = _format_history_messages(history_messages) prompt_sections = [] @@ -262,11 +184,6 @@ async def gemini_complete_if_cache( usage_container: dict[str, Any] = {} def _stream_model() -> None: - # COT state tracking for streaming - cot_active = False - cot_started = False - initial_content_seen = False - try: stream_kwargs = dict(request_kwargs) stream_iterator = client.models.generate_content_stream(**stream_kwargs) @@ -274,61 +191,18 @@ async def gemini_complete_if_cache( usage = getattr(chunk, "usage_metadata", None) if usage is not None: usage_container["usage"] = usage - - # Extract both regular and thought content - regular_text, thought_text = _extract_response_text( - chunk, extract_thoughts=True - ) - - if enable_cot: - # Process regular content - if regular_text: - if not initial_content_seen: - initial_content_seen = True - - # Close COT section if it was active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = False - - # Send regular content - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Process thought content - if thought_text: - if not initial_content_seen and not cot_started: - # Start COT section - loop.call_soon_threadsafe(queue.put_nowait, "") - cot_active = True - cot_started = True - - # Send thought content if COT is active - if cot_active: - loop.call_soon_threadsafe( - queue.put_nowait, thought_text - ) - else: - # COT disabled - only send regular content - if regular_text: - loop.call_soon_threadsafe(queue.put_nowait, regular_text) - - # Ensure COT is properly closed if still active - if cot_active: - loop.call_soon_threadsafe(queue.put_nowait, "") - + text_piece = getattr(chunk, "text", None) or _extract_response_text(chunk) + if text_piece: + loop.call_soon_threadsafe(queue.put_nowait, text_piece) loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as exc: # pragma: no cover - surface runtime issues - # Try to close COT tag before reporting error - if cot_active: - try: - loop.call_soon_threadsafe(queue.put_nowait, "") - except Exception: - pass loop.call_soon_threadsafe(queue.put_nowait, exc) loop.run_in_executor(None, _stream_model) async def _async_stream() -> AsyncIterator[str]: + accumulated = "" + emitted = "" try: while True: item = await queue.get() @@ -341,9 +215,16 @@ async def gemini_complete_if_cache( if "\\u" in chunk_text: chunk_text = safe_unicode_decode(chunk_text.encode("utf-8")) - # Yield the chunk directly without filtering - # COT filtering is already handled in _stream_model() - yield chunk_text + accumulated += chunk_text + sanitized = remove_think_tags(accumulated) + if sanitized.startswith(emitted): + delta = sanitized[len(emitted) :] + else: + delta = sanitized + emitted = sanitized + + if delta: + yield delta finally: usage = usage_container.get("usage") if token_tracker and usage: @@ -361,33 +242,14 @@ async def gemini_complete_if_cache( response = await asyncio.to_thread(_call_model) - # Extract both regular text and thought text - regular_text, thought_text = _extract_response_text(response, extract_thoughts=True) - - # Apply COT filtering logic based on enable_cot parameter - if enable_cot: - # Include thought content wrapped in tags - if thought_text and thought_text.strip(): - if not regular_text or regular_text.strip() == "": - # Only thought content available - final_text = f"{thought_text}" - else: - # Both content types present: prepend thought to regular content - final_text = f"{thought_text}{regular_text}" - else: - # No thought content, use regular content only - final_text = regular_text or "" - else: - # Filter out thought content, return only regular content - final_text = regular_text or "" - - if not final_text: + text = _extract_response_text(response) + if not text: raise RuntimeError("Gemini response did not contain any text content.") - if "\\u" in final_text: - final_text = safe_unicode_decode(final_text.encode("utf-8")) + if "\\u" in text: + text = safe_unicode_decode(text.encode("utf-8")) - final_text = remove_think_tags(final_text) + text = remove_think_tags(text) usage = getattr(response, "usage_metadata", None) if token_tracker and usage: @@ -399,8 +261,8 @@ async def gemini_complete_if_cache( } ) - logger.debug("Gemini response length: %s", len(final_text)) - return final_text + logger.debug("Gemini response length: %s", len(text)) + return text async def gemini_model_complete( @@ -429,136 +291,7 @@ async def gemini_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=( - retry_if_exception_type(Exception) # Gemini uses generic exceptions - ), -) -async def gemini_embed( - texts: list[str], - model: str = "gemini-embedding-001", - base_url: str | None = None, - api_key: str | None = None, - embedding_dim: int | None = None, - task_type: str = "RETRIEVAL_DOCUMENT", - timeout: int | None = None, - token_tracker: Any | None = None, -) -> np.ndarray: - """Generate embeddings for a list of texts using Gemini's API. - - This function uses Google's Gemini embedding model to generate text embeddings. - It supports dynamic dimension control and automatic normalization for dimensions - less than 3072. - - Args: - texts: List of texts to embed. - model: The Gemini embedding model to use. Default is "gemini-embedding-001". - base_url: Optional custom API endpoint. - api_key: Optional Gemini API key. If None, uses environment variables. - embedding_dim: Optional embedding dimension for dynamic dimension reduction. - **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. - Do NOT manually pass this parameter when calling the function directly. - The dimension is controlled by the @wrap_embedding_func_with_attrs decorator - or the EMBEDDING_DIM environment variable. - Supported range: 128-3072. Recommended values: 768, 1536, 3072. - task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". - Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, - RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, - QUESTION_ANSWERING, FACT_VERIFICATION. - timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). - token_tracker: Optional token usage tracker for monitoring API usage. - - Returns: - A numpy array of embeddings, one per input text. For dimensions < 3072, - the embeddings are L2-normalized to ensure optimal semantic similarity performance. - - Raises: - ValueError: If API key is not provided or configured. - RuntimeError: If the response from Gemini is invalid or empty. - - Note: - - For dimension 3072: Embeddings are already normalized by the API - - For dimensions < 3072: Embeddings are L2-normalized after retrieval - - Normalization ensures accurate semantic similarity via cosine distance - """ - loop = asyncio.get_running_loop() - - key = _ensure_api_key(api_key) - # Convert timeout from seconds to milliseconds for Gemini API - timeout_ms = timeout * 1000 if timeout else None - client = _get_gemini_client(key, base_url, timeout_ms) - - # Prepare embedding configuration - config_kwargs: dict[str, Any] = {} - - # Add task_type to config - if task_type: - config_kwargs["task_type"] = task_type - - # Add output_dimensionality if embedding_dim is provided - if embedding_dim is not None: - config_kwargs["output_dimensionality"] = embedding_dim - - # Create config object if we have parameters - config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None - - def _call_embed() -> Any: - """Call Gemini embedding API in executor thread.""" - request_kwargs: dict[str, Any] = { - "model": model, - "contents": texts, - } - if config_obj is not None: - request_kwargs["config"] = config_obj - - return client.models.embed_content(**request_kwargs) - - # Execute API call in thread pool - response = await loop.run_in_executor(None, _call_embed) - - # Extract embeddings from response - if not hasattr(response, "embeddings") or not response.embeddings: - raise RuntimeError("Gemini response did not contain embeddings.") - - # Convert embeddings to numpy array - embeddings = np.array( - [np.array(e.values, dtype=np.float32) for e in response.embeddings] - ) - - # Apply L2 normalization for dimensions < 3072 - # The 3072 dimension embedding is already normalized by Gemini API - if embedding_dim and embedding_dim < 3072: - # Normalize each embedding vector to unit length - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - # Avoid division by zero - norms = np.where(norms == 0, 1, norms) - embeddings = embeddings / norms - logger.debug( - f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}" - ) - - # Track token usage if tracker is provided - # Note: Gemini embedding API may not provide usage metadata - if token_tracker and hasattr(response, "usage_metadata"): - usage = response.usage_metadata - token_counts = { - "prompt_tokens": getattr(usage, "prompt_token_count", 0), - "total_tokens": getattr(usage, "total_token_count", 0), - } - token_tracker.add_usage(token_counts) - - logger.debug( - f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}" - ) - - return embeddings - - __all__ = [ "gemini_complete_if_cache", "gemini_model_complete", - "gemini_embed", ] diff --git a/pyproject.toml b/pyproject.toml index 9650a1cf..a4f16ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,12 @@ dependencies = [ "aiohttp", "configparser", "future", - "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", "json_repair", "nano-vectordb", "networkx", - "numpy>=1.24.0,<2.0.0", - "pandas>=2.0.0,<2.4.0", + "numpy", + "pandas>=2.0.0,<2.3.0", "pipmaster", "pydantic", "pypinyin", @@ -42,14 +41,6 @@ dependencies = [ ] [project.optional-dependencies] -# Test framework dependencies (for CI/CD and testing) -pytest = [ - "pytest>=8.4.2", - "pytest-asyncio>=1.2.0", - "pre-commit", - "ruff", -] - api = [ # Core dependencies "aiohttp", @@ -58,9 +49,9 @@ api = [ "json_repair", "nano-vectordb", "networkx", - "numpy>=1.24.0,<2.0.0", - "openai>=1.0.0,<3.0.0", - "pandas>=2.0.0,<2.4.0", + "numpy", + "openai>=1.0.0,<2.0.0", + "pandas>=2.0.0,<2.3.0", "pipmaster", "pydantic", "pypinyin", @@ -69,7 +60,6 @@ api = [ "tenacity", "tiktoken", "xlsxwriter>=3.1.0", - "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", # API-specific dependencies "aiofiles", @@ -78,7 +68,7 @@ api = [ "distro", "fastapi", "httpcore", - "httpx>=0.28.1", + "httpx", "jiter", "passlib[bcrypt]", "psutil", @@ -87,68 +77,42 @@ api = [ "python-multipart", "pytz", "uvicorn", - "gunicorn", - # Document processing dependencies (required for API document upload functionality) - "openpyxl>=3.0.0,<4.0.0", # XLSX processing - "pycryptodome>=3.0.0,<4.0.0", # PDF encryption support - "pypdf>=6.1.0", # PDF processing - "python-docx>=0.8.11,<2.0.0", # DOCX processing - "python-pptx>=0.6.21,<2.0.0", # PPTX processing -] - -# Advanced document processing engine (optional) -docling = [ - # On macOS, pytorch and frameworks use Objective-C are not fork-safe, - # and not compatible to gunicorn multi-worker mode - "docling>=2.0.0,<3.0.0; sys_platform != 'darwin'", ] # Offline deployment dependencies (layered design for flexibility) +offline-docs = [ + # Document processing dependencies + "pypdf2>=3.0.0", + "python-docx>=0.8.11,<2.0.0", + "python-pptx>=0.6.21,<2.0.0", + "openpyxl>=3.0.0,<4.0.0", +] + offline-storage = [ # Storage backend dependencies - "redis>=5.0.0,<8.0.0", + "redis>=5.0.0,<7.0.0", "neo4j>=5.0.0,<7.0.0", "pymilvus>=2.6.2,<3.0.0", "pymongo>=4.0.0,<5.0.0", "asyncpg>=0.29.0,<1.0.0", - "qdrant-client>=1.11.0,<2.0.0", + "qdrant-client>=1.7.0,<2.0.0", ] offline-llm = [ # LLM provider dependencies - "openai>=1.0.0,<3.0.0", + "openai>=1.0.0,<2.0.0", "anthropic>=0.18.0,<1.0.0", "ollama>=0.1.0,<1.0.0", "zhipuai>=2.0.0,<3.0.0", "aioboto3>=12.0.0,<16.0.0", "voyageai>=0.2.0,<1.0.0", "llama-index>=0.9.0,<1.0.0", - "google-api-core>=2.0.0,<3.0.0", "google-genai>=1.0.0,<2.0.0", ] offline = [ - # Complete offline package (includes api for document processing, plus storage and LLM) - "lightrag-hku[api,offline-storage,offline-llm]", -] - -test = [ - "lightrag-hku[api]", - "pytest>=8.4.2", - "pytest-asyncio>=1.2.0", - "pre-commit", - "ruff", -] - -evaluation = [ - "lightrag-hku[api]", - "ragas>=0.3.7", - "datasets>=4.3.0", -] - -observability = [ - # LLM observability and tracing dependencies - "langfuse>=3.8.1", + # Complete offline package (includes all offline dependencies) + "lightrag-hku[offline-docs,offline-storage,offline-llm]", ] [project.scripts] @@ -173,15 +137,7 @@ include-package-data = true version = {attr = "lightrag.__version__"} [tool.setuptools.package-data] -lightrag = ["api/webui/**/*", "api/static/**/*"] - -[tool.pytest.ini_options] -asyncio_mode = "auto" -asyncio_default_fixture_loop_scope = "function" -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] +lightrag = ["api/webui/**/*"] [tool.ruff] target-version = "py310" diff --git a/requirements-offline-llm.txt b/requirements-offline-llm.txt index cb0121b7..441abc6e 100644 --- a/requirements-offline-llm.txt +++ b/requirements-offline-llm.txt @@ -3,14 +3,16 @@ # For offline installation: # pip download -r requirements-offline-llm.txt -d ./packages # pip install --no-index --find-links=./packages -r requirements-offline-llm.txt +# +# Recommended: Use pip install lightrag-hku[offline-llm] for the same effect +# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-llm.txt -aioboto3>=12.0.0 -anthropic>=0.18.0 -llama-index>=0.9.0 -ollama>=0.1.0 -# LLM provider dependencies -openai>=1.0.0 -torch>=2.0.0 -transformers>=4.30.0 -voyageai>=0.2.0 -zhipuai>=2.0.0 +# LLM provider dependencies (with version constraints matching pyproject.toml) +aioboto3>=12.0.0,<16.0.0 +anthropic>=0.18.0,<1.0.0 +llama-index>=0.9.0,<1.0.0 +ollama>=0.1.0,<1.0.0 +openai>=1.0.0,<2.0.0 +google-genai>=1.0.0,<2.0.0 +voyageai>=0.2.0,<1.0.0 +zhipuai>=2.0.0,<3.0.0 diff --git a/requirements-offline.txt b/requirements-offline.txt index 0582eaca..d6943b11 100644 --- a/requirements-offline.txt +++ b/requirements-offline.txt @@ -13,21 +13,20 @@ anthropic>=0.18.0,<1.0.0 # Storage backend dependencies asyncpg>=0.29.0,<1.0.0 -google-genai>=1.0.0,<2.0.0 # Document processing dependencies llama-index>=0.9.0,<1.0.0 neo4j>=5.0.0,<7.0.0 ollama>=0.1.0,<1.0.0 -openai>=1.0.0,<3.0.0 +openai>=1.0.0,<2.0.0 +google-genai>=1.0.0,<2.0.0 openpyxl>=3.0.0,<4.0.0 -pycryptodome>=3.0.0,<4.0.0 pymilvus>=2.6.2,<3.0.0 pymongo>=4.0.0,<5.0.0 pypdf2>=3.0.0 python-docx>=0.8.11,<2.0.0 python-pptx>=0.6.21,<2.0.0 -qdrant-client>=1.11.0,<2.0.0 -redis>=5.0.0,<8.0.0 +qdrant-client>=1.7.0,<2.0.0 +redis>=5.0.0,<7.0.0 voyageai>=0.2.0,<1.0.0 zhipuai>=2.0.0,<3.0.0