This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:28 +08:00
parent 49b0953ac1
commit 086191ae5a
2 changed files with 38 additions and 58 deletions

View file

@ -713,7 +713,6 @@ def create_app(args):
) )
# Step 3: Create optimized embedding function (calls underlying function directly) # Step 3: Create optimized embedding function (calls underlying function directly)
# Note: When model is None, each binding will use its own default model
async def optimized_embedding_function(texts, embedding_dim=None): async def optimized_embedding_function(texts, embedding_dim=None):
try: try:
if binding == "lollms": if binding == "lollms":
@ -725,9 +724,9 @@ def create_app(args):
if isinstance(lollms_embed, EmbeddingFunc) if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed else lollms_embed
) )
# lollms embed_model is not used (server uses configured vectorizer) return await actual_func(
# Only pass base_url and api_key texts, embed_model=model, host=host, api_key=api_key
return await actual_func(texts, base_url=host, api_key=api_key) )
elif binding == "ollama": elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed from lightrag.llm.ollama import ollama_embed
@ -746,16 +745,13 @@ def create_app(args):
ollama_options = OllamaEmbeddingOptions.options_dict(args) ollama_options = OllamaEmbeddingOptions.options_dict(args)
# Pass embed_model only if provided, let function use its default (bge-m3:latest) return await actual_func(
kwargs = { texts,
"texts": texts, embed_model=model,
"host": host, host=host,
"api_key": api_key, api_key=api_key,
"options": ollama_options, options=ollama_options,
} )
if model:
kwargs["embed_model"] = model
return await actual_func(**kwargs)
elif binding == "azure_openai": elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed from lightrag.llm.azure_openai import azure_openai_embed
@ -764,11 +760,7 @@ def create_app(args):
if isinstance(azure_openai_embed, EmbeddingFunc) if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed else azure_openai_embed
) )
# Pass model only if provided, let function use its default otherwise return await actual_func(texts, model=model, api_key=api_key)
kwargs = {"texts": texts, "api_key": api_key}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "aws_bedrock": elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed from lightrag.llm.bedrock import bedrock_embed
@ -777,11 +769,7 @@ def create_app(args):
if isinstance(bedrock_embed, EmbeddingFunc) if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed else bedrock_embed
) )
# Pass model only if provided, let function use its default otherwise return await actual_func(texts, model=model)
kwargs = {"texts": texts}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "jina": elif binding == "jina":
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
@ -790,16 +778,13 @@ def create_app(args):
if isinstance(jina_embed, EmbeddingFunc) if isinstance(jina_embed, EmbeddingFunc)
else jina_embed else jina_embed
) )
# Pass model only if provided, let function use its default (jina-embeddings-v4) return await actual_func(
kwargs = { texts,
"texts": texts, model=model,
"embedding_dim": embedding_dim, embedding_dim=embedding_dim,
"base_url": host, base_url=host,
"api_key": api_key, api_key=api_key,
} )
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "gemini": elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed from lightrag.llm.gemini import gemini_embed
@ -817,19 +802,14 @@ def create_app(args):
gemini_options = GeminiEmbeddingOptions.options_dict(args) gemini_options = GeminiEmbeddingOptions.options_dict(args)
# Pass model only if provided, let function use its default (gemini-embedding-001) return await actual_func(
kwargs = { texts,
"texts": texts, model=model,
"base_url": host, base_url=host,
"api_key": api_key, api_key=api_key,
"embedding_dim": embedding_dim, embedding_dim=embedding_dim,
"task_type": gemini_options.get( task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
"task_type", "RETRIEVAL_DOCUMENT" )
),
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
@ -838,16 +818,13 @@ def create_app(args):
if isinstance(openai_embed, EmbeddingFunc) if isinstance(openai_embed, EmbeddingFunc)
else openai_embed else openai_embed
) )
# Pass model only if provided, let function use its default (text-embedding-3-small) return await actual_func(
kwargs = { texts,
"texts": texts, model=model,
"base_url": host, base_url=host,
"api_key": api_key, api_key=api_key,
"embedding_dim": embedding_dim, embedding_dim=embedding_dim,
} )
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}") raise Exception(f"Failed to import {binding} embedding: {e}")

View file

@ -69,6 +69,7 @@ async def fetch_data(url, headers, data):
) )
async def jina_embed( async def jina_embed(
texts: list[str], texts: list[str],
model: str = "jina-embeddings-v4",
embedding_dim: int = 2048, embedding_dim: int = 2048,
late_chunking: bool = False, late_chunking: bool = False,
base_url: str = None, base_url: str = None,
@ -78,6 +79,8 @@ async def jina_embed(
Args: Args:
texts: List of texts to embed. texts: List of texts to embed.
model: The Jina embedding model to use (default: jina-embeddings-v4).
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4). embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly. Do NOT manually pass this parameter when calling the function directly.
@ -107,7 +110,7 @@ async def jina_embed(
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}", "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
} }
data = { data = {
"model": "jina-embeddings-v4", "model": model,
"task": "text-matching", "task": "text-matching",
"dimensions": embedding_dim, "dimensions": embedding_dim,
"embedding_type": "base64", "embedding_type": "base64",