typo of chat_model.py

This commit is contained in:
yongtenglei 2025-12-01 11:10:08 +08:00
parent 1ba19972af
commit 983fa6320a

View file

@ -33,9 +33,9 @@ from openai.lib.azure import AzureOpenAI
from strenum import StrEnum from strenum import StrEnum
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english from rag.nlp import is_chinese, is_english
from common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants # Error message constants
@ -66,7 +66,7 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name self.model_name = model_name
# Configure retry parameters # Configure retry parameters
@ -127,7 +127,7 @@ class Base(ABC):
"tool_choice", "tool_choice",
"logprobs", "logprobs",
"top_logprobs", "top_logprobs",
"extra_headers" "extra_headers",
} }
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
@ -1213,7 +1213,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig # Build GenerateContentConfig
try: try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e: except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise raise
@ -1242,14 +1242,14 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"] role = "model" if item["role"] == "assistant" else item["role"]
content = Content( content = Content(
role=role, role=role,
parts=[Part(text=item["content"])] parts=[Part(text=item["content"])],
) )
contents.append(content) contents.append(content)
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.model_name, model=self.model_name,
contents=contents, contents=contents,
config=config config=config,
) )
ans = response.text ans = response.text
@ -1299,7 +1299,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig # Build GenerateContentConfig
try: try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e: except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise raise
@ -1326,7 +1326,7 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"] role = "model" if item["role"] == "assistant" else item["role"]
content = Content( content = Content(
role=role, role=role,
parts=[Part(text=item["content"])] parts=[Part(text=item["content"])],
) )
contents.append(content) contents.append(content)
@ -1334,7 +1334,7 @@ class GoogleChat(Base):
for chunk in self.client.models.generate_content_stream( for chunk in self.client.models.generate_content_stream(
model=self.model_name, model=self.model_name,
contents=contents, contents=contents,
config=config config=config,
): ):
text = chunk.text text = chunk.text
ans = text ans = text
@ -1406,7 +1406,7 @@ class LiteLLMBase(ABC):
] ]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.provider = kwargs.get("provider", "") self.provider = kwargs.get("provider", "")
self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "") self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
self.model_name = f"{self.prefix}{model_name}" self.model_name = f"{self.prefix}{model_name}"
@ -1625,6 +1625,7 @@ class LiteLLMBase(ABC):
if self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order: if self.provider_order:
def _to_order_list(x): def _to_order_list(x):
if x is None: if x is None:
return [] return []
@ -1633,6 +1634,7 @@ class LiteLLMBase(ABC):
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()] return [str(s).strip() for s in x if str(s).strip()]
return [] return []
extra_body = {} extra_body = {}
provider_cfg = {} provider_cfg = {}
provider_order = _to_order_list(self.provider_order) provider_order = _to_order_list(self.provider_order)