Allow embedding models to use provider defaults when unspecified
- Set EMBEDDING_MODEL default to None - Pass model param only when provided - Let providers use their own defaults - Fix lollms embed function params - Add ollama embed_model default param
This commit is contained in:
parent
881b8d3a50
commit
4ab4a7ac94
3 changed files with 66 additions and 37 deletions
|
|
@ -365,8 +365,12 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
# Inject model configuration
|
# Inject model configuration
|
||||||
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
||||||
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
# EMBEDDING_MODEL defaults to None - each binding will use its own default model
|
||||||
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
# e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4"
|
||||||
|
args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True)
|
||||||
|
# EMBEDDING_DIM defaults to None - each binding will use its own default dimension
|
||||||
|
# Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator
|
||||||
|
args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True)
|
||||||
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
|
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)
|
||||||
|
|
||||||
# Inject chunk configuration
|
# Inject chunk configuration
|
||||||
|
|
|
||||||
|
|
@ -713,6 +713,7 @@ def create_app(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||||
|
# Note: When model is None, each binding will use its own default model
|
||||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||||
try:
|
try:
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
|
|
@ -724,9 +725,9 @@ def create_app(args):
|
||||||
if isinstance(lollms_embed, EmbeddingFunc)
|
if isinstance(lollms_embed, EmbeddingFunc)
|
||||||
else lollms_embed
|
else lollms_embed
|
||||||
)
|
)
|
||||||
return await actual_func(
|
# lollms embed_model is not used (server uses configured vectorizer)
|
||||||
texts, embed_model=model, host=host, api_key=api_key
|
# Only pass base_url and api_key
|
||||||
)
|
return await actual_func(texts, base_url=host, api_key=api_key)
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
|
|
@ -745,13 +746,16 @@ def create_app(args):
|
||||||
|
|
||||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
return await actual_func(
|
# Pass embed_model only if provided, let function use its default (bge-m3:latest)
|
||||||
texts,
|
kwargs = {
|
||||||
embed_model=model,
|
"texts": texts,
|
||||||
host=host,
|
"host": host,
|
||||||
api_key=api_key,
|
"api_key": api_key,
|
||||||
options=ollama_options,
|
"options": ollama_options,
|
||||||
)
|
}
|
||||||
|
if model:
|
||||||
|
kwargs["embed_model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
|
|
@ -760,7 +764,11 @@ def create_app(args):
|
||||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||||
else azure_openai_embed
|
else azure_openai_embed
|
||||||
)
|
)
|
||||||
return await actual_func(texts, model=model, api_key=api_key)
|
# Pass model only if provided, let function use its default otherwise
|
||||||
|
kwargs = {"texts": texts, "api_key": api_key}
|
||||||
|
if model:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
|
|
@ -769,7 +777,11 @@ def create_app(args):
|
||||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||||
else bedrock_embed
|
else bedrock_embed
|
||||||
)
|
)
|
||||||
return await actual_func(texts, model=model)
|
# Pass model only if provided, let function use its default otherwise
|
||||||
|
kwargs = {"texts": texts}
|
||||||
|
if model:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
elif binding == "jina":
|
elif binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
|
|
@ -778,13 +790,16 @@ def create_app(args):
|
||||||
if isinstance(jina_embed, EmbeddingFunc)
|
if isinstance(jina_embed, EmbeddingFunc)
|
||||||
else jina_embed
|
else jina_embed
|
||||||
)
|
)
|
||||||
return await actual_func(
|
# Pass model only if provided, let function use its default (jina-embeddings-v4)
|
||||||
texts,
|
kwargs = {
|
||||||
model=model,
|
"texts": texts,
|
||||||
embedding_dim=embedding_dim,
|
"embedding_dim": embedding_dim,
|
||||||
base_url=host,
|
"base_url": host,
|
||||||
api_key=api_key,
|
"api_key": api_key,
|
||||||
)
|
}
|
||||||
|
if model:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
elif binding == "gemini":
|
elif binding == "gemini":
|
||||||
from lightrag.llm.gemini import gemini_embed
|
from lightrag.llm.gemini import gemini_embed
|
||||||
|
|
||||||
|
|
@ -802,14 +817,19 @@ def create_app(args):
|
||||||
|
|
||||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
return await actual_func(
|
# Pass model only if provided, let function use its default (gemini-embedding-001)
|
||||||
texts,
|
kwargs = {
|
||||||
model=model,
|
"texts": texts,
|
||||||
base_url=host,
|
"base_url": host,
|
||||||
api_key=api_key,
|
"api_key": api_key,
|
||||||
embedding_dim=embedding_dim,
|
"embedding_dim": embedding_dim,
|
||||||
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
|
"task_type": gemini_options.get(
|
||||||
)
|
"task_type", "RETRIEVAL_DOCUMENT"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if model:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
|
|
@ -818,13 +838,16 @@ def create_app(args):
|
||||||
if isinstance(openai_embed, EmbeddingFunc)
|
if isinstance(openai_embed, EmbeddingFunc)
|
||||||
else openai_embed
|
else openai_embed
|
||||||
)
|
)
|
||||||
return await actual_func(
|
# Pass model only if provided, let function use its default (text-embedding-3-small)
|
||||||
texts,
|
kwargs = {
|
||||||
model=model,
|
"texts": texts,
|
||||||
base_url=host,
|
"base_url": host,
|
||||||
api_key=api_key,
|
"api_key": api_key,
|
||||||
embedding_dim=embedding_dim,
|
"embedding_dim": embedding_dim,
|
||||||
)
|
}
|
||||||
|
if model:
|
||||||
|
kwargs["model"] = model
|
||||||
|
return await actual_func(**kwargs)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,9 @@ async def ollama_model_complete(
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embed(
|
||||||
|
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
api_key = os.getenv("OLLAMA_API_KEY")
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue