diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7f838f14..8f3fbae1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -641,33 +641,102 @@ def create_app(args): def create_optimized_embedding_function( config_cache: LLMConfigCache, binding, model, host, api_key, args - ): + ) -> EmbeddingFunc: """ - Create optimized embedding function with pre-processed configuration for applicable bindings. - Uses lazy imports for all bindings and avoids repeated configuration parsing. + Create optimized embedding function and return an EmbeddingFunc instance + with proper max_token_size inheritance from provider defaults. + + This function: + 1. Imports the provider embedding function + 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc + 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) + 4. Returns a properly configured EmbeddingFunc instance """ + # Step 1: Import provider function and extract default attributes + provider_func = None + default_max_token_size = None + default_embedding_dim = args.embedding_dim # Use config as default + + try: + if binding == "openai": + from lightrag.llm.openai import openai_embed + + provider_func = openai_embed + elif binding == "ollama": + from lightrag.llm.ollama import ollama_embed + + provider_func = ollama_embed + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + provider_func = gemini_embed + elif binding == "jina": + from lightrag.llm.jina import jina_embed + + provider_func = jina_embed + elif binding == "azure_openai": + from lightrag.llm.azure_openai import azure_openai_embed + + provider_func = azure_openai_embed + elif binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_embed + + provider_func = bedrock_embed + elif binding == "lollms": + from lightrag.llm.lollms import lollms_embed + + provider_func = lollms_embed + + # Extract attributes if provider is an EmbeddingFunc + if provider_func and isinstance(provider_func, EmbeddingFunc): + default_max_token_size = provider_func.max_token_size + default_embedding_dim = provider_func.embedding_dim + logger.debug( + f"Extracted from {binding} provider: " + f"max_token_size={default_max_token_size}, " + f"embedding_dim={default_embedding_dim}" + ) + except ImportError as e: + logger.warning(f"Could not import provider function for {binding}: {e}") + + # Step 2: Apply priority (environment variable > provider default) + final_max_token_size = args.embedding_token_limit or default_max_token_size + + # Step 3: Create optimized embedding function (calls underlying function directly) async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - return await lollms_embed( + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + lollms_embed.func + if isinstance(lollms_embed, EmbeddingFunc) + else lollms_embed + ) + return await actual_func( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + ollama_embed.func + if isinstance(ollama_embed, EmbeddingFunc) + else ollama_embed + ) + + # Use pre-processed configuration if available if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await ollama_embed( + return await actual_func( texts, embed_model=model, host=host, @@ -677,15 +746,30 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + actual_func = ( + azure_openai_embed.func + if isinstance(azure_openai_embed, EmbeddingFunc) + else azure_openai_embed + ) + return await actual_func(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + actual_func = ( + bedrock_embed.func + if isinstance(bedrock_embed, EmbeddingFunc) + else bedrock_embed + ) + return await actual_func(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - return await jina_embed( + actual_func = ( + jina_embed.func + if isinstance(jina_embed, EmbeddingFunc) + else jina_embed + ) + return await actual_func( texts, embedding_dim=embedding_dim, base_url=host, @@ -694,16 +778,21 @@ def create_app(args): elif binding == "gemini": from lightrag.llm.gemini import gemini_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + actual_func = ( + gemini_embed.func + if isinstance(gemini_embed, EmbeddingFunc) + else gemini_embed + ) + + # Use pre-processed configuration if available if config_cache.gemini_embedding_options is not None: gemini_options = config_cache.gemini_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import GeminiEmbeddingOptions gemini_options = GeminiEmbeddingOptions.options_dict(args) - return await gemini_embed( + return await actual_func( texts, model=model, base_url=host, @@ -714,7 +803,12 @@ def create_app(args): else: # openai and compatible from lightrag.llm.openai import openai_embed - return await openai_embed( + actual_func = ( + openai_embed.func + if isinstance(openai_embed, EmbeddingFunc) + else openai_embed + ) + return await actual_func( texts, model=model, base_url=host, @@ -724,7 +818,15 @@ def create_app(args): except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - return optimized_embedding_function + # Step 4: Wrap in EmbeddingFunc and return + embedding_func_instance = EmbeddingFunc( + embedding_dim=default_embedding_dim, + func=optimized_embedding_function, + max_token_size=final_max_token_size, + send_dimensions=False, # Will be set later based on binding requirements + ) + + return embedding_func_instance llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -758,25 +860,24 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration + # Create embedding function with optimized configuration and max_token_size inheritance import inspect - # Create the optimized embedding function - optimized_embedding_func = create_optimized_embedding_function( + # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) + embedding_func = create_optimized_embedding_function( config_cache=config_cache, binding=args.embedding_binding, model=args.embedding_model, host=args.embedding_binding_host, api_key=args.embedding_binding_api_key, - args=args, # Pass args object for fallback option generation + args=args, ) # Get embedding_send_dim from centralized configuration embedding_send_dim = args.embedding_send_dim - # Check if the function signature has embedding_dim parameter - # Note: Since optimized_embedding_func is an async function, inspect its signature - sig = inspect.signature(optimized_embedding_func) + # Check if the underlying function signature has embedding_dim parameter + sig = inspect.signature(embedding_func.func) has_embedding_dim_param = "embedding_dim" in sig.parameters # Determine send_dimensions value based on binding type @@ -794,23 +895,27 @@ def create_app(args): else: dimension_control = "by not hasparam" + # Set send_dimensions on the EmbeddingFunc instance + embedding_func.send_dimensions = send_dimensions + logger.info( f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, " + f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " f"binding={args.embedding_binding})" ) - # Create EmbeddingFunc with send_dimensions attribute - embedding_func = EmbeddingFunc( - embedding_dim=args.embedding_dim, - func=optimized_embedding_func, - send_dimensions=send_dimensions, - ) - - # Set max_token_size if EMBEDDING_TOKEN_LIMIT is provided - if args.embedding_token_limit is not None: - embedding_func.max_token_size = args.embedding_token_limit - logger.info(f"Set embedding max_token_size to {args.embedding_token_limit}") + # Log max_token_size source + if embedding_func.max_token_size: + source = ( + "env variable" + if args.embedding_token_limit + else f"{args.embedding_binding} provider default" + ) + logger.info( + f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" + ) + else: + logger.info("Embedding max_token_size: not set (90% token warning disabled)") # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None