From 4ab4a7ac949aa0bb96983b3e9eda46e977803d21 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Nov 2025 16:57:33 +0800 Subject: [PATCH] Allow embedding models to use provider defaults when unspecified - Set EMBEDDING_MODEL default to None - Pass model param only when provided - Let providers use their own defaults - Fix lollms embed function params - Add ollama embed_model default param --- lightrag/api/config.py | 8 ++- lightrag/api/lightrag_server.py | 91 +++++++++++++++++++++------------ lightrag/llm/ollama.py | 4 +- 3 files changed, 66 insertions(+), 37 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 4f59d3c1..4d8ab1e1 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -365,8 +365,12 @@ def parse_args() -> argparse.Namespace: # Inject model configuration args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") - args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") - args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) + # EMBEDDING_MODEL defaults to None - each binding will use its own default model + # e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4" + args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True) + # EMBEDDING_DIM defaults to None - each binding will use its own default dimension + # Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator + args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True) args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool) # Inject chunk configuration diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 16e59d21..930358a7 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -713,6 +713,7 @@ def create_app(args): ) # Step 3: Create optimized embedding function (calls underlying function directly) + # Note: When model is None, each binding will use its own default model async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": @@ -724,9 +725,9 @@ def create_app(args): if isinstance(lollms_embed, EmbeddingFunc) else lollms_embed ) - return await actual_func( - texts, embed_model=model, host=host, api_key=api_key - ) + # lollms embed_model is not used (server uses configured vectorizer) + # Only pass base_url and api_key + return await actual_func(texts, base_url=host, api_key=api_key) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed @@ -745,13 +746,16 @@ def create_app(args): ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await actual_func( - texts, - embed_model=model, - host=host, - api_key=api_key, - options=ollama_options, - ) + # Pass embed_model only if provided, let function use its default (bge-m3:latest) + kwargs = { + "texts": texts, + "host": host, + "api_key": api_key, + "options": ollama_options, + } + if model: + kwargs["embed_model"] = model + return await actual_func(**kwargs) elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed @@ -760,7 +764,11 @@ def create_app(args): if isinstance(azure_openai_embed, EmbeddingFunc) else azure_openai_embed ) - return await actual_func(texts, model=model, api_key=api_key) + # Pass model only if provided, let function use its default otherwise + kwargs = {"texts": texts, "api_key": api_key} + if model: + kwargs["model"] = model + return await actual_func(**kwargs) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed @@ -769,7 +777,11 @@ def create_app(args): if isinstance(bedrock_embed, EmbeddingFunc) else bedrock_embed ) - return await actual_func(texts, model=model) + # Pass model only if provided, let function use its default otherwise + kwargs = {"texts": texts} + if model: + kwargs["model"] = model + return await actual_func(**kwargs) elif binding == "jina": from lightrag.llm.jina import jina_embed @@ -778,13 +790,16 @@ def create_app(args): if isinstance(jina_embed, EmbeddingFunc) else jina_embed ) - return await actual_func( - texts, - model=model, - embedding_dim=embedding_dim, - base_url=host, - api_key=api_key, - ) + # Pass model only if provided, let function use its default (jina-embeddings-v4) + kwargs = { + "texts": texts, + "embedding_dim": embedding_dim, + "base_url": host, + "api_key": api_key, + } + if model: + kwargs["model"] = model + return await actual_func(**kwargs) elif binding == "gemini": from lightrag.llm.gemini import gemini_embed @@ -802,14 +817,19 @@ def create_app(args): gemini_options = GeminiEmbeddingOptions.options_dict(args) - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, - task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), - ) + # Pass model only if provided, let function use its default (gemini-embedding-001) + kwargs = { + "texts": texts, + "base_url": host, + "api_key": api_key, + "embedding_dim": embedding_dim, + "task_type": gemini_options.get( + "task_type", "RETRIEVAL_DOCUMENT" + ), + } + if model: + kwargs["model"] = model + return await actual_func(**kwargs) else: # openai and compatible from lightrag.llm.openai import openai_embed @@ -818,13 +838,16 @@ def create_app(args): if isinstance(openai_embed, EmbeddingFunc) else openai_embed ) - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, - ) + # Pass model only if provided, let function use its default (text-embedding-3-small) + kwargs = { + "texts": texts, + "base_url": host, + "api_key": api_key, + "embedding_dim": embedding_dim, + } + if model: + kwargs["model"] = model + return await actual_func(**kwargs) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index e35dc293..cd633e80 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -173,7 +173,9 @@ async def ollama_model_complete( @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) -async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: +async def ollama_embed( + texts: list[str], embed_model: str = "bge-m3:latest", **kwargs +) -> np.ndarray: api_key = kwargs.pop("api_key", None) if not api_key: api_key = os.getenv("OLLAMA_API_KEY")