Add optional embedding dimension parameter control via env var
* Add EMBEDDING_SEND_DIM environment variable * Update Jina/OpenAI embed functions * Add send_dimensions to EmbeddingFunc * Auto-inject embedding_dim when enabled * Add parameter validation warnings
This commit is contained in:
parent
d94aae9c5e
commit
33a1482f7f
5 changed files with 76 additions and 18 deletions
|
|
@ -241,6 +241,13 @@ OLLAMA_LLM_NUM_CTX=32768
|
|||
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
||||
#######################################################################################
|
||||
# EMBEDDING_TIMEOUT=30
|
||||
|
||||
### Control whether to send embedding_dim parameter to embedding API
|
||||
### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina)
|
||||
### Set to 'false' (default) to disable sending dimension parameter
|
||||
### Note: This is automatically ignored for backends that don't support dimension parameter
|
||||
# EMBEDDING_SEND_DIM=false
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import logging.config
|
|||
import sys
|
||||
import uvicorn
|
||||
import pipmaster as pm
|
||||
import inspect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pathlib import Path
|
||||
|
|
@ -595,7 +594,7 @@ def create_app(args):
|
|||
return {}
|
||||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
):
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
|
|
@ -641,7 +640,7 @@ def create_app(args):
|
|||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
||||
texts, base_url=host, api_key=api_key
|
||||
)
|
||||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
|
@ -687,17 +686,43 @@ def create_app(args):
|
|||
)
|
||||
|
||||
# Create embedding function with optimized configuration
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=create_optimized_embedding_function(
|
||||
import inspect
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
dimensions=args.embedding_dim,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
),
|
||||
)
|
||||
|
||||
# Check environment variable for sending dimensions
|
||||
embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true"
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
has_embedding_dim_param = 'embedding_dim' in sig.parameters
|
||||
|
||||
# Determine send_dimensions value
|
||||
# Only send dimensions if both conditions are met:
|
||||
# 1. EMBEDDING_SEND_DIM environment variable is true
|
||||
# 2. The function has embedding_dim parameter
|
||||
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
||||
|
||||
logger.info(
|
||||
f"Embedding configuration: send_dimensions={send_dimensions} "
|
||||
f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ async def fetch_data(url, headers, data):
|
|||
)
|
||||
async def jina_embed(
|
||||
texts: list[str],
|
||||
dimensions: int = 2048,
|
||||
embedding_dim: int = 2048,
|
||||
late_chunking: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
|
|
@ -78,7 +78,12 @@ async def jina_embed(
|
|||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
|
||||
late_chunking: Whether to use late chunking.
|
||||
base_url: Optional base URL for the Jina API.
|
||||
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
||||
|
|
@ -104,7 +109,7 @@ async def jina_embed(
|
|||
data = {
|
||||
"model": "jina-embeddings-v4",
|
||||
"task": "text-matching",
|
||||
"dimensions": dimensions,
|
||||
"dimensions": embedding_dim,
|
||||
"embedding_type": "base64",
|
||||
"input": texts,
|
||||
}
|
||||
|
|
@ -114,7 +119,7 @@ async def jina_embed(
|
|||
data["late_chunking"] = late_chunking
|
||||
|
||||
logger.debug(
|
||||
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}"
|
||||
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -609,7 +609,7 @@ async def openai_embed(
|
|||
model: str = "text-embedding-3-small",
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int = None,
|
||||
embedding_dim: int | None = None,
|
||||
client_configs: dict[str, Any] | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
) -> np.ndarray:
|
||||
|
|
@ -620,7 +620,12 @@ async def openai_embed(
|
|||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
embedding_dim: Optional embedding dimension. If None, uses the default embedding dimension for the model. (will be passed to API for dimension reduction).
|
||||
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
|
|
|||
|
|
@ -353,8 +353,24 @@ class EmbeddingFunc:
|
|||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||
send_dimensions: bool = False # Control whether to send embedding_dim to the function
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
# Only inject embedding_dim when send_dimensions is True
|
||||
if self.send_dimensions:
|
||||
# Check if user provided embedding_dim parameter
|
||||
if 'embedding_dim' in kwargs:
|
||||
user_provided_dim = kwargs['embedding_dim']
|
||||
# If user's value differs from class attribute, output warning
|
||||
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
|
||||
logger.warning(
|
||||
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
||||
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
||||
)
|
||||
|
||||
# Inject embedding_dim from decorator
|
||||
kwargs['embedding_dim'] = self.embedding_dim
|
||||
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue