This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:27 +08:00
parent 2fbc5972f8
commit 7f5afd0a4d

View file

@ -618,33 +618,108 @@ def create_app(args):
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args
):
) -> EmbeddingFunc:
"""
Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing.
Create optimized embedding function and return an EmbeddingFunc instance
with proper max_token_size inheritance from provider defaults.
This function:
1. Imports the provider embedding function
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
"""
# Step 1: Import provider function and extract default attributes
provider_func = None
provider_max_token_size = None
provider_embedding_dim = None
try:
if binding == "openai":
from lightrag.llm.openai import openai_embed
provider_func = openai_embed
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
provider_func = ollama_embed
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
provider_func = gemini_embed
elif binding == "jina":
from lightrag.llm.jina import jina_embed
provider_func = jina_embed
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
provider_func = azure_openai_embed
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
provider_func = bedrock_embed
elif binding == "lollms":
from lightrag.llm.lollms import lollms_embed
provider_func = lollms_embed
# Extract attributes if provider is an EmbeddingFunc
if provider_func and isinstance(provider_func, EmbeddingFunc):
provider_max_token_size = provider_func.max_token_size
provider_embedding_dim = provider_func.embedding_dim
logger.debug(
f"Extracted from {binding} provider: "
f"max_token_size={provider_max_token_size}, "
f"embedding_dim={provider_embedding_dim}"
)
except ImportError as e:
logger.warning(f"Could not import provider function for {binding}: {e}")
# Step 2: Apply priority (user config > provider default)
# For max_token_size: explicit env var > provider default > None
final_max_token_size = args.embedding_token_limit or provider_max_token_size
# For embedding_dim: user config (always has value) takes priority
# Only use provider default if user config is explicitly None (which shouldn't happen)
final_embedding_dim = (
args.embedding_dim if args.embedding_dim else provider_embedding_dim
)
# Step 3: Create optimized embedding function (calls underlying function directly)
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
lollms_embed.func
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
ollama_embed.func
if isinstance(ollama_embed, EmbeddingFunc)
else ollama_embed
)
# Use pre-processed configuration if available
if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
return await actual_func(
texts,
embed_model=model,
host=host,
@ -654,30 +729,53 @@ def create_app(args):
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
actual_func = (
azure_openai_embed.func
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
actual_func = (
bedrock_embed.func
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
texts, embedding_dim=embedding_dim, base_url=host, api_key=api_key
actual_func = (
jina_embed.func
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
actual_func = (
gemini_embed.func
if isinstance(gemini_embed, EmbeddingFunc)
else gemini_embed
)
# Use pre-processed configuration if available
if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await gemini_embed(
return await actual_func(
texts,
model=model,
base_url=host,
@ -688,13 +786,36 @@ def create_app(args):
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key, embedding_dim=embedding_dim
actual_func = (
openai_embed.func
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
return optimized_embedding_function
# Step 4: Wrap in EmbeddingFunc and return
embedding_func_instance = EmbeddingFunc(
embedding_dim=final_embedding_dim,
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
)
# Log final embedding configuration
logger.info(
f"Embedding config: binding={binding} model={model} "
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
)
return embedding_func_instance
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
@ -728,25 +849,24 @@ def create_app(args):
**kwargs,
)
# Create embedding function with optimized configuration
# Create embedding function with optimized configuration and max_token_size inheritance
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
args=args,
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
# Check if the underlying function signature has embedding_dim parameter
sig = inspect.signature(embedding_func.func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
@ -764,23 +884,27 @@ def create_app(args):
else:
dimension_control = "by not hasparam"
# Set send_dimensions on the EmbeddingFunc instance
embedding_func.send_dimensions = send_dimensions
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=optimized_embedding_func,
send_dimensions=send_dimensions,
)
# Set max_token_size if EMBEDDING_TOKEN_LIMIT is provided
if args.embedding_token_limit is not None:
embedding_func.max_token_size = args.embedding_token_limit
logger.info(f"Set embedding max_token_size to {args.embedding_token_limit}")
# Log max_token_size source
if embedding_func.max_token_size:
source = (
"env variable"
if args.embedding_token_limit
else f"{args.embedding_binding} provider default"
)
logger.info(
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
@ -1212,6 +1336,12 @@ def check_and_install_dependencies():
def main():
# Explicitly initialize configuration for clarity
# (The proxy will auto-initialize anyway, but this makes intent clear)
from .config import initialize_config
initialize_config()
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application