Merge pull request #2433 from danielaskdd/fix-jina-embedding

Fix: Add configurable model support for Jina embedding
This commit is contained in:
Daniel.y 2025-11-28 19:36:18 +08:00 committed by GitHub
commit b670544958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 82 additions and 38 deletions

View file

@ -1 +1 @@
__api_version__ = "0256" __api_version__ = "0257"

View file

@ -365,8 +365,12 @@ def parse_args() -> argparse.Namespace:
# Inject model configuration # Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") # EMBEDDING_MODEL defaults to None - each binding will use its own default model
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) # e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4"
args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True)
# EMBEDDING_DIM defaults to None - each binding will use its own default dimension
# Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator
args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True)
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool) args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
# Inject chunk configuration # Inject chunk configuration

View file

@ -654,6 +654,17 @@ def create_app(args):
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance 4. Returns a properly configured EmbeddingFunc instance
Configuration Rules:
- When EMBEDDING_MODEL is not set: Uses provider's default model and dimension
(e.g., jina-embeddings-v4 with 2048 dims, text-embedding-3-small with 1536 dims)
- When EMBEDDING_MODEL is set to a custom model: User MUST also set EMBEDDING_DIM
to match the custom model's dimension (e.g., for jina-embeddings-v3, set EMBEDDING_DIM=1024)
Note: The embedding_dim parameter is automatically injected by EmbeddingFunc wrapper
when send_dimensions=True (enabled for Jina and Gemini bindings). This wrapper calls
the underlying provider function directly (.func) to avoid double-wrapping, so we must
explicitly pass embedding_dim to the provider's underlying function.
""" """
# Step 1: Import provider function and extract default attributes # Step 1: Import provider function and extract default attributes
@ -713,6 +724,7 @@ def create_app(args):
) )
# Step 3: Create optimized embedding function (calls underlying function directly) # Step 3: Create optimized embedding function (calls underlying function directly)
# Note: When model is None, each binding will use its own default model
async def optimized_embedding_function(texts, embedding_dim=None): async def optimized_embedding_function(texts, embedding_dim=None):
try: try:
if binding == "lollms": if binding == "lollms":
@ -724,9 +736,9 @@ def create_app(args):
if isinstance(lollms_embed, EmbeddingFunc) if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed else lollms_embed
) )
return await actual_func( # lollms embed_model is not used (server uses configured vectorizer)
texts, embed_model=model, host=host, api_key=api_key # Only pass base_url and api_key
) return await actual_func(texts, base_url=host, api_key=api_key)
elif binding == "ollama": elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed from lightrag.llm.ollama import ollama_embed
@ -745,13 +757,16 @@ def create_app(args):
ollama_options = OllamaEmbeddingOptions.options_dict(args) ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await actual_func( # Pass embed_model only if provided, let function use its default (bge-m3:latest)
texts, kwargs = {
embed_model=model, "texts": texts,
host=host, "host": host,
api_key=api_key, "api_key": api_key,
options=ollama_options, "options": ollama_options,
) }
if model:
kwargs["embed_model"] = model
return await actual_func(**kwargs)
elif binding == "azure_openai": elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed from lightrag.llm.azure_openai import azure_openai_embed
@ -760,7 +775,11 @@ def create_app(args):
if isinstance(azure_openai_embed, EmbeddingFunc) if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed else azure_openai_embed
) )
return await actual_func(texts, model=model, api_key=api_key) # Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts, "api_key": api_key}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "aws_bedrock": elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed from lightrag.llm.bedrock import bedrock_embed
@ -769,7 +788,11 @@ def create_app(args):
if isinstance(bedrock_embed, EmbeddingFunc) if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed else bedrock_embed
) )
return await actual_func(texts, model=model) # Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "jina": elif binding == "jina":
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
@ -778,12 +801,16 @@ def create_app(args):
if isinstance(jina_embed, EmbeddingFunc) if isinstance(jina_embed, EmbeddingFunc)
else jina_embed else jina_embed
) )
return await actual_func( # Pass model only if provided, let function use its default (jina-embeddings-v4)
texts, kwargs = {
embedding_dim=embedding_dim, "texts": texts,
base_url=host, "embedding_dim": embedding_dim,
api_key=api_key, "base_url": host,
) "api_key": api_key,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "gemini": elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed from lightrag.llm.gemini import gemini_embed
@ -801,14 +828,19 @@ def create_app(args):
gemini_options = GeminiEmbeddingOptions.options_dict(args) gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await actual_func( # Pass model only if provided, let function use its default (gemini-embedding-001)
texts, kwargs = {
model=model, "texts": texts,
base_url=host, "base_url": host,
api_key=api_key, "api_key": api_key,
embedding_dim=embedding_dim, "embedding_dim": embedding_dim,
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), "task_type": gemini_options.get(
) "task_type", "RETRIEVAL_DOCUMENT"
),
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
@ -817,13 +849,16 @@ def create_app(args):
if isinstance(openai_embed, EmbeddingFunc) if isinstance(openai_embed, EmbeddingFunc)
else openai_embed else openai_embed
) )
return await actual_func( # Pass model only if provided, let function use its default (text-embedding-3-small)
texts, kwargs = {
model=model, "texts": texts,
base_url=host, "base_url": host,
api_key=api_key, "api_key": api_key,
embedding_dim=embedding_dim, "embedding_dim": embedding_dim,
) }
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}") raise Exception(f"Failed to import {binding} embedding: {e}")

View file

@ -69,6 +69,7 @@ async def fetch_data(url, headers, data):
) )
async def jina_embed( async def jina_embed(
texts: list[str], texts: list[str],
model: str = "jina-embeddings-v4",
embedding_dim: int = 2048, embedding_dim: int = 2048,
late_chunking: bool = False, late_chunking: bool = False,
base_url: str = None, base_url: str = None,
@ -78,6 +79,8 @@ async def jina_embed(
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
model: The Jina embedding model to use (default: jina-embeddings-v4).
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4). embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly. Do NOT manually pass this parameter when calling the function directly.
@ -107,7 +110,7 @@ async def jina_embed(
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}", "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
} }
data = { data = {
"model": "jina-embeddings-v4", "model": model,
"task": "text-matching", "task": "text-matching",
"dimensions": embedding_dim, "dimensions": embedding_dim,
"embedding_type": "base64", "embedding_type": "base64",

View file

@ -173,7 +173,9 @@ async def ollama_model_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
) -> np.ndarray:
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
if not api_key: if not api_key:
api_key = os.getenv("OLLAMA_API_KEY") api_key = os.getenv("OLLAMA_API_KEY")