Merge pull request #2433 from danielaskdd/fix-jina-embedding

Fix: Add configurable model support for Jina embedding
This commit is contained in:
Daniel.y 2025-11-28 19:36:18 +08:00 committed by GitHub
commit b670544958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 82 additions and 38 deletions

View file

@ -1 +1 @@
__api_version__ = "0256"
__api_version__ = "0257"

View file

@ -365,8 +365,12 @@ def parse_args() -> argparse.Namespace:
# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
# EMBEDDING_MODEL defaults to None - each binding will use its own default model
# e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4"
args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True)
# EMBEDDING_DIM defaults to None - each binding will use its own default dimension
# Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator
args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True)
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
# Inject chunk configuration

View file

@ -654,6 +654,17 @@ def create_app(args):
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
Configuration Rules:
- When EMBEDDING_MODEL is not set: Uses provider's default model and dimension
(e.g., jina-embeddings-v4 with 2048 dims, text-embedding-3-small with 1536 dims)
- When EMBEDDING_MODEL is set to a custom model: User MUST also set EMBEDDING_DIM
to match the custom model's dimension (e.g., for jina-embeddings-v3, set EMBEDDING_DIM=1024)
Note: The embedding_dim parameter is automatically injected by EmbeddingFunc wrapper
when send_dimensions=True (enabled for Jina and Gemini bindings). This wrapper calls
the underlying provider function directly (.func) to avoid double-wrapping, so we must
explicitly pass embedding_dim to the provider's underlying function.
"""
# Step 1: Import provider function and extract default attributes
@ -713,6 +724,7 @@ def create_app(args):
)
# Step 3: Create optimized embedding function (calls underlying function directly)
# Note: When model is None, each binding will use its own default model
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
@ -724,9 +736,9 @@ def create_app(args):
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key
)
# lollms embed_model is not used (server uses configured vectorizer)
# Only pass base_url and api_key
return await actual_func(texts, base_url=host, api_key=api_key)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
@ -745,13 +757,16 @@ def create_app(args):
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await actual_func(
texts,
embed_model=model,
host=host,
api_key=api_key,
options=ollama_options,
)
# Pass embed_model only if provided, let function use its default (bge-m3:latest)
kwargs = {
"texts": texts,
"host": host,
"api_key": api_key,
"options": ollama_options,
}
if model:
kwargs["embed_model"] = model
return await actual_func(**kwargs)
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
@ -760,7 +775,11 @@ def create_app(args):
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
# Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts, "api_key": api_key}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
@ -769,7 +788,11 @@ def create_app(args):
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
# Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
@ -778,12 +801,16 @@ def create_app(args):
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
)
# Pass model only if provided, let function use its default (jina-embeddings-v4)
kwargs = {
"texts": texts,
"embedding_dim": embedding_dim,
"base_url": host,
"api_key": api_key,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
@ -801,14 +828,19 @@ def create_app(args):
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
)
# Pass model only if provided, let function use its default (gemini-embedding-001)
kwargs = {
"texts": texts,
"base_url": host,
"api_key": api_key,
"embedding_dim": embedding_dim,
"task_type": gemini_options.get(
"task_type", "RETRIEVAL_DOCUMENT"
),
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
@ -817,13 +849,16 @@ def create_app(args):
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
)
# Pass model only if provided, let function use its default (text-embedding-3-small)
kwargs = {
"texts": texts,
"base_url": host,
"api_key": api_key,
"embedding_dim": embedding_dim,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")

View file

@ -69,6 +69,7 @@ async def fetch_data(url, headers, data):
)
async def jina_embed(
texts: list[str],
model: str = "jina-embeddings-v4",
embedding_dim: int = 2048,
late_chunking: bool = False,
base_url: str = None,
@ -78,6 +79,8 @@ async def jina_embed(
Args:
texts: List of texts to embed.
model: The Jina embedding model to use (default: jina-embeddings-v4).
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
@ -107,7 +110,7 @@ async def jina_embed(
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v4",
"model": model,
"task": "text-matching",
"dimensions": embedding_dim,
"embedding_type": "base64",

View file

@ -173,7 +173,9 @@ async def ollama_model_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")