Add optional embedding dimension parameter control via env var

* Add EMBEDDING_SEND_DIM environment variable
* Update Jina/OpenAI embed functions
* Add send_dimensions to EmbeddingFunc
* Auto-inject embedding_dim when enabled
* Add parameter validation warnings
This commit is contained in:
yangdx 2025-11-07 20:46:40 +08:00
parent d94aae9c5e
commit 33a1482f7f
5 changed files with 76 additions and 18 deletions

View file

@ -241,6 +241,13 @@ OLLAMA_LLM_NUM_CTX=32768
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service ### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
####################################################################################### #######################################################################################
# EMBEDDING_TIMEOUT=30 # EMBEDDING_TIMEOUT=30
### Control whether to send embedding_dim parameter to embedding API
### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina)
### Set to 'false' (default) to disable sending dimension parameter
### Note: This is automatically ignored for backends that don't support dimension parameter
# EMBEDDING_SEND_DIM=false
EMBEDDING_BINDING=ollama EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024

View file

@ -15,7 +15,6 @@ import logging.config
import sys import sys
import uvicorn import uvicorn
import pipmaster as pm import pipmaster as pm
import inspect
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pathlib import Path from pathlib import Path
@ -595,7 +594,7 @@ def create_app(args):
return {} return {}
def create_optimized_embedding_function( def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args config_cache: LLMConfigCache, binding, model, host, api_key, args
): ):
""" """
Create optimized embedding function with pre-processed configuration for applicable bindings. Create optimized embedding function with pre-processed configuration for applicable bindings.
@ -641,7 +640,7 @@ def create_app(args):
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
return await jina_embed( return await jina_embed(
texts, dimensions=dimensions, base_url=host, api_key=api_key texts, base_url=host, api_key=api_key
) )
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
@ -687,17 +686,43 @@ def create_app(args):
) )
# Create embedding function with optimized configuration # Create embedding function with optimized configuration
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
)
# Check environment variable for sending dimensions
embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true"
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
has_embedding_dim_param = 'embedding_dim' in sig.parameters
# Determine send_dimensions value
# Only send dimensions if both conditions are met:
# 1. EMBEDDING_SEND_DIM environment variable is true
# 2. The function has embedding_dim parameter
send_dimensions = embedding_send_dim and has_embedding_dim_param
logger.info(
f"Embedding configuration: send_dimensions={send_dimensions} "
f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc( embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim, embedding_dim=args.embedding_dim,
func=create_optimized_embedding_function( func=optimized_embedding_func,
config_cache=config_cache, send_dimensions=send_dimensions,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
dimensions=args.embedding_dim,
args=args, # Pass args object for fallback option generation
),
) )
# Configure rerank function based on args.rerank_bindingparameter # Configure rerank function based on args.rerank_bindingparameter

View file

@ -69,7 +69,7 @@ async def fetch_data(url, headers, data):
) )
async def jina_embed( async def jina_embed(
texts: list[str], texts: list[str],
dimensions: int = 2048, embedding_dim: int = 2048,
late_chunking: bool = False, late_chunking: bool = False,
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
@ -78,7 +78,12 @@ async def jina_embed(
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4). embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
late_chunking: Whether to use late chunking. late_chunking: Whether to use late chunking.
base_url: Optional base URL for the Jina API. base_url: Optional base URL for the Jina API.
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable. api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
@ -104,7 +109,7 @@ async def jina_embed(
data = { data = {
"model": "jina-embeddings-v4", "model": "jina-embeddings-v4",
"task": "text-matching", "task": "text-matching",
"dimensions": dimensions, "dimensions": embedding_dim,
"embedding_type": "base64", "embedding_type": "base64",
"input": texts, "input": texts,
} }
@ -114,7 +119,7 @@ async def jina_embed(
data["late_chunking"] = late_chunking data["late_chunking"] = late_chunking
logger.debug( logger.debug(
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}" f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
) )
try: try:

View file

@ -609,7 +609,7 @@ async def openai_embed(
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
embedding_dim: int = None, embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None, token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
@ -620,7 +620,12 @@ async def openai_embed(
model: The OpenAI embedding model to use. model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API. base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension. If None, uses the default embedding dimension for the model. (will be passed to API for dimension reduction). embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).

View file

@ -353,8 +353,24 @@ class EmbeddingFunc:
embedding_dim: int embedding_dim: int
func: callable func: callable
max_token_size: int | None = None # deprecated keep it for compatible only max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = False # Control whether to send embedding_dim to the function
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions:
# Check if user provided embedding_dim parameter
if 'embedding_dim' in kwargs:
user_provided_dim = kwargs['embedding_dim']
# If user's value differs from class attribute, output warning
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
logger.warning(
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
f"using declared embedding_dim={self.embedding_dim} from decorator"
)
# Inject embedding_dim from decorator
kwargs['embedding_dim'] = self.embedding_dim
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)