diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index dda22c28d..ac79f04eb 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -21,6 +21,8 @@ import random import re import time from abc import ABC +from copy import deepcopy +from http import HTTPStatus from typing import Any, Protocol from urllib.parse import urljoin @@ -58,13 +60,13 @@ class ToolCallSession(Protocol): class Base(ABC): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters - self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) - self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) + self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) + self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) self.is_tools = False def _get_delay(self, attempt): @@ -96,6 +98,24 @@ class Base(ABC): else: return ERROR_GENERIC + def _clean_conf(self, gen_conf): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + return gen_conf + + def _chat(self, history, gen_conf): + response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) + + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) + def bind_tools(self, toolcall_session, tools): if not (toolcall_session and tools): return @@ -141,12 +161,6 @@ class Base(ABC): args = json.loads(tool_call.function.arguments) tool_response = self.toolcall_session.tool_call(name, args) - # if tool_response.choices[0].finish_reason == "length": - # if is_chinese(ans): - # ans += LENGTH_NOTIFICATION_CN - # else: - # ans += LENGTH_NOTIFICATION_EN - # return ans, tk_count + self.total_token_count(tool_response) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) @@ -184,23 +198,12 @@ class Base(ABC): def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] + gen_conf = self._clean_conf(gen_conf) # Implement exponential backoff retry strategy for attempt in range(self.max_retries): try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return self._chat(history, gen_conf) except Exception as e: logging.exception("chat_model.Base.chat got exception") # Classify the error @@ -309,12 +312,6 @@ class Base(ABC): ], } ) - # if tool_response.choices[0].finish_reason == "length": - # if is_chinese(ans): - # ans += LENGTH_NOTIFICATION_CN - # else: - # ans += LENGTH_NOTIFICATION_EN - # return ans, total_tokens + self.total_token_count(tool_response) history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) final_tool_calls = {} response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) @@ -429,64 +426,64 @@ class Base(ABC): class GptTurbo(Base): - def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): + def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): if not base_url: base_url = "https://api.openai.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class MoonshotChat(Base): - def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): + def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs): if not base_url: base_url = "https://api.moonshot.cn/v1" super().__init__(key, model_name, base_url) class XinferenceChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class HuggingFaceChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url) + super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class ModelScopeChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url) + super().__init__(key, model_name.split("___")[0], base_url, **kwargs) class DeepSeekChat(Base): - def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): + def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs): if not base_url: base_url = "https://api.deepseek.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class AzureChat(Base): def __init__(self, key, model_name, **kwargs): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") - super().__init__(key, model_name, kwargs["base_url"]) + super().__init__(key, model_name, kwargs["base_url"], **kwargs) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name class BaiChuanChat(Base): - def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"): + def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): if not base_url: base_url = "https://api.baichuan-ai.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) @staticmethod def _format_params(params): @@ -495,27 +492,26 @@ class BaiChuanChat(Base): "top_p": params.get("top_p", 0.85), } - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, - **self._format_params(gen_conf), - ) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + def _clean_conf(self, gen_conf): + return { + "temperature": gen_conf.get("temperature", 0.3), + "top_p": gen_conf.get("top_p", 0.85), + } + + def _chat(self, history, gen_conf): + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, + **gen_conf, + ) + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -557,15 +553,15 @@ class BaiChuanChat(Base): class QWenChat(Base): - def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import dashscope dashscope.api_key = key self.model_name = model_name if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") + super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs) def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: if "max_tokens" in gen_conf: @@ -639,41 +635,22 @@ class QWenChat(Base): else: return "".join(result_list[:-1]), result_list[-1] - def chat(self, system, history, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] + def _chat(self, history, gen_conf): if self.is_reasoning_model(self.model_name) or self.model_name in ["qwen-vl-plus", "qwen-vl-plus-latest", "qwen-vl-max", "qwen-vl-max-latest"]: - return super().chat(system, history, gen_conf) - - stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" - if not stream_flag: - from http import HTTPStatus - - if system: - history.insert(0, {"role": "system", "content": system}) - - response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]["message"]["content"] - tk_count += self.total_token_count(response) - if response.output.choices[0].get("finish_reason", "") == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, tk_count - - return "**ERROR**: " + response.message, tk_count - else: - g = self._chat_streamly(system, history, gen_conf, incremental_output=True) - result_list = list(g) - error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] - if len(error_msg_list) > 0: - return "**ERROR**: " + "".join(error_msg_list), 0 - else: - return "".join(result_list[:-1]), result_list[-1] + return super()._chat(history, gen_conf) + response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]["message"]["content"] + tk_count += self.total_token_count(response) + if response.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + return "**ERROR**: " + response.message, tk_count def _wrap_toolcall_message(self, old_message, message): if not old_message: @@ -826,32 +803,20 @@ class QWenChat(Base): class ZhipuChat(Base): - def __init__(self, key, model_name="glm-3-turbo", **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) self.client = ZhipuAI(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - try: - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return "**ERROR**: " + str(e), 0 + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + return gen_conf def chat_with_tools(self, system: str, history: list, gen_conf: dict): if "presence_penalty" in gen_conf: @@ -903,39 +868,31 @@ class ZhipuChat(Base): class OllamaChat(Base): - def __init__(self, key, model_name, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): + options = {} if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - try: - # Calculate context size - ctx_size = self._calculate_dynamic_ctx(history) + options["num_predict"] = gen_conf["max_tokens"] + for k in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]: + if k not in gen_conf: + continue + options[k] = gen_conf[k] + return options - options = {"num_ctx": ctx_size} - if "temperature" in gen_conf: - options["temperature"] = gen_conf["temperature"] - if "max_tokens" in gen_conf: - options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: - options["top_p"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: - options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - options["frequency_penalty"] = gen_conf["frequency_penalty"] + def _chat(self, history, gen_conf): + # Calculate context size + ctx_size = self._calculate_dynamic_ctx(history) - response = self.client.chat(model=self.model_name, messages=history, options=options) - ans = response["message"]["content"].strip() - token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) - return ans, token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + gen_conf["num_ctx"] = ctx_size + response = self.client.chat(model=self.model_name, messages=history, options=gen_conf) + ans = response["message"]["content"].strip() + token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) + return ans, token_count def chat_streamly(self, system, history, gen_conf): if system: @@ -975,8 +932,8 @@ class OllamaChat(Base): class LocalAIChat(Base): - def __init__(self, key, model_name, base_url): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: raise ValueError("Local llm url cannot be None") @@ -986,36 +943,10 @@ class LocalAIChat(Base): class LocalLLM(Base): - class RPCProxy: - def __init__(self, host, port): - self.host = host - self.port = int(port) - self.__conn() - - def __conn(self): - from multiprocessing.connection import Client - - self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu") - - def __getattr__(self, name): - import pickle - - def do_rpc(*args, **kwargs): - for _ in range(3): - try: - self._connection.send(pickle.dumps((name, args, kwargs))) - return pickle.loads(self._connection.recv()) - except Exception: - self.__conn() - raise Exception("RPC connection lost!") - - return do_rpc - - def __init__(self, key, model_name): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from jina import Client - self.client = Client(port=12345, protocol="grpc", asyncio=True) def _prepare_prompt(self, system, history, gen_conf): @@ -1059,9 +990,7 @@ class LocalLLM(Base): class VolcEngineChat(Base): - def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): - super().__init__(key, model_name, base_url=None) - + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use @@ -1070,7 +999,7 @@ class VolcEngineChat(Base): base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3" ark_api_key = json.loads(key).get("ark_api_key", "") model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") - super().__init__(ark_api_key, model_name, base_url) + super().__init__(ark_api_key, model_name, base_url, **kwargs) class MiniMaxChat(Base): @@ -1079,8 +1008,9 @@ class MiniMaxChat(Base): key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + **kwargs ): - super().__init__(key, model_name, base_url=None) + super().__init__(key, model_name, base_url=base_url, **kwargs) if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" @@ -1088,29 +1018,27 @@ class MiniMaxChat(Base): self.model_name = model_name self.api_key = key - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] + return gen_conf + + def _chat(self, history, gen_conf): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf}) - try: - response = requests.request("POST", url=self.base_url, headers=headers, data=payload) - response = response.json() - ans = response["choices"][0]["message"]["content"].strip() - if response["choices"][0]["finish_reason"] == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return "**ERROR**: " + str(e), 0 + response = requests.request("POST", url=self.base_url, headers=headers, data=payload) + response = response.json() + ans = response["choices"][0]["message"]["content"].strip() + if response["choices"][0]["finish_reason"] == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -1159,31 +1087,29 @@ class MiniMaxChat(Base): class MistralChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from mistralai.client import MistralClient self.client = MistralClient(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] - try: - response = self.client.chat(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 + return gen_conf + + def _chat(self, history, gen_conf): + response = self.client.chat(model=self.model_name, messages=history, **gen_conf) + ans = response.choices[0].message.content + if response.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): if system: @@ -1214,8 +1140,8 @@ class MistralChat(Base): class BedrockChat(Base): - def __init__(self, key, model_name, **kwargs): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import boto3 @@ -1230,31 +1156,32 @@ class BedrockChat(Base): else: self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) - def chat(self, system, history, gen_conf): - from botocore.exceptions import ClientError - + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature"]: del gen_conf[k] + return gen_conf + + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + hist = [] for item in history: - if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): - item["content"] = [{"text": item["content"]}] + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + if not isinstance(hist[-1]["content"], list) and not isinstance(hist[-1]["content"], tuple): + hist[-1]["content"] = [{"text": hist[-1]["content"]}] + # Send the message to the model, using a basic inference configuration. + response = self.client.converse( + modelId=self.model_name, + messages=hist, + inferenceConfig=gen_conf, + system=[{"text": (system if system else "Answer the user's message.")}], + ) - try: - # Send the message to the model, using a basic inference configuration. - response = self.client.converse( - modelId=self.model_name, - messages=history, - inferenceConfig=gen_conf, - system=[{"text": (system if system else "Answer the user's message.")}], - ) - - # Extract and print the response text. - ans = response["output"]["message"]["content"][0]["text"] - return ans, num_tokens_from_string(ans) - - except (ClientError, Exception) as e: - return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0 + # Extract and print the response text. + ans = response["output"]["message"]["content"][0]["text"] + return ans, num_tokens_from_string(ans) def chat_streamly(self, system, history, gen_conf): from botocore.exceptions import ClientError @@ -1295,8 +1222,8 @@ class BedrockChat(Base): class GeminiChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from google.generativeai import GenerativeModel, client @@ -1306,15 +1233,21 @@ class GeminiChat(Base): self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client - def chat(self, system, history, gen_conf): - from google.generativeai.types import content_types - - if system: - self.model._system_instruction = content_types.to_content(system) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] + return gen_conf + + def _chat(self, history, gen_conf): + from google.generativeai.types import content_types + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + hist = [] for item in history: + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + item = hist[-1] if "role" in item and item["role"] == "assistant": item["role"] = "model" if "role" in item and item["role"] == "system": @@ -1322,12 +1255,11 @@ class GeminiChat(Base): if "content" in item: item["parts"] = item.pop("content") - try: - response = self.model.generate_content(history, generation_config=gen_conf) - ans = response.text - return ans, response.usage_metadata.total_token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + if system: + self.model._system_instruction = content_types.to_content(system) + response = self.model.generate_content(hist, generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count def chat_streamly(self, system, history, gen_conf): from google.generativeai.types import content_types @@ -1357,32 +1289,19 @@ class GeminiChat(Base): class GroqChat(Base): - def __init__(self, key, model_name, base_url=""): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from groq import Groq self.client = Groq(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] - ans = "" - try: - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) - ans = response.choices[0].message.content - if response.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + return gen_conf def chat_streamly(self, system, history, gen_conf): if system: @@ -1414,32 +1333,32 @@ class GroqChat(Base): ## openrouter class OpenRouterChat(Base): - def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"): + def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): if not base_url: base_url = "https://openrouter.ai/api/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class StepFunChat(Base): - def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"): + def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): if not base_url: base_url = "https://api.stepfun.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class NvidiaChat(Base): - def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"): + def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs): if not base_url: base_url = "https://integrate.api.nvidia.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class LmStudioChat(Base): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) self.client = OpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name @@ -1453,50 +1372,50 @@ class OpenAI_APIChat(Base): class PPIOChat(Base): - def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai"): + def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): if not base_url: base_url = "https://api.ppinfra.com/v3/openai" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class CoHereChat(Base): - def __init__(self, key, model_name, base_url=""): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from cohere import Client self.client = Client(api_key=key) self.model_name = model_name - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) + def _clean_conf(self, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] if "top_p" in gen_conf: gen_conf["p"] = gen_conf.pop("top_p") if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf: gen_conf.pop("presence_penalty") + return gen_conf + + def _chat(self, history, gen_conf): + hist = [] for item in history: + hist.append(deepcopy(item)) + item = hist[-1] if "role" in item and item["role"] == "user": item["role"] = "USER" if "role" in item and item["role"] == "assistant": item["role"] = "CHATBOT" if "content" in item: item["message"] = item.pop("content") - mes = history.pop()["message"] - ans = "" - try: - response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf) - ans = response.text - if response.finish_reason == "MAX_TOKENS": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, - ) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + mes = hist.pop()["message"] + response = self.client.chat(model=self.model_name, chat_history=hist, message=mes, **gen_conf) + ans = response.text + if response.finish_reason == "MAX_TOKENS": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, + ) def chat_streamly(self, system, history, gen_conf): if system: @@ -1535,92 +1454,82 @@ class CoHereChat(Base): class LeptonAIChat(Base): - def __init__(self, key, model_name, base_url=None): + def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class TogetherAIChat(Base): - def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs): if not base_url: base_url = "https://api.together.xyz/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class PerfXCloudChat(Base): - def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): + def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class UpstageChat(Base): - def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): + def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): if not base_url: base_url = "https://api.upstage.ai/v1/solar" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class NovitaAIChat(Base): - def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai"): + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): if not base_url: base_url = "https://api.novita.ai/v3/openai" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class SILICONFLOWChat(Base): - def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"): + def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): if not base_url: base_url = "https://api.siliconflow.cn/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class YiChat(Base): - def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1"): + def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs) class ReplicateChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from replicate.client import Client self.model_name = model_name self.client = Client(api_token=key) - self.system = "" - def chat(self, system, history, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - if system: - self.system = system - prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) - ans = "" - try: - response = self.client.run( - self.model_name, - input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, - ) - ans = "".join(response) - return ans, num_tokens_from_string(ans) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) + response = self.client.run( + self.model_name, + input={"system_prompt": system, "prompt": prompt, **gen_conf}, + ) + ans = "".join(response) + return ans, num_tokens_from_string(ans) def chat_streamly(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - if system: - self.system = system prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) ans = "" try: response = self.client.run( self.model_name, - input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, + input={"system_prompt": system, "prompt": prompt, **gen_conf}, ) for resp in response: ans = resp @@ -1633,8 +1542,8 @@ class ReplicateChat(Base): class HunyuanChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -1646,33 +1555,24 @@ class HunyuanChat(Base): self.model_name = model_name self.client = hunyuan_client.HunyuanClient(cred, "") - def chat(self, system, history, gen_conf): - from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( - TencentCloudSDKException, - ) - from tencentcloud.hunyuan.v20230901 import models - + def _clean_conf(self, gen_conf): _gen_conf = {} - _history = [{k.capitalize(): v for k, v in item.items()} for item in history] - if system: - _history.insert(0, {"Role": "system", "Content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] if "temperature" in gen_conf: _gen_conf["Temperature"] = gen_conf["temperature"] if "top_p" in gen_conf: _gen_conf["TopP"] = gen_conf["top_p"] + return _gen_conf + def _chat(self, history, gen_conf): + from tencentcloud.hunyuan.v20230901 import models + + hist = [{k.capitalize(): v for k, v in item.items()} for item in history] req = models.ChatCompletionsRequest() - params = {"Model": self.model_name, "Messages": _history, **_gen_conf} + params = {"Model": self.model_name, "Messages": hist, **gen_conf} req.from_json_string(json.dumps(params)) - ans = "" - try: - response = self.client.ChatCompletions(req) - ans = response.Choices[0].Message.Content - return ans, response.Usage.TotalTokens - except TencentCloudSDKException as e: - return ans + "\n**ERROR**: " + str(e), 0 + response = self.client.ChatCompletions(req) + ans = response.Choices[0].Message.Content + return ans, response.Usage.TotalTokens def chat_streamly(self, system, history, gen_conf): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( @@ -1718,7 +1618,7 @@ class HunyuanChat(Base): class SparkChat(Base): - def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"): + def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" model2version = { @@ -1734,12 +1634,12 @@ class SparkChat(Base): model_version = model2version[model_name] else: model_version = model_name - super().__init__(key, model_version, base_url) + super().__init__(key, model_version, base_url, **kwargs) class BaiduYiyanChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import qianfan @@ -1748,27 +1648,20 @@ class BaiduYiyanChat(Base): sk = key.get("yiyan_sk", "") self.client = qianfan.ChatCompletion(ak=ak, sk=sk) self.model_name = model_name.lower() - self.system = "" - def chat(self, system, history, gen_conf): - if system: - self.system = system + def _clean_conf(self, gen_conf): gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - ans = "" + return gen_conf - try: - response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body - ans = response["result"] - return ans, self.total_token_count(response) - - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + response = self.client.do(model=self.model_name, messages=[h for h in history if h["role"] != "system"], system=system, **gen_conf).body + ans = response["result"] + return ans, self.total_token_count(response) def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1776,7 +1669,7 @@ class BaiduYiyanChat(Base): total_tokens = 0 try: - response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf) + response = self.client.do(model=self.model_name, messages=history, system=system, stream=True, **gen_conf) for resp in response: resp = resp.body ans = resp["result"] @@ -1791,18 +1684,15 @@ class BaiduYiyanChat(Base): class AnthropicChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import anthropic self.client = anthropic.Anthropic(api_key=key) self.model_name = model_name - self.system = "" - def chat(self, system, history, gen_conf): - if system: - self.system = system + def _clean_conf(self, gen_conf): if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: @@ -1810,29 +1700,26 @@ class AnthropicChat(Base): gen_conf["max_tokens"] = 8192 if "haiku" in self.model_name or "opus" in self.model_name: gen_conf["max_tokens"] = 4096 + return gen_conf - ans = "" - try: - response = self.client.messages.create( - model=self.model_name, - messages=history, - system=self.system, - stream=False, - **gen_conf, - ).to_dict() - ans = response["content"][0]["text"] - if response["stop_reason"] == "max_tokens": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response["usage"]["input_tokens"] + response["usage"]["output_tokens"], - ) - except Exception as e: - return ans + "\n**ERROR**: " + str(e), 0 + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + response = self.client.messages.create( + model=self.model_name, + messages=[h for h in history if h["role"] != "system"], + system=system, + stream=False, + **gen_conf, + ).to_dict() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: @@ -1873,8 +1760,8 @@ class AnthropicChat(Base): class GoogleChat(Base): - def __init__(self, key, model_name, base_url=None): - super().__init__(key, model_name, base_url=None) + def __init__(self, key, model_name, base_url=None, **kwargs): + super().__init__(key, model_name, base_url=base_url, **kwargs) import base64 @@ -1887,7 +1774,6 @@ class GoogleChat(Base): scopes = ["https://www.googleapis.com/auth/cloud-platform"] self.model_name = model_name - self.system = "" if "claude" in self.model_name: from anthropic import AnthropicVertex @@ -1912,53 +1798,53 @@ class GoogleChat(Base): aiplatform.init(project=project_id, location=region) self.client = glm.GenerativeModel(model_name=self.model_name) - def chat(self, system, history, gen_conf): - if system: - self.system = system - + def _clean_conf(self, gen_conf): if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - try: - response = self.client.messages.create( - model=self.model_name, - messages=history, - system=self.system, - stream=False, - **gen_conf, - ).json() - ans = response["content"][0]["text"] - if response["stop_reason"] == "max_tokens": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ( - ans, - response["usage"]["input_tokens"] + response["usage"]["output_tokens"], - ) - except Exception as e: - return "\n**ERROR**: " + str(e), 0 else: - self.client._system_instruction = self.system if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] for k in list(gen_conf.keys()): if k not in ["temperature", "top_p", "max_output_tokens"]: del gen_conf[k] - for item in history: - if "role" in item and item["role"] == "assistant": - item["role"] = "model" - if "content" in item: - item["parts"] = item.pop("content") - try: - response = self.client.generate_content(history, generation_config=gen_conf) - ans = response.text - return ans, response.usage_metadata.total_token_count - except Exception as e: - return "**ERROR**: " + str(e), 0 + return gen_conf + + def _chat(self, history, gen_conf): + system = history[0]["content"] if history and history[0]["role"] == "system" else "" + if "claude" in self.model_name: + response = self.client.messages.create( + model=self.model_name, + messages=[h for h in history if h["role"] != "system"], + system=system, + stream=False, + **gen_conf, + ).json() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) + + self.client._system_instruction = system + hist = [] + for item in history: + if item["role"] == "system": + continue + hist.append(deepcopy(item)) + item = hist[-1] + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "content" in item: + item["parts"] = item.pop("content") + + response = self.client.generate_content(hist, generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count def chat_streamly(self, system, history, gen_conf): - if system: - self.system = system - if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1968,7 +1854,7 @@ class GoogleChat(Base): response = self.client.messages.create( model=self.model_name, messages=history, - system=self.system, + system=system, stream=True, **gen_conf, ) @@ -1983,7 +1869,7 @@ class GoogleChat(Base): yield total_tokens else: - self.client._system_instruction = self.system + self.client._system_instruction = system if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] for k in list(gen_conf.keys()): @@ -2008,8 +1894,8 @@ class GoogleChat(Base): class GPUStackChat(Base): - def __init__(self, key=None, model_name="", base_url=""): + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url) + super().__init__(key, model_name, base_url, **kwargs)