cherry-pick 56e0365c
This commit is contained in:
parent
49b0953ac1
commit
086191ae5a
2 changed files with 38 additions and 58 deletions
|
|
@ -713,7 +713,6 @@ def create_app(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||||
# Note: When model is None, each binding will use its own default model
|
|
||||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||||
try:
|
try:
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
|
|
@ -725,9 +724,9 @@ def create_app(args):
|
||||||
if isinstance(lollms_embed, EmbeddingFunc)
|
if isinstance(lollms_embed, EmbeddingFunc)
|
||||||
else lollms_embed
|
else lollms_embed
|
||||||
)
|
)
|
||||||
# lollms embed_model is not used (server uses configured vectorizer)
|
return await actual_func(
|
||||||
# Only pass base_url and api_key
|
texts, embed_model=model, host=host, api_key=api_key
|
||||||
return await actual_func(texts, base_url=host, api_key=api_key)
|
)
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
|
|
@ -746,16 +745,13 @@ def create_app(args):
|
||||||
|
|
||||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
# Pass embed_model only if provided, let function use its default (bge-m3:latest)
|
return await actual_func(
|
||||||
kwargs = {
|
texts,
|
||||||
"texts": texts,
|
embed_model=model,
|
||||||
"host": host,
|
host=host,
|
||||||
"api_key": api_key,
|
api_key=api_key,
|
||||||
"options": ollama_options,
|
options=ollama_options,
|
||||||
}
|
)
|
||||||
if model:
|
|
||||||
kwargs["embed_model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
|
|
@ -764,11 +760,7 @@ def create_app(args):
|
||||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||||
else azure_openai_embed
|
else azure_openai_embed
|
||||||
)
|
)
|
||||||
# Pass model only if provided, let function use its default otherwise
|
return await actual_func(texts, model=model, api_key=api_key)
|
||||||
kwargs = {"texts": texts, "api_key": api_key}
|
|
||||||
if model:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
|
|
@ -777,11 +769,7 @@ def create_app(args):
|
||||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||||
else bedrock_embed
|
else bedrock_embed
|
||||||
)
|
)
|
||||||
# Pass model only if provided, let function use its default otherwise
|
return await actual_func(texts, model=model)
|
||||||
kwargs = {"texts": texts}
|
|
||||||
if model:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
elif binding == "jina":
|
elif binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
|
|
@ -790,16 +778,13 @@ def create_app(args):
|
||||||
if isinstance(jina_embed, EmbeddingFunc)
|
if isinstance(jina_embed, EmbeddingFunc)
|
||||||
else jina_embed
|
else jina_embed
|
||||||
)
|
)
|
||||||
# Pass model only if provided, let function use its default (jina-embeddings-v4)
|
return await actual_func(
|
||||||
kwargs = {
|
texts,
|
||||||
"texts": texts,
|
model=model,
|
||||||
"embedding_dim": embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
"base_url": host,
|
base_url=host,
|
||||||
"api_key": api_key,
|
api_key=api_key,
|
||||||
}
|
)
|
||||||
if model:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
elif binding == "gemini":
|
elif binding == "gemini":
|
||||||
from lightrag.llm.gemini import gemini_embed
|
from lightrag.llm.gemini import gemini_embed
|
||||||
|
|
||||||
|
|
@ -817,19 +802,14 @@ def create_app(args):
|
||||||
|
|
||||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
# Pass model only if provided, let function use its default (gemini-embedding-001)
|
return await actual_func(
|
||||||
kwargs = {
|
texts,
|
||||||
"texts": texts,
|
model=model,
|
||||||
"base_url": host,
|
base_url=host,
|
||||||
"api_key": api_key,
|
api_key=api_key,
|
||||||
"embedding_dim": embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
"task_type": gemini_options.get(
|
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
|
||||||
"task_type", "RETRIEVAL_DOCUMENT"
|
)
|
||||||
),
|
|
||||||
}
|
|
||||||
if model:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
|
|
@ -838,16 +818,13 @@ def create_app(args):
|
||||||
if isinstance(openai_embed, EmbeddingFunc)
|
if isinstance(openai_embed, EmbeddingFunc)
|
||||||
else openai_embed
|
else openai_embed
|
||||||
)
|
)
|
||||||
# Pass model only if provided, let function use its default (text-embedding-3-small)
|
return await actual_func(
|
||||||
kwargs = {
|
texts,
|
||||||
"texts": texts,
|
model=model,
|
||||||
"base_url": host,
|
base_url=host,
|
||||||
"api_key": api_key,
|
api_key=api_key,
|
||||||
"embedding_dim": embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
}
|
)
|
||||||
if model:
|
|
||||||
kwargs["model"] = model
|
|
||||||
return await actual_func(**kwargs)
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ async def fetch_data(url, headers, data):
|
||||||
)
|
)
|
||||||
async def jina_embed(
|
async def jina_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
|
model: str = "jina-embeddings-v4",
|
||||||
embedding_dim: int = 2048,
|
embedding_dim: int = 2048,
|
||||||
late_chunking: bool = False,
|
late_chunking: bool = False,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
|
|
@ -78,6 +79,8 @@ async def jina_embed(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to embed.
|
texts: List of texts to embed.
|
||||||
|
model: The Jina embedding model to use (default: jina-embeddings-v4).
|
||||||
|
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
|
||||||
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
Do NOT manually pass this parameter when calling the function directly.
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
|
@ -107,7 +110,7 @@ async def jina_embed(
|
||||||
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"model": "jina-embeddings-v4",
|
"model": model,
|
||||||
"task": "text-matching",
|
"task": "text-matching",
|
||||||
"dimensions": embedding_dim,
|
"dimensions": embedding_dim,
|
||||||
"embedding_type": "base64",
|
"embedding_type": "base64",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue