From 8b894a29c4500d2e340e89439044b5f5567f831f Mon Sep 17 00:00:00 2001 From: yongtenglei Date: Fri, 28 Nov 2025 19:24:15 +0800 Subject: [PATCH] chats and agents --- agent/canvas.py | 109 +++++--- agent/component/llm.py | 51 +++- agent/component/message.py | 50 +++- api/apps/canvas_app.py | 7 +- api/db/services/llm_service.py | 83 +++++- rag/llm/__init__.py | 3 + rag/llm/chat_model.py | 471 ++++++++++++++++++++++++++++----- 7 files changed, 654 insertions(+), 120 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index 3e15814aa..1921671d2 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import base64 +import inspect import json import logging import re @@ -79,6 +81,7 @@ class Graph: self.dsl = json.loads(dsl) self._tenant_id = tenant_id self.task_id = task_id if task_id else get_uuid() + self._thread_pool = ThreadPoolExecutor(max_workers=5) self.load() def load(self): @@ -328,6 +331,7 @@ class Canvas(Graph): async def run(self, **kwargs): st = time.perf_counter() + self._loop = asyncio.get_running_loop() self.message_id = get_uuid() created_at = int(time.time()) self.add_user_input(kwargs.get("query")) @@ -343,7 +347,7 @@ class Canvas(Graph): for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": - self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) + self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : @@ -373,31 +377,39 @@ class Canvas(Graph): yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) - def _run_batch(f, t): + async def _run_batch(f, t): if self.is_canceled(): msg = f"Task {self.task_id} has been canceled during batch execution." logging.info(msg) raise TaskCanceledException(msg) - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] - i = f - while i < t: - cpn = self.get_component_obj(self.path[i]) - if cpn.component_name.lower() in ["begin", "userfillup"]: - thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) - i += 1 + loop = asyncio.get_running_loop() + tasks = [] + i = f + while i < t: + cpn = self.get_component_obj(self.path[i]) + task_fn = None + + if cpn.component_name.lower() in ["begin", "userfillup"]: + task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) + i += 1 + else: + for _, ele in cpn.get_input_elements().items(): + if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: + self.path.pop(i) + t -= 1 + break else: - for _, ele in cpn.get_input_elements().items(): - if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: - self.path.pop(i) - t -= 1 - break - else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) - i += 1 - for t in thr: - t.result() + task_fn = partial(cpn.invoke, **cpn.get_input()) + i += 1 + + if task_fn is None: + continue + + tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) + + if tasks: + await asyncio.gather(*tasks) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -424,7 +436,7 @@ class Canvas(Graph): "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) - _run_batch(idx, to) + await _run_batch(idx, to) to = len(self.path) # post processing of components invocation for i in range(idx, to): @@ -433,16 +445,29 @@ class Canvas(Graph): if cpn_obj.component_name.lower() == "message": if isinstance(cpn_obj.output("content"), partial): _m = "" - for m in cpn_obj.output("content")(): - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - yield decorate("message", {"content": m}) - _m += m + stream = cpn_obj.output("content")() + if inspect.isasyncgen(stream): + async for m in stream: + if not m: + continue + if m == "": + yield decorate("message", {"content": "", "start_to_think": True}) + elif m == "": + yield decorate("message", {"content": "", "end_to_think": True}) + else: + yield decorate("message", {"content": m}) + _m += m + else: + for m in stream: + if not m: + continue + if m == "": + yield decorate("message", {"content": "", "start_to_think": True}) + elif m == "": + yield decorate("message", {"content": "", "end_to_think": True}) + else: + yield decorate("message", {"content": m}) + _m += m cpn_obj.set_output("content", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m) else: @@ -590,21 +615,31 @@ class Canvas(Graph): def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() - def get_files(self, files: Union[None, list[dict]]) -> list[str]: + async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]: from api.db.services.file_service import FileService if not files: return [] def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) - threads = [] + loop = asyncio.get_running_loop() + tasks = [] for file in files: if file["mime_type"].find("image") >=0: - threads.append(exe.submit(image_to_base64, file)) + tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file)) continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) - return [th.result() for th in threads] + tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + return await asyncio.gather(*tasks) + + def get_files(self, files: Union[None, list[dict]]) -> list[str]: + """ + Synchronous wrapper for get_files_async, used by sync component invoke paths. + """ + loop = getattr(self, "_loop", None) + if loop and loop.is_running(): + return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result() + + return asyncio.run(self.get_files_async(files)) def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") diff --git a/agent/component/llm.py b/agent/component/llm.py index 807bbc288..3f135b57b 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -205,6 +205,55 @@ class LLM(ComponentBase): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): yield delta(txt) + async def _stream_output_async(self, prompt, msg): + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + answer = "" + last_idx = 0 + endswith_think = False + + def delta(txt): + nonlocal answer, last_idx, endswith_think + delta_ans = txt[last_idx:] + answer = txt + + if delta_ans.find("") == 0: + last_idx += len("") + return "" + elif delta_ans.find("") > 0: + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] + last_idx += delta_ans.find("") + return delta_ans + elif delta_ans.endswith(""): + endswith_think = True + elif endswith_think: + endswith_think = False + return "" + + last_idx = len(answer) + if answer.endswith(""): + last_idx -= len("") + return re.sub(r"(|)", "", delta_ans) + + stream_kwargs = {"images": self.imgs} if self.imgs else {} + async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs): + if self.check_if_canceled("LLM streaming"): + return + + if isinstance(ans, int): + continue + + if ans.find("**ERROR**") >= 0: + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + yield self.get_exception_default_value() + else: + self.set_output("_ERROR", ans) + return + + yield delta(ans) + + self.set_output("content", answer) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): if self.check_if_canceled("LLM processing"): @@ -250,7 +299,7 @@ class LLM(ComponentBase): downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): - self.set_output("content", partial(self._stream_output, prompt, msg)) + self.set_output("content", partial(self._stream_output_async, prompt, msg)) return for _ in range(self._param.max_retries+1): diff --git a/agent/component/message.py b/agent/component/message.py index ac1d2beae..28349a7c3 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import inspect import json import os import random @@ -66,8 +68,12 @@ class Message(ComponentBase): v = "" ans = "" if isinstance(v, partial): - for t in v(): - ans += t + iter_obj = v() + if inspect.isasyncgen(iter_obj): + ans = asyncio.run(self._consume_async_gen(iter_obj)) + else: + for t in iter_obj: + ans += t elif isinstance(v, list) and delimiter: ans = delimiter.join([str(vv) for vv in v]) elif not isinstance(v, str): @@ -89,7 +95,13 @@ class Message(ComponentBase): _kwargs[_n] = v return script, _kwargs - def _stream(self, rand_cnt:str): + async def _consume_async_gen(self, agen): + buf = "" + async for t in agen: + buf += t + return buf + + async def _stream(self, rand_cnt:str): s = 0 all_content = "" cache = {} @@ -111,15 +123,27 @@ class Message(ComponentBase): v = "" if isinstance(v, partial): cnt = "" - for t in v(): - if self.check_if_canceled("Message streaming"): - return + iter_obj = v() + if inspect.isasyncgen(iter_obj): + async for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return - all_content += t - cnt += t - yield t + all_content += t + cnt += t + yield t + else: + for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return + + all_content += t + cnt += t + yield t self.set_input_value(exp, cnt) continue + elif inspect.isawaitable(v): + v = await v elif not isinstance(v, str): try: v = json.dumps(v, ensure_ascii=False) @@ -181,7 +205,7 @@ class Message(ComponentBase): import pypandoc doc_id = get_uuid() - + if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}: self._param.output_format = "markdown" @@ -231,11 +255,11 @@ class Message(ComponentBase): settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) self.set_output("attachment", { - "doc_id":doc_id, - "format":self._param.output_format, + "doc_id":doc_id, + "format":self._param.output_format, "file_name":f"{doc_id[:8]}.{self._param.output_format}"}) logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})") except Exception as e: - logging.error(f"Error converting content to {self._param.output_format}: {e}") \ No newline at end of file + logging.error(f"Error converting content to {self._param.output_format}: {e}") diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 86ffaedb1..69b956dc1 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import re @@ -134,12 +135,12 @@ async def run(): files = req.get("files", []) inputs = req.get("inputs", {}) user_id = req.get("user_id", current_user.id) - if not UserCanvasService.accessible(req["id"], current_user.id): + if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) - e, cvs = UserCanvasService.get_by_id(req["id"]) + e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"]) if not e: return get_data_error_result(message="canvas not found.") @@ -149,7 +150,7 @@ async def run(): if cvs.canvas_category == CanvasCategory.DataFlow: task_id = get_uuid() Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) + ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0) if not ok: return get_data_error_result(message=error_message) return get_json_result(data={"message_id": task_id}) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 4d4ccaa57..a681341d4 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import inspect import logging import re +import threading from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator @@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant): if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) if self.langfuse: @@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant): yield ans if total_tokens > 0: - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + + def _bridge_sync_stream(self, gen): + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in gen: + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as e: # pragma: no cover + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + return queue + + async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) + if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"): + chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + + use_kwargs = self._clean_param(chat_partial, **kwargs) + + if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools: + txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs) + elif hasattr(self.mdl, "async_chat"): + txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs) + else: + txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs) + + txt = self._remove_reasoning_content(txt) + if not self.verbose_tool_use: + txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + + if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + + return txt + + async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + total_tokens = 0 + if self.is_tools and self.mdl.is_tools: + stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) + else: + stream_fn = getattr(self.mdl, "async_chat_streamly", None) + + if stream_fn: + chat_partial = partial(stream_fn, system, history, gen_conf) + use_kwargs = self._clean_param(chat_partial, **kwargs) + async for txt in chat_partial(**use_kwargs): + if isinstance(txt, int): + total_tokens = txt + break + yield txt + if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + return + + chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf) + use_kwargs = self._clean_param(chat_partial, **kwargs) + queue = self._bridge_sync_stream(chat_partial(**use_kwargs)) + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + if isinstance(item, int): + total_tokens = item + break + yield item + + if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 897fec65f..1913646a2 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum): GiteeAI = "GiteeAI" AI_302 = "302.AI" JiekouAI = "Jiekou.AI" + ZHIPU_AI = "ZHIPU-AI" FACTORY_DEFAULT_BASE_URL = { @@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = { SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1", SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", + SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", } @@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = { SupportedLiteLLMProvider.GiteeAI: "openai/", SupportedLiteLLMProvider.AI_302: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/", + SupportedLiteLLMProvider.ZHIPU_AI: "openai/", } ChatModel = globals().get("ChatModel", {}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cce5b2454..e2a5c19cb 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -19,6 +19,7 @@ import logging import os import random import re +import threading import time from abc import ABC from copy import deepcopy @@ -28,10 +29,9 @@ import json_repair import litellm import openai import requests -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from openai.lib.azure import AzureOpenAI from strenum import StrEnum -from zhipuai import ZhipuAI from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english @@ -68,6 +68,7 @@ class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) @@ -134,6 +135,23 @@ class Base(ABC): return gen_conf + def _bridge_sync_stream(self, gen): + """Run a sync generator in a thread and yield asynchronously.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in gen: + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as exc: # pragma: no cover - defensive + loop.call_soon_threadsafe(queue.put_nowait, exc) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + return queue + def _chat(self, history, gen_conf, **kwargs): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwq") >= 0: @@ -199,6 +217,60 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_EN yield ans, tol + async def _async_chat_stream(self, history, gen_conf, **kwargs): + logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + reasoning_start = False + + request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf} + stop = kwargs.get("stop") + if stop: + request_kwargs["stop"] = stop + + response = await self.async_client.chat.completions.create(**request_kwargs) + + async for resp in response: + if not resp.choices: + continue + if not resp.choices[0].delta.content: + resp.choices[0].delta.content = "" + if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += resp.choices[0].delta.reasoning_content + "" + else: + reasoning_start = False + ans = resp.choices[0].delta.content + + tol = total_token_count_from_response(resp) + if not tol: + tol = num_tokens_from_string(resp.choices[0].delta.content) + + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + yield ans, tol + + async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) + ans = "" + total_tokens = 0 + try: + async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): + ans = delta_ans + total_tokens += tol + yield delta_ans + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN @@ -227,7 +299,25 @@ class Base(ABC): time.sleep(delay) return None - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"sync base giving up: {msg}") + return msg + + async def _exceptions_async(self, e, attempt) -> str | None: + logging.exception("OpenAI async completion") + error_code = self._classify_error(e) + if attempt == self.max_retries: + error_code = LLMErrorCode.ERROR_MAX_RETRIES + + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + await asyncio.sleep(delay) + return None + + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"async base giving up: {msg}") + return msg def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" @@ -318,6 +408,60 @@ class Base(ABC): assert False, "Shouldn't be here." + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + gen_conf = self._clean_conf(gen_conf) + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + + ans = "" + tk_count = 0 + hist = deepcopy(history) + for attempt in range(self.max_retries + 1): + history = deepcopy(hist) + try: + for _ in range(self.max_rounds + 1): + logging.info(f"{self.tools=}") + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) + tk_count += total_token_count_from_response(response) + if any([not response.choices, not response.choices[0].message]): + raise Exception(f"500 response structure error. Response: {response}") + + if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: + if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: + ans += "" + response.choices[0].message.reasoning_content + "" + + ans += response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + + return ans, tk_count + + for tool_call in response.choices[0].message.tool_calls: + logging.info(f"Response {tool_call=}") + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + history = self._append_history(history, tool_call, tool_response) + ans += self._verbose_tool_use(name, args, tool_response) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + ans += self._verbose_tool_use(name, {}, str(e)) + + logging.warning(f"Exceed max rounds: {self.max_rounds}") + history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) + response, token_count = await self._async_chat(history, gen_conf) + ans += response + tk_count += token_count + return ans, tk_count + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, tk_count + + assert False, "Shouldn't be here." + def chat(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -452,6 +596,160 @@ class Base(ABC): assert False, "Shouldn't be here." + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + gen_conf = self._clean_conf(gen_conf) + tools = self.tools + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + + total_tokens = 0 + hist = deepcopy(history) + + for attempt in range(self.max_retries + 1): + history = deepcopy(hist) + try: + for _ in range(self.max_rounds + 1): + reasoning_start = False + logging.info(f"{tools=}") + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + + final_tool_calls = {} + answer = "" + + async for resp in response: + if not hasattr(resp, "choices") or not resp.choices: + continue + + delta = resp.choices[0].delta + + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tool_call in delta.tool_calls: + index = tool_call.index + if index not in final_tool_calls: + if not tool_call.function.arguments: + tool_call.function.arguments = "" + final_tool_calls[index] = tool_call + else: + final_tool_calls[index].function.arguments += tool_call.function.arguments or "" + continue + + if not hasattr(delta, "content") or delta.content is None: + delta.content = "" + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += delta.reasoning_content + "" + yield ans + else: + reasoning_start = False + answer += delta.content + yield delta.content + + tol = total_token_count_from_response(resp) + if not tol: + total_tokens += num_tokens_from_string(delta.content) + else: + total_tokens = tol + + finish_reason = getattr(resp.choices[0], "finish_reason", "") + if finish_reason == "length": + yield self._length_stop("") + + if answer: + yield total_tokens + return + + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + yield self._verbose_tool_use(name, args, "Begin to call...") + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + history = self._append_history(history, tool_call, tool_response) + yield self._verbose_tool_use(name, args, tool_response) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + yield self._verbose_tool_use(name, {}, str(e)) + + logging.warning(f"Exceed max rounds: {self.max_rounds}") + history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + + async for resp in response: + if not hasattr(resp, "choices") or not resp.choices: + continue + delta = resp.choices[0].delta + if not hasattr(delta, "content") or delta.content is None: + continue + tol = total_token_count_from_response(resp) + if not tol: + total_tokens += num_tokens_from_string(delta.content) + else: + total_tokens = tol + yield delta.content + + yield total_tokens + return + + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + logging.error(f"async_chat_streamly failed: {e}") + yield e + yield total_tokens + return + + assert False, "Shouldn't be here." + + async def _async_chat(self, history, gen_conf, **kwargs): + logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) + if self.model_name.lower().find("qwq") >= 0: + logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") + final_ans = "" + tol_token = 0 + async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): + if delta.startswith("") or delta.endswith(""): + continue + final_ans += delta + tol_token = tol + + if len(final_ans.strip()) == 0: + final_ans = "**ERROR**: Empty response from reasoning model" + + return final_ans.strip(), tol_token + + if self.model_name.lower().find("qwen3") >= 0: + kwargs["extra_body"] = {"enable_thinking": False} + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) + + if not response.choices or not response.choices[0].message or not response.choices[0].message.content: + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + return ans, total_token_count_from_response(response) + + async def async_chat(self, system, history, gen_conf={}, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) + + for attempt in range(self.max_retries + 1): + try: + return await self._async_chat(history, gen_conf, **kwargs) + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, 0 + assert False, "Shouldn't be here." + def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -637,66 +935,6 @@ class BaiChuanChat(Base): yield total_tokens -class ZhipuChat(Base): - _FACTORY_NAME = "ZHIPU-AI" - - def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): - super().__init__(key, model_name, base_url=base_url, **kwargs) - - self.client = ZhipuAI(api_key=key) - self.model_name = model_name - - def _clean_conf(self, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - gen_conf = self._clean_conf_plealty(gen_conf) - return gen_conf - - def _clean_conf_plealty(self, gen_conf): - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] - return gen_conf - - def chat_with_tools(self, system: str, history: list, gen_conf: dict): - gen_conf = self._clean_conf_plealty(gen_conf) - - return super().chat_with_tools(system, history, gen_conf) - - def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - tk_count = 0 - try: - logging.info(json.dumps(history, ensure_ascii=False, indent=2)) - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - for resp in response: - if not resp.choices[0].delta.content: - continue - delta = resp.choices[0].delta.content - ans = delta - if resp.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - tk_count = total_token_count_from_response(resp) - if resp.choices[0].finish_reason == "stop": - tk_count = total_token_count_from_response(resp) - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): - gen_conf = self._clean_conf_plealty(gen_conf) - return super().chat_streamly_with_tools(system, history, gen_conf) - - class LocalAIChat(Base): _FACTORY_NAME = "LocalAI" @@ -1398,6 +1636,7 @@ class LiteLLMBase(ABC): "GiteeAI", "302.AI", "Jiekou.AI", + "ZHIPU-AI", ] def __init__(self, key, model_name, base_url=None, **kwargs): @@ -1477,6 +1716,7 @@ class LiteLLMBase(ABC): def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + gen_conf = self._clean_conf(gen_conf) reasoning_start = False completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) @@ -1520,6 +1760,96 @@ class LiteLLMBase(ABC): yield ans, tol + async def async_chat(self, history, gen_conf, **kwargs): + logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) + if self.model_name.lower().find("qwen3") >= 0: + kwargs["extra_body"] = {"enable_thinking": False} + + completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) + + for attempt in range(self.max_retries + 1): + try: + response = await litellm.acompletion( + **completion_args, + drop_params=True, + timeout=self.timeout, + ) + + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + + return ans, total_token_count_from_response(response) + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, 0 + + assert False, "Shouldn't be here." + + async def async_chat_streamly(self, system, history, gen_conf, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + gen_conf = self._clean_conf(gen_conf) + reasoning_start = False + total_tokens = 0 + + completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) + stop = kwargs.get("stop") + if stop: + completion_args["stop"] = stop + + for attempt in range(self.max_retries + 1): + try: + stream = await litellm.acompletion( + **completion_args, + drop_params=True, + timeout=self.timeout, + ) + + async for resp in stream: + if not hasattr(resp, "choices") or not resp.choices: + continue + + delta = resp.choices[0].delta + if not hasattr(delta, "content") or delta.content is None: + delta.content = "" + + if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += delta.reasoning_content + "" + else: + reasoning_start = False + ans = delta.content + + tol = total_token_count_from_response(resp) + if not tol: + tol = num_tokens_from_string(delta.content) + total_tokens += tol + + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + + yield ans + yield total_tokens + return + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + yield e + yield total_tokens + return + def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN @@ -1550,6 +1880,21 @@ class LiteLLMBase(ABC): return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + async def _exceptions_async(self, e, attempt) -> str | None: + logging.exception("LiteLLMBase async completion") + error_code = self._classify_error(e) + if attempt == self.max_retries: + error_code = LLMErrorCode.ERROR_MAX_RETRIES + + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + await asyncio.sleep(delay) + return None + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"async_chat_streamly giving up: {msg}") + return msg + def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""