Merge pull request #2328 from HKUDS/apply-dim-to-embedding-call
Feat: Add Optional Embedding Dimension Parameter Control with Jina API Compliance
This commit is contained in:
commit
f4492d48dc
8 changed files with 142 additions and 53 deletions
|
|
@ -242,6 +242,14 @@ OLLAMA_LLM_NUM_CTX=32768
|
||||||
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
||||||
#######################################################################################
|
#######################################################################################
|
||||||
# EMBEDDING_TIMEOUT=30
|
# EMBEDDING_TIMEOUT=30
|
||||||
|
|
||||||
|
### Control whether to send embedding_dim parameter to embedding API
|
||||||
|
### IMPORTANT: Jina ALWAYS sends dimension parameter (API requirement) - this setting is ignored for Jina
|
||||||
|
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
|
||||||
|
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
|
||||||
|
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
|
||||||
|
# EMBEDDING_SEND_DIM=false
|
||||||
|
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
|
|
|
||||||
|
|
@ -343,6 +343,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
||||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
||||||
|
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
|
||||||
|
|
||||||
# Inject chunk configuration
|
# Inject chunk configuration
|
||||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import logging.config
|
||||||
import sys
|
import sys
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
import inspect
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -599,14 +598,14 @@ def create_app(args):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def create_optimized_embedding_function(
|
def create_optimized_embedding_function(
|
||||||
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
|
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def optimized_embedding_function(texts):
|
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||||
try:
|
try:
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
from lightrag.llm.lollms import lollms_embed
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
@ -645,13 +644,20 @@ def create_app(args):
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
return await jina_embed(
|
return await jina_embed(
|
||||||
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
texts,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts, model=model, base_url=host, api_key=api_key
|
texts,
|
||||||
|
model=model,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
@ -691,17 +697,52 @@ def create_app(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create embedding function with optimized configuration
|
# Create embedding function with optimized configuration
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
# Create the optimized embedding function
|
||||||
|
optimized_embedding_func = create_optimized_embedding_function(
|
||||||
|
config_cache=config_cache,
|
||||||
|
binding=args.embedding_binding,
|
||||||
|
model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
api_key=args.embedding_binding_api_key,
|
||||||
|
args=args, # Pass args object for fallback option generation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embedding_send_dim from centralized configuration
|
||||||
|
embedding_send_dim = args.embedding_send_dim
|
||||||
|
|
||||||
|
# Check if the function signature has embedding_dim parameter
|
||||||
|
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||||
|
sig = inspect.signature(optimized_embedding_func)
|
||||||
|
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||||
|
|
||||||
|
# Determine send_dimensions value based on binding type
|
||||||
|
# Jina REQUIRES dimension parameter (forced to True)
|
||||||
|
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
|
||||||
|
if args.embedding_binding == "jina":
|
||||||
|
# Jina API requires dimension parameter - always send it
|
||||||
|
send_dimensions = has_embedding_dim_param
|
||||||
|
dimension_control = "forced by Jina API"
|
||||||
|
else:
|
||||||
|
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
|
||||||
|
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
||||||
|
if send_dimensions or not embedding_send_dim:
|
||||||
|
dimension_control = "by env var"
|
||||||
|
else:
|
||||||
|
dimension_control = "by not hasparam"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||||
|
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||||
|
f"binding={args.embedding_binding})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create EmbeddingFunc with send_dimensions attribute
|
||||||
embedding_func = EmbeddingFunc(
|
embedding_func = EmbeddingFunc(
|
||||||
embedding_dim=args.embedding_dim,
|
embedding_dim=args.embedding_dim,
|
||||||
func=create_optimized_embedding_function(
|
func=optimized_embedding_func,
|
||||||
config_cache=config_cache,
|
send_dimensions=send_dimensions,
|
||||||
binding=args.embedding_binding,
|
|
||||||
model=args.embedding_model,
|
|
||||||
host=args.embedding_binding_host,
|
|
||||||
api_key=args.embedding_binding_api_key,
|
|
||||||
dimensions=args.embedding_dim,
|
|
||||||
args=args, # Pass args object for fallback option generation
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure rerank function based on args.rerank_bindingparameter
|
# Configure rerank function based on args.rerank_bindingparameter
|
||||||
|
|
|
||||||
|
|
@ -472,6 +472,9 @@ class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions):
|
||||||
_binding_name: ClassVar[str] = "ollama_llm"
|
_binding_name: ClassVar[str] = "ollama_llm"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Binding Options for Gemini
|
||||||
|
# =============================================================================
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeminiLLMOptions(BindingOptions):
|
class GeminiLLMOptions(BindingOptions):
|
||||||
"""Options for Google Gemini models."""
|
"""Options for Google Gemini models."""
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ async def fetch_data(url, headers, data):
|
||||||
)
|
)
|
||||||
async def jina_embed(
|
async def jina_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
dimensions: int = 2048,
|
embedding_dim: int = 2048,
|
||||||
late_chunking: bool = False,
|
late_chunking: bool = False,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
|
@ -78,7 +78,12 @@ async def jina_embed(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to embed.
|
texts: List of texts to embed.
|
||||||
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||||
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||||
|
Manually passing a different value will trigger a warning and be ignored.
|
||||||
|
When provided (by EmbeddingFunc), it will be passed to the Jina API for dimension reduction.
|
||||||
late_chunking: Whether to use late chunking.
|
late_chunking: Whether to use late chunking.
|
||||||
base_url: Optional base URL for the Jina API.
|
base_url: Optional base URL for the Jina API.
|
||||||
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
||||||
|
|
@ -104,7 +109,7 @@ async def jina_embed(
|
||||||
data = {
|
data = {
|
||||||
"model": "jina-embeddings-v4",
|
"model": "jina-embeddings-v4",
|
||||||
"task": "text-matching",
|
"task": "text-matching",
|
||||||
"dimensions": dimensions,
|
"dimensions": embedding_dim,
|
||||||
"embedding_type": "base64",
|
"embedding_type": "base64",
|
||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
@ -114,7 +119,7 @@ async def jina_embed(
|
||||||
data["late_chunking"] = late_chunking
|
data["late_chunking"] = late_chunking
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}"
|
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -613,6 +613,7 @@ async def openai_embed(
|
||||||
model: str = "text-embedding-3-small",
|
model: str = "text-embedding-3-small",
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
embedding_dim: int | None = None,
|
||||||
client_configs: dict[str, Any] | None = None,
|
client_configs: dict[str, Any] | None = None,
|
||||||
token_tracker: Any | None = None,
|
token_tracker: Any | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
@ -623,6 +624,12 @@ async def openai_embed(
|
||||||
model: The OpenAI embedding model to use.
|
model: The OpenAI embedding model to use.
|
||||||
base_url: Optional base URL for the OpenAI API.
|
base_url: Optional base URL for the OpenAI API.
|
||||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||||
|
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||||
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||||
|
Manually passing a different value will trigger a warning and be ignored.
|
||||||
|
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||||
These will override any default configurations but will be overridden by
|
These will override any default configurations but will be overridden by
|
||||||
explicit parameters (api_key, base_url).
|
explicit parameters (api_key, base_url).
|
||||||
|
|
@ -642,9 +649,19 @@ async def openai_embed(
|
||||||
)
|
)
|
||||||
|
|
||||||
async with openai_async_client:
|
async with openai_async_client:
|
||||||
response = await openai_async_client.embeddings.create(
|
# Prepare API call parameters
|
||||||
model=model, input=texts, encoding_format="base64"
|
api_params = {
|
||||||
)
|
"model": model,
|
||||||
|
"input": texts,
|
||||||
|
"encoding_format": "base64",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add dimensions parameter only if embedding_dim is provided
|
||||||
|
if embedding_dim is not None:
|
||||||
|
api_params["dimensions"] = embedding_dim
|
||||||
|
|
||||||
|
# Make API call
|
||||||
|
response = await openai_async_client.embeddings.create(**api_params)
|
||||||
|
|
||||||
if token_tracker and hasattr(response, "usage"):
|
if token_tracker and hasattr(response, "usage"):
|
||||||
token_counts = {
|
token_counts = {
|
||||||
|
|
|
||||||
|
|
@ -3425,10 +3425,10 @@ async def _perform_kg_search(
|
||||||
)
|
)
|
||||||
query_embedding = None
|
query_embedding = None
|
||||||
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if embedding_func_config and embedding_func_config.func:
|
if actual_embedding_func:
|
||||||
try:
|
try:
|
||||||
query_embedding = await embedding_func_config.func([query])
|
query_embedding = await actual_embedding_func([query])
|
||||||
query_embedding = query_embedding[
|
query_embedding = query_embedding[
|
||||||
0
|
0
|
||||||
] # Extract first embedding from batch result
|
] # Extract first embedding from batch result
|
||||||
|
|
@ -4336,25 +4336,21 @@ async def _find_related_text_unit_from_entities(
|
||||||
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
||||||
|
|
||||||
# Get embedding function from global config
|
# Get embedding function from global config
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if not embedding_func_config:
|
if not actual_embedding_func:
|
||||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
actual_embedding_func = embedding_func_config.func
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
selected_chunk_ids = None
|
text_chunks_storage=text_chunks_db,
|
||||||
if actual_embedding_func:
|
chunks_vdb=chunks_vdb,
|
||||||
selected_chunk_ids = await pick_by_vector_similarity(
|
num_of_chunks=num_of_chunks,
|
||||||
query=query,
|
entity_info=entities_with_chunks,
|
||||||
text_chunks_storage=text_chunks_db,
|
embedding_func=actual_embedding_func,
|
||||||
chunks_vdb=chunks_vdb,
|
query_embedding=query_embedding,
|
||||||
num_of_chunks=num_of_chunks,
|
)
|
||||||
entity_info=entities_with_chunks,
|
|
||||||
embedding_func=actual_embedding_func,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if selected_chunk_ids == []:
|
if selected_chunk_ids == []:
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
@ -4629,24 +4625,21 @@ async def _find_related_text_unit_from_relations(
|
||||||
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
||||||
|
|
||||||
# Get embedding function from global config
|
# Get embedding function from global config
|
||||||
embedding_func_config = text_chunks_db.embedding_func
|
actual_embedding_func = text_chunks_db.embedding_func
|
||||||
if not embedding_func_config:
|
if not actual_embedding_func:
|
||||||
logger.warning("No embedding function found, falling back to WEIGHT method")
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
actual_embedding_func = embedding_func_config.func
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
if actual_embedding_func:
|
text_chunks_storage=text_chunks_db,
|
||||||
selected_chunk_ids = await pick_by_vector_similarity(
|
chunks_vdb=chunks_vdb,
|
||||||
query=query,
|
num_of_chunks=num_of_chunks,
|
||||||
text_chunks_storage=text_chunks_db,
|
entity_info=relations_with_chunks,
|
||||||
chunks_vdb=chunks_vdb,
|
embedding_func=actual_embedding_func,
|
||||||
num_of_chunks=num_of_chunks,
|
query_embedding=query_embedding,
|
||||||
entity_info=relations_with_chunks,
|
)
|
||||||
embedding_func=actual_embedding_func,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if selected_chunk_ids == []:
|
if selected_chunk_ids == []:
|
||||||
kg_chunk_pick_method = "WEIGHT"
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
|
||||||
|
|
@ -353,8 +353,29 @@ class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||||
|
send_dimensions: bool = (
|
||||||
|
False # Control whether to send embedding_dim to the function
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
# Only inject embedding_dim when send_dimensions is True
|
||||||
|
if self.send_dimensions:
|
||||||
|
# Check if user provided embedding_dim parameter
|
||||||
|
if "embedding_dim" in kwargs:
|
||||||
|
user_provided_dim = kwargs["embedding_dim"]
|
||||||
|
# If user's value differs from class attribute, output warning
|
||||||
|
if (
|
||||||
|
user_provided_dim is not None
|
||||||
|
and user_provided_dim != self.embedding_dim
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
||||||
|
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject embedding_dim from decorator
|
||||||
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue