diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index ac79f04eb..c045d2611 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -636,21 +636,49 @@ class QWenChat(Base): return "".join(result_list[:-1]), result_list[-1] def _chat(self, history, gen_conf): - if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - return super()._chat(history, gen_conf) - response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) - ans = "" tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]["message"]["content"] - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN + if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: + try: + response = super()._chat(history, gen_conf) + return response + except Exception as e: + error_msg = str(e).lower() + if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: + return self._simulate_one_shot_from_stream(history, gen_conf) else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - return "**ERROR**: " + response.message, tk_count + return "**ERROR**: " + str(e), tk_count + + try: + ans = "" + response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]["message"]["content"] + tk_count += self.total_token_count(response) + if response.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + return "**ERROR**: " + response.message, tk_count + except Exception as e: + error_msg = str(e).lower() + if "invalid_parameter_error" in error_msg and "only support stream mode" in error_msg: + return self._simulate_one_shot_from_stream(history, gen_conf) + else: + return "**ERROR**: " + str(e), tk_count + + def _simulate_one_shot_from_stream(self, history, gen_conf): + """ + Handles models that require streaming output but need one-shot response. + """ + g = self._chat_streamly("", history, gen_conf, incremental_output=True) + result_list = list(g) + error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] + if len(error_msg_list) > 0: + return "**ERROR**: " + "".join(error_msg_list), 0 + else: + return "".join(result_list[:-1]), result_list[-1] def _wrap_toolcall_message(self, old_message, message): if not old_message: @@ -943,10 +971,10 @@ class LocalAIChat(Base): class LocalLLM(Base): - def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client + self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): @@ -1003,13 +1031,7 @@ class VolcEngineChat(Base): class MiniMaxChat(Base): - def __init__( - self, - key, - model_name, - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", - **kwargs - ): + def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: @@ -1241,6 +1263,7 @@ class GeminiChat(Base): def _chat(self, history, gen_conf): from google.generativeai.types import content_types + system = history[0]["content"] if history and history[0]["role"] == "system" else "" hist = [] for item in history: