Refactor embedding function creation with proper attribute inheritance
- Extract max_token_size from providers - Avoid double-wrapping EmbeddingFunc - Improve configuration priority logic - Add comprehensive debug logging - Return complete EmbeddingFunc instance
This commit is contained in:
parent
f0254773c6
commit
6b2af2b579
1 changed files with 139 additions and 34 deletions
|
|
@ -641,33 +641,102 @@ def create_app(args):
|
|||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
):
|
||||
) -> EmbeddingFunc:
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||
Create optimized embedding function and return an EmbeddingFunc instance
|
||||
with proper max_token_size inheritance from provider defaults.
|
||||
|
||||
This function:
|
||||
1. Imports the provider embedding function
|
||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||
4. Returns a properly configured EmbeddingFunc instance
|
||||
"""
|
||||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
default_max_token_size = None
|
||||
default_embedding_dim = args.embedding_dim # Use config as default
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
provider_func = openai_embed
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
provider_func = ollama_embed
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
provider_func = gemini_embed
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
provider_func = jina_embed
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
provider_func = azure_openai_embed
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
provider_func = bedrock_embed
|
||||
elif binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
provider_func = lollms_embed
|
||||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
default_max_token_size = provider_func.max_token_size
|
||||
default_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={default_max_token_size}, "
|
||||
f"embedding_dim={default_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (environment variable > provider default)
|
||||
final_max_token_size = args.embedding_token_limit or default_max_token_size
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
try:
|
||||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
return await lollms_embed(
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
lollms_embed.func
|
||||
if isinstance(lollms_embed, EmbeddingFunc)
|
||||
else lollms_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
ollama_embed.func
|
||||
if isinstance(ollama_embed, EmbeddingFunc)
|
||||
else ollama_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.ollama_embedding_options is not None:
|
||||
ollama_options = config_cache.ollama_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
|
||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await ollama_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
|
|
@ -677,15 +746,30 @@ def create_app(args):
|
|||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
actual_func = (
|
||||
azure_openai_embed.func
|
||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||
else azure_openai_embed
|
||||
)
|
||||
return await actual_func(texts, model=model, api_key=api_key)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
return await bedrock_embed(texts, model=model)
|
||||
actual_func = (
|
||||
bedrock_embed.func
|
||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||
else bedrock_embed
|
||||
)
|
||||
return await actual_func(texts, model=model)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
actual_func = (
|
||||
jina_embed.func
|
||||
if isinstance(jina_embed, EmbeddingFunc)
|
||||
else jina_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
embedding_dim=embedding_dim,
|
||||
base_url=host,
|
||||
|
|
@ -694,16 +778,21 @@ def create_app(args):
|
|||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
actual_func = (
|
||||
gemini_embed.func
|
||||
if isinstance(gemini_embed, EmbeddingFunc)
|
||||
else gemini_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.gemini_embedding_options is not None:
|
||||
gemini_options = config_cache.gemini_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await gemini_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -714,7 +803,12 @@ def create_app(args):
|
|||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
return await openai_embed(
|
||||
actual_func = (
|
||||
openai_embed.func
|
||||
if isinstance(openai_embed, EmbeddingFunc)
|
||||
else openai_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -724,7 +818,15 @@ def create_app(args):
|
|||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
||||
return optimized_embedding_function
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=default_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
embedding_timeout = get_env_value(
|
||||
|
|
@ -758,25 +860,24 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embedding function with optimized configuration
|
||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||
import inspect
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||
embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Get embedding_send_dim from centralized configuration
|
||||
embedding_send_dim = args.embedding_send_dim
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
# Check if the underlying function signature has embedding_dim parameter
|
||||
sig = inspect.signature(embedding_func.func)
|
||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||
|
||||
# Determine send_dimensions value based on binding type
|
||||
|
|
@ -794,23 +895,27 @@ def create_app(args):
|
|||
else:
|
||||
dimension_control = "by not hasparam"
|
||||
|
||||
# Set send_dimensions on the EmbeddingFunc instance
|
||||
embedding_func.send_dimensions = send_dimensions
|
||||
|
||||
logger.info(
|
||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
|
||||
# Set max_token_size if EMBEDDING_TOKEN_LIMIT is provided
|
||||
if args.embedding_token_limit is not None:
|
||||
embedding_func.max_token_size = args.embedding_token_limit
|
||||
logger.info(f"Set embedding max_token_size to {args.embedding_token_limit}")
|
||||
# Log max_token_size source
|
||||
if embedding_func.max_token_size:
|
||||
source = (
|
||||
"env variable"
|
||||
if args.embedding_token_limit
|
||||
else f"{args.embedding_binding} provider default"
|
||||
)
|
||||
logger.info(
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue