From 7f5afd0a4df09c4053c174b0f7c23d156c45334f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:27 +0800 Subject: [PATCH] cherry-pick 5dec4dea --- lightrag/api/lightrag_server.py | 202 ++++++++++++++++++++++++++------ 1 file changed, 166 insertions(+), 36 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 939129c0..41a07f7f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -618,33 +618,108 @@ def create_app(args): def create_optimized_embedding_function( config_cache: LLMConfigCache, binding, model, host, api_key, args - ): + ) -> EmbeddingFunc: """ - Create optimized embedding function with pre-processed configuration for applicable bindings. - Uses lazy imports for all bindings and avoids repeated configuration parsing. + Create optimized embedding function and return an EmbeddingFunc instance + with proper max_token_size inheritance from provider defaults. + + This function: + 1. Imports the provider embedding function + 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc + 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) + 4. Returns a properly configured EmbeddingFunc instance """ + # Step 1: Import provider function and extract default attributes + provider_func = None + provider_max_token_size = None + provider_embedding_dim = None + + try: + if binding == "openai": + from lightrag.llm.openai import openai_embed + + provider_func = openai_embed + elif binding == "ollama": + from lightrag.llm.ollama import ollama_embed + + provider_func = ollama_embed + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + provider_func = gemini_embed + elif binding == "jina": + from lightrag.llm.jina import jina_embed + + provider_func = jina_embed + elif binding == "azure_openai": + from lightrag.llm.azure_openai import azure_openai_embed + + provider_func = azure_openai_embed + elif binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_embed + + provider_func = bedrock_embed + elif binding == "lollms": + from lightrag.llm.lollms import lollms_embed + + provider_func = lollms_embed + + # Extract attributes if provider is an EmbeddingFunc + if provider_func and isinstance(provider_func, EmbeddingFunc): + provider_max_token_size = provider_func.max_token_size + provider_embedding_dim = provider_func.embedding_dim + logger.debug( + f"Extracted from {binding} provider: " + f"max_token_size={provider_max_token_size}, " + f"embedding_dim={provider_embedding_dim}" + ) + except ImportError as e: + logger.warning(f"Could not import provider function for {binding}: {e}") + + # Step 2: Apply priority (user config > provider default) + # For max_token_size: explicit env var > provider default > None + final_max_token_size = args.embedding_token_limit or provider_max_token_size + # For embedding_dim: user config (always has value) takes priority + # Only use provider default if user config is explicitly None (which shouldn't happen) + final_embedding_dim = ( + args.embedding_dim if args.embedding_dim else provider_embedding_dim + ) + + # Step 3: Create optimized embedding function (calls underlying function directly) async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - return await lollms_embed( + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + lollms_embed.func + if isinstance(lollms_embed, EmbeddingFunc) + else lollms_embed + ) + return await actual_func( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + ollama_embed.func + if isinstance(ollama_embed, EmbeddingFunc) + else ollama_embed + ) + + # Use pre-processed configuration if available if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await ollama_embed( + return await actual_func( texts, embed_model=model, host=host, @@ -654,30 +729,53 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + actual_func = ( + azure_openai_embed.func + if isinstance(azure_openai_embed, EmbeddingFunc) + else azure_openai_embed + ) + return await actual_func(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + actual_func = ( + bedrock_embed.func + if isinstance(bedrock_embed, EmbeddingFunc) + else bedrock_embed + ) + return await actual_func(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - return await jina_embed( - texts, embedding_dim=embedding_dim, base_url=host, api_key=api_key + actual_func = ( + jina_embed.func + if isinstance(jina_embed, EmbeddingFunc) + else jina_embed + ) + return await actual_func( + texts, + embedding_dim=embedding_dim, + base_url=host, + api_key=api_key, ) elif binding == "gemini": from lightrag.llm.gemini import gemini_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + actual_func = ( + gemini_embed.func + if isinstance(gemini_embed, EmbeddingFunc) + else gemini_embed + ) + + # Use pre-processed configuration if available if config_cache.gemini_embedding_options is not None: gemini_options = config_cache.gemini_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import GeminiEmbeddingOptions gemini_options = GeminiEmbeddingOptions.options_dict(args) - return await gemini_embed( + return await actual_func( texts, model=model, base_url=host, @@ -688,13 +786,36 @@ def create_app(args): else: # openai and compatible from lightrag.llm.openai import openai_embed - return await openai_embed( - texts, model=model, base_url=host, api_key=api_key, embedding_dim=embedding_dim + actual_func = ( + openai_embed.func + if isinstance(openai_embed, EmbeddingFunc) + else openai_embed + ) + return await actual_func( + texts, + model=model, + base_url=host, + api_key=api_key, + embedding_dim=embedding_dim, ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - return optimized_embedding_function + # Step 4: Wrap in EmbeddingFunc and return + embedding_func_instance = EmbeddingFunc( + embedding_dim=final_embedding_dim, + func=optimized_embedding_function, + max_token_size=final_max_token_size, + send_dimensions=False, # Will be set later based on binding requirements + ) + + # Log final embedding configuration + logger.info( + f"Embedding config: binding={binding} model={model} " + f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" + ) + + return embedding_func_instance llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -728,25 +849,24 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration + # Create embedding function with optimized configuration and max_token_size inheritance import inspect - # Create the optimized embedding function - optimized_embedding_func = create_optimized_embedding_function( + # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) + embedding_func = create_optimized_embedding_function( config_cache=config_cache, binding=args.embedding_binding, model=args.embedding_model, host=args.embedding_binding_host, api_key=args.embedding_binding_api_key, - args=args, # Pass args object for fallback option generation + args=args, ) # Get embedding_send_dim from centralized configuration embedding_send_dim = args.embedding_send_dim - # Check if the function signature has embedding_dim parameter - # Note: Since optimized_embedding_func is an async function, inspect its signature - sig = inspect.signature(optimized_embedding_func) + # Check if the underlying function signature has embedding_dim parameter + sig = inspect.signature(embedding_func.func) has_embedding_dim_param = "embedding_dim" in sig.parameters # Determine send_dimensions value based on binding type @@ -764,23 +884,27 @@ def create_app(args): else: dimension_control = "by not hasparam" + # Set send_dimensions on the EmbeddingFunc instance + embedding_func.send_dimensions = send_dimensions + logger.info( f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, " + f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " f"binding={args.embedding_binding})" ) - # Create EmbeddingFunc with send_dimensions attribute - embedding_func = EmbeddingFunc( - embedding_dim=args.embedding_dim, - func=optimized_embedding_func, - send_dimensions=send_dimensions, - ) - - # Set max_token_size if EMBEDDING_TOKEN_LIMIT is provided - if args.embedding_token_limit is not None: - embedding_func.max_token_size = args.embedding_token_limit - logger.info(f"Set embedding max_token_size to {args.embedding_token_limit}") + # Log max_token_size source + if embedding_func.max_token_size: + source = ( + "env variable" + if args.embedding_token_limit + else f"{args.embedding_binding} provider default" + ) + logger.info( + f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" + ) + else: + logger.info("Embedding max_token_size: not set (90% token warning disabled)") # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None @@ -1212,6 +1336,12 @@ def check_and_install_dependencies(): def main(): + # Explicitly initialize configuration for clarity + # (The proxy will auto-initialize anyway, but this makes intent clear) + from .config import initialize_config + + initialize_config() + # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application