Fix default model base url extraction logic
Fixes an issue where default models which used the same factory but different base URLs would all be initialised with the default chat model's base URL and would ignore e.g. the embedding model's base URL config.
This commit is contained in:
parent
e8f1a245a6
commit
3d3f3e32d6
1 changed files with 14 additions and 5 deletions
|
|
@ -19,6 +19,7 @@ import re
|
|||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
|
|
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
|
|||
from common import settings
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
LLMType.CHAT: settings.CHAT_CFG,
|
||||
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||
LLMType.RERANK: settings.RERANK_CFG,
|
||||
}
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
|
|
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
|
|||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
|
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
|
|||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
token_size = num_tokens_from_string(text)
|
||||
|
|
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
|
|||
safe_texts.append(text[:target_len])
|
||||
else:
|
||||
safe_texts.append(text)
|
||||
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue