From 3d3f3e32d6ca7def3e824b9bc5701868be33451e Mon Sep 17 00:00:00 2001 From: Scott Davidson Date: Thu, 13 Nov 2025 16:38:30 +0000 Subject: [PATCH] Fix default model base url extraction logic Fixes an issue where default models which used the same factory but different base URLs would all be initialised with the default chat model's base URL and would ignore e.g. the embedding model's base URL config. --- api/db/services/llm_service.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6ccbf5a94..0f0b22119 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -19,6 +19,7 @@ import re from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator +from api.db import LLMType from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService @@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id): from common import settings tenant_llm = [] + model_configs = { + LLMType.CHAT: settings.CHAT_CFG, + LLMType.EMBEDDING: settings.EMBEDDING_CFG, + LLMType.SPEECH2TEXT: settings.ASR_CFG, + LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG, + LLMType.RERANK: settings.RERANK_CFG, + } + seen = set() factory_configs = [] for factory_config in [ @@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id): "llm_factory": factory_config["factory"], "llm_name": llm.llm_name, "model_type": llm.model_type, - "api_key": factory_config["api_key"], - "api_base": factory_config["base_url"], + "api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]), + "api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]), "max_tokens": llm.max_tokens if llm.max_tokens else 8192, } ) @@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant): def encode(self, texts: list): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts}) - + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts}) + safe_texts = [] for text in texts: token_size = num_tokens_from_string(text) @@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant): safe_texts.append(text[:target_len]) else: safe_texts.append(text) - + embeddings, used_tokens = self.mdl.encode(safe_texts) llm_name = getattr(self, "llm_name", None)