Merge pull request #2328 from HKUDS/apply-dim-to-embedding-call

Feat: Add Optional Embedding Dimension Parameter Control with Jina API Compliance
This commit is contained in:
Daniel.y 2025-11-08 02:10:08 +08:00 committed by GitHub
commit f4492d48dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 142 additions and 53 deletions

View file

@ -242,6 +242,14 @@ OLLAMA_LLM_NUM_CTX=32768
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service ### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
####################################################################################### #######################################################################################
# EMBEDDING_TIMEOUT=30 # EMBEDDING_TIMEOUT=30
### Control whether to send embedding_dim parameter to embedding API
### IMPORTANT: Jina ALWAYS sends dimension parameter (API requirement) - this setting is ignored for Jina
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
# EMBEDDING_SEND_DIM=false
EMBEDDING_BINDING=ollama EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024

View file

@ -343,6 +343,7 @@ def parse_args() -> argparse.Namespace:
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
# Inject chunk configuration # Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)

View file

@ -15,7 +15,6 @@ import logging.config
import sys import sys
import uvicorn import uvicorn
import pipmaster as pm import pipmaster as pm
import inspect
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pathlib import Path from pathlib import Path
@ -599,14 +598,14 @@ def create_app(args):
return {} return {}
def create_optimized_embedding_function( def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args config_cache: LLMConfigCache, binding, model, host, api_key, args
): ):
""" """
Create optimized embedding function with pre-processed configuration for applicable bindings. Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing. Uses lazy imports for all bindings and avoids repeated configuration parsing.
""" """
async def optimized_embedding_function(texts): async def optimized_embedding_function(texts, embedding_dim=None):
try: try:
if binding == "lollms": if binding == "lollms":
from lightrag.llm.lollms import lollms_embed from lightrag.llm.lollms import lollms_embed
@ -645,13 +644,20 @@ def create_app(args):
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
return await jina_embed( return await jina_embed(
texts, dimensions=dimensions, base_url=host, api_key=api_key texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
) )
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
return await openai_embed( return await openai_embed(
texts, model=model, base_url=host, api_key=api_key texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
) )
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}") raise Exception(f"Failed to import {binding} embedding: {e}")
@ -691,17 +697,52 @@ def create_app(args):
) )
# Create embedding function with optimized configuration # Create embedding function with optimized configuration
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
# Jina REQUIRES dimension parameter (forced to True)
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
if args.embedding_binding == "jina":
# Jina API requires dimension parameter - always send it
send_dimensions = has_embedding_dim_param
dimension_control = "forced by Jina API"
else:
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
send_dimensions = embedding_send_dim and has_embedding_dim_param
if send_dimensions or not embedding_send_dim:
dimension_control = "by env var"
else:
dimension_control = "by not hasparam"
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc( embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim, embedding_dim=args.embedding_dim,
func=create_optimized_embedding_function( func=optimized_embedding_func,
config_cache=config_cache, send_dimensions=send_dimensions,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
dimensions=args.embedding_dim,
args=args, # Pass args object for fallback option generation
),
) )
# Configure rerank function based on args.rerank_bindingparameter # Configure rerank function based on args.rerank_bindingparameter

View file

@ -472,6 +472,9 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
_binding_name: ClassVar[str] = "ollama_llm" _binding_name: ClassVar[str] = "ollama_llm"
# =============================================================================
# Binding Options for Gemini
# =============================================================================
@dataclass @dataclass
class GeminiLLMOptions(BindingOptions): class GeminiLLMOptions(BindingOptions):
"""Options for Google Gemini models.""" """Options for Google Gemini models."""

View file

@ -69,7 +69,7 @@ async def fetch_data(url, headers, data):
) )
async def jina_embed( async def jina_embed(
texts: list[str], texts: list[str],
dimensions: int = 2048, embedding_dim: int = 2048,
late_chunking: bool = False, late_chunking: bool = False,
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
@ -78,7 +78,12 @@ async def jina_embed(
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4). embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
late_chunking: Whether to use late chunking. late_chunking: Whether to use late chunking.
base_url: Optional base URL for the Jina API. base_url: Optional base URL for the Jina API.
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable. api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
@ -104,7 +109,7 @@ async def jina_embed(
data = { data = {
"model": "jina-embeddings-v4", "model": "jina-embeddings-v4",
"task": "text-matching", "task": "text-matching",
"dimensions": dimensions, "dimensions": embedding_dim,
"embedding_type": "base64", "embedding_type": "base64",
"input": texts, "input": texts,
} }
@ -114,7 +119,7 @@ async def jina_embed(
data["late_chunking"] = late_chunking data["late_chunking"] = late_chunking
logger.debug( logger.debug(
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}" f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
) )
try: try:

View file

@ -613,6 +613,7 @@ async def openai_embed(
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None, client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None, token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
@ -623,6 +624,12 @@ async def openai_embed(
model: The OpenAI embedding model to use. model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API. base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client. client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url). explicit parameters (api_key, base_url).
@ -642,9 +649,19 @@ async def openai_embed(
) )
async with openai_async_client: async with openai_async_client:
response = await openai_async_client.embeddings.create( # Prepare API call parameters
model=model, input=texts, encoding_format="base64" api_params = {
) "model": model,
"input": texts,
"encoding_format": "base64",
}
# Add dimensions parameter only if embedding_dim is provided
if embedding_dim is not None:
api_params["dimensions"] = embedding_dim
# Make API call
response = await openai_async_client.embeddings.create(**api_params)
if token_tracker and hasattr(response, "usage"): if token_tracker and hasattr(response, "usage"):
token_counts = { token_counts = {

View file

@ -3425,10 +3425,10 @@ async def _perform_kg_search(
) )
query_embedding = None query_embedding = None
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if embedding_func_config and embedding_func_config.func: if actual_embedding_func:
try: try:
query_embedding = await embedding_func_config.func([query]) query_embedding = await actual_embedding_func([query])
query_embedding = query_embedding[ query_embedding = query_embedding[
0 0
] # Extract first embedding from batch result ] # Extract first embedding from batch result
@ -4336,25 +4336,21 @@ async def _find_related_text_unit_from_entities(
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
selected_chunk_ids = None text_chunks_storage=text_chunks_db,
if actual_embedding_func: chunks_vdb=chunks_vdb,
selected_chunk_ids = await pick_by_vector_similarity( num_of_chunks=num_of_chunks,
query=query, entity_info=entities_with_chunks,
text_chunks_storage=text_chunks_db, embedding_func=actual_embedding_func,
chunks_vdb=chunks_vdb, query_embedding=query_embedding,
num_of_chunks=num_of_chunks, )
entity_info=entities_with_chunks,
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
@ -4629,24 +4625,21 @@ async def _find_related_text_unit_from_relations(
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config # Get embedding function from global config
embedding_func_config = text_chunks_db.embedding_func actual_embedding_func = text_chunks_db.embedding_func
if not embedding_func_config: if not actual_embedding_func:
logger.warning("No embedding function found, falling back to WEIGHT method") logger.warning("No embedding function found, falling back to WEIGHT method")
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"
else: else:
try: try:
actual_embedding_func = embedding_func_config.func selected_chunk_ids = await pick_by_vector_similarity(
query=query,
if actual_embedding_func: text_chunks_storage=text_chunks_db,
selected_chunk_ids = await pick_by_vector_similarity( chunks_vdb=chunks_vdb,
query=query, num_of_chunks=num_of_chunks,
text_chunks_storage=text_chunks_db, entity_info=relations_with_chunks,
chunks_vdb=chunks_vdb, embedding_func=actual_embedding_func,
num_of_chunks=num_of_chunks, query_embedding=query_embedding,
entity_info=relations_with_chunks, )
embedding_func=actual_embedding_func,
query_embedding=query_embedding,
)
if selected_chunk_ids == []: if selected_chunk_ids == []:
kg_chunk_pick_method = "WEIGHT" kg_chunk_pick_method = "WEIGHT"

View file

@ -353,8 +353,29 @@ class EmbeddingFunc:
embedding_dim: int embedding_dim: int
func: callable func: callable
max_token_size: int | None = None # deprecated keep it for compatible only max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions:
# Check if user provided embedding_dim parameter
if "embedding_dim" in kwargs:
user_provided_dim = kwargs["embedding_dim"]
# If user's value differs from class attribute, output warning
if (
user_provided_dim is not None
and user_provided_dim != self.embedding_dim
):
logger.warning(
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
f"using declared embedding_dim={self.embedding_dim} from decorator"
)
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)