diff --git a/agent/canvas.py b/agent/canvas.py
index 3e15814aa..1921671d2 100644
--- a/agent/canvas.py
+++ b/agent/canvas.py
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import base64
+import inspect
import json
import logging
import re
@@ -79,6 +81,7 @@ class Graph:
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
+ self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load()
def load(self):
@@ -328,6 +331,7 @@ class Canvas(Graph):
async def run(self, **kwargs):
st = time.perf_counter()
+ self._loop = asyncio.get_running_loop()
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
@@ -343,7 +347,7 @@ class Canvas(Graph):
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
- self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
+ self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
else:
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
@@ -373,31 +377,39 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
- def _run_batch(f, t):
+ async def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)
- with ThreadPoolExecutor(max_workers=5) as executor:
- thr = []
- i = f
- while i < t:
- cpn = self.get_component_obj(self.path[i])
- if cpn.component_name.lower() in ["begin", "userfillup"]:
- thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
- i += 1
+ loop = asyncio.get_running_loop()
+ tasks = []
+ i = f
+ while i < t:
+ cpn = self.get_component_obj(self.path[i])
+ task_fn = None
+
+ if cpn.component_name.lower() in ["begin", "userfillup"]:
+ task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
+ i += 1
+ else:
+ for _, ele in cpn.get_input_elements().items():
+ if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
+ self.path.pop(i)
+ t -= 1
+ break
else:
- for _, ele in cpn.get_input_elements().items():
- if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
- self.path.pop(i)
- t -= 1
- break
- else:
- thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
- i += 1
- for t in thr:
- t.result()
+ task_fn = partial(cpn.invoke, **cpn.get_input())
+ i += 1
+
+ if task_fn is None:
+ continue
+
+ tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
+
+ if tasks:
+ await asyncio.gather(*tasks)
def _node_finished(cpn_obj):
return decorate("node_finished",{
@@ -424,7 +436,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
- _run_batch(idx, to)
+ await _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
@@ -433,16 +445,29 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial):
_m = ""
- for m in cpn_obj.output("content")():
- if not m:
- continue
- if m == "":
- yield decorate("message", {"content": "", "start_to_think": True})
- elif m == "":
- yield decorate("message", {"content": "", "end_to_think": True})
- else:
- yield decorate("message", {"content": m})
- _m += m
+ stream = cpn_obj.output("content")()
+ if inspect.isasyncgen(stream):
+ async for m in stream:
+ if not m:
+ continue
+ if m == "":
+ yield decorate("message", {"content": "", "start_to_think": True})
+ elif m == "":
+ yield decorate("message", {"content": "", "end_to_think": True})
+ else:
+ yield decorate("message", {"content": m})
+ _m += m
+ else:
+ for m in stream:
+ if not m:
+ continue
+ if m == "":
+ yield decorate("message", {"content": "", "start_to_think": True})
+ elif m == "":
+ yield decorate("message", {"content": "", "end_to_think": True})
+ else:
+ yield decorate("message", {"content": m})
+ _m += m
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
@@ -590,21 +615,31 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements()
- def get_files(self, files: Union[None, list[dict]]) -> list[str]:
+ async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
from api.db.services.file_service import FileService
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
- exe = ThreadPoolExecutor(max_workers=5)
- threads = []
+ loop = asyncio.get_running_loop()
+ tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
- threads.append(exe.submit(image_to_base64, file))
+ tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
- threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
- return [th.result() for th in threads]
+ tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
+ return await asyncio.gather(*tasks)
+
+ def get_files(self, files: Union[None, list[dict]]) -> list[str]:
+ """
+ Synchronous wrapper for get_files_async, used by sync component invoke paths.
+ """
+ loop = getattr(self, "_loop", None)
+ if loop and loop.is_running():
+ return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
+
+ return asyncio.run(self.get_files_async(files))
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->")
diff --git a/agent/component/llm.py b/agent/component/llm.py
index 807bbc288..3f135b57b 100644
--- a/agent/component/llm.py
+++ b/agent/component/llm.py
@@ -205,6 +205,55 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
+ async def _stream_output_async(self, prompt, msg):
+ _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
+ answer = ""
+ last_idx = 0
+ endswith_think = False
+
+ def delta(txt):
+ nonlocal answer, last_idx, endswith_think
+ delta_ans = txt[last_idx:]
+ answer = txt
+
+ if delta_ans.find("") == 0:
+ last_idx += len("")
+ return ""
+ elif delta_ans.find("") > 0:
+ delta_ans = txt[last_idx:last_idx + delta_ans.find("")]
+ last_idx += delta_ans.find("")
+ return delta_ans
+ elif delta_ans.endswith(""):
+ endswith_think = True
+ elif endswith_think:
+ endswith_think = False
+ return ""
+
+ last_idx = len(answer)
+ if answer.endswith(""):
+ last_idx -= len("")
+ return re.sub(r"(|)", "", delta_ans)
+
+ stream_kwargs = {"images": self.imgs} if self.imgs else {}
+ async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
+ if self.check_if_canceled("LLM streaming"):
+ return
+
+ if isinstance(ans, int):
+ continue
+
+ if ans.find("**ERROR**") >= 0:
+ if self.get_exception_default_value():
+ self.set_output("content", self.get_exception_default_value())
+ yield self.get_exception_default_value()
+ else:
+ self.set_output("_ERROR", ans)
+ return
+
+ yield delta(ans)
+
+ self.set_output("content", answer)
+
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"):
@@ -250,7 +299,7 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
- self.set_output("content", partial(self._stream_output, prompt, msg))
+ self.set_output("content", partial(self._stream_output_async, prompt, msg))
return
for _ in range(self._param.max_retries+1):
diff --git a/agent/component/message.py b/agent/component/message.py
index ac1d2beae..28349a7c3 100644
--- a/agent/component/message.py
+++ b/agent/component/message.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
+import inspect
import json
import os
import random
@@ -66,8 +68,12 @@ class Message(ComponentBase):
v = ""
ans = ""
if isinstance(v, partial):
- for t in v():
- ans += t
+ iter_obj = v()
+ if inspect.isasyncgen(iter_obj):
+ ans = asyncio.run(self._consume_async_gen(iter_obj))
+ else:
+ for t in iter_obj:
+ ans += t
elif isinstance(v, list) and delimiter:
ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str):
@@ -89,7 +95,13 @@ class Message(ComponentBase):
_kwargs[_n] = v
return script, _kwargs
- def _stream(self, rand_cnt:str):
+ async def _consume_async_gen(self, agen):
+ buf = ""
+ async for t in agen:
+ buf += t
+ return buf
+
+ async def _stream(self, rand_cnt:str):
s = 0
all_content = ""
cache = {}
@@ -111,15 +123,27 @@ class Message(ComponentBase):
v = ""
if isinstance(v, partial):
cnt = ""
- for t in v():
- if self.check_if_canceled("Message streaming"):
- return
+ iter_obj = v()
+ if inspect.isasyncgen(iter_obj):
+ async for t in iter_obj:
+ if self.check_if_canceled("Message streaming"):
+ return
- all_content += t
- cnt += t
- yield t
+ all_content += t
+ cnt += t
+ yield t
+ else:
+ for t in iter_obj:
+ if self.check_if_canceled("Message streaming"):
+ return
+
+ all_content += t
+ cnt += t
+ yield t
self.set_input_value(exp, cnt)
continue
+ elif inspect.isawaitable(v):
+ v = await v
elif not isinstance(v, str):
try:
v = json.dumps(v, ensure_ascii=False)
@@ -181,7 +205,7 @@ class Message(ComponentBase):
import pypandoc
doc_id = get_uuid()
-
+
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown"
@@ -231,11 +255,11 @@ class Message(ComponentBase):
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", {
- "doc_id":doc_id,
- "format":self._param.output_format,
+ "doc_id":doc_id,
+ "format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e:
- logging.error(f"Error converting content to {self._param.output_format}: {e}")
\ No newline at end of file
+ logging.error(f"Error converting content to {self._param.output_format}: {e}")
diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py
index 86ffaedb1..69b956dc1 100644
--- a/api/apps/canvas_app.py
+++ b/api/apps/canvas_app.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import json
import logging
import re
@@ -134,12 +135,12 @@ async def run():
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
- if not UserCanvasService.accessible(req["id"], current_user.id):
+ if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
- e, cvs = UserCanvasService.get_by_id(req["id"])
+ e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@@ -149,7 +150,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
- ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
+ ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 4d4ccaa57..a681341d4 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import inspect
import logging
import re
+import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
@@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use:
txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
- if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
+ if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
@@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant):
yield ans
if total_tokens > 0:
- if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
- logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
+ if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+
+ def _bridge_sync_stream(self, gen):
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+
+ def worker():
+ try:
+ for item in gen:
+ loop.call_soon_threadsafe(queue.put_nowait, item)
+ except Exception as e: # pragma: no cover
+ loop.call_soon_threadsafe(queue.put_nowait, e)
+ finally:
+ loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
+
+ threading.Thread(target=worker, daemon=True).start()
+ return queue
+
+ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
+ chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
+ if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
+ chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
+
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+
+ if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
+ txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
+ elif hasattr(self.mdl, "async_chat"):
+ txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
+ else:
+ txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
+
+ txt = self._remove_reasoning_content(txt)
+ if not self.verbose_tool_use:
+ txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
+
+ if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
+
+ return txt
+
+ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
+ total_tokens = 0
+ if self.is_tools and self.mdl.is_tools:
+ stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
+ else:
+ stream_fn = getattr(self.mdl, "async_chat_streamly", None)
+
+ if stream_fn:
+ chat_partial = partial(stream_fn, system, history, gen_conf)
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+ async for txt in chat_partial(**use_kwargs):
+ if isinstance(txt, int):
+ total_tokens = txt
+ break
+ yield txt
+ if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+ return
+
+ chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+ queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
+ while True:
+ item = await queue.get()
+ if item is StopAsyncIteration:
+ break
+ if isinstance(item, Exception):
+ raise item
+ if isinstance(item, int):
+ total_tokens = item
+ break
+ yield item
+
+ if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 897fec65f..1913646a2 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum):
GiteeAI = "GiteeAI"
AI_302 = "302.AI"
JiekouAI = "Jiekou.AI"
+ ZHIPU_AI = "ZHIPU-AI"
FACTORY_DEFAULT_BASE_URL = {
@@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
+ SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
}
@@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/",
SupportedLiteLLMProvider.JiekouAI: "openai/",
+ SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
}
ChatModel = globals().get("ChatModel", {})
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index cce5b2454..e2a5c19cb 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -19,6 +19,7 @@ import logging
import os
import random
import re
+import threading
import time
from abc import ABC
from copy import deepcopy
@@ -28,10 +29,9 @@ import json_repair
import litellm
import openai
import requests
-from openai import OpenAI
+from openai import AsyncOpenAI, OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
-from zhipuai import ZhipuAI
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
@@ -68,6 +68,7 @@ class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
+ self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
@@ -134,6 +135,23 @@ class Base(ABC):
return gen_conf
+ def _bridge_sync_stream(self, gen):
+ """Run a sync generator in a thread and yield asynchronously."""
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+
+ def worker():
+ try:
+ for item in gen:
+ loop.call_soon_threadsafe(queue.put_nowait, item)
+ except Exception as exc: # pragma: no cover - defensive
+ loop.call_soon_threadsafe(queue.put_nowait, exc)
+ finally:
+ loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
+
+ threading.Thread(target=worker, daemon=True).start()
+ return queue
+
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
@@ -199,6 +217,60 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
+ async def _async_chat_stream(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ reasoning_start = False
+
+ request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
+ stop = kwargs.get("stop")
+ if stop:
+ request_kwargs["stop"] = stop
+
+ response = await self.async_client.chat.completions.create(**request_kwargs)
+
+ async for resp in response:
+ if not resp.choices:
+ continue
+ if not resp.choices[0].delta.content:
+ resp.choices[0].delta.content = ""
+ if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += resp.choices[0].delta.reasoning_content + ""
+ else:
+ reasoning_start = False
+ ans = resp.choices[0].delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ tol = num_tokens_from_string(resp.choices[0].delta.content)
+
+ finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
+ if finish_reason == "length":
+ if is_chinese(ans):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ yield ans, tol
+
+ async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ gen_conf = self._clean_conf(gen_conf)
+ ans = ""
+ total_tokens = 0
+ try:
+ async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
+ ans = delta_ans
+ total_tokens += tol
+ yield delta_ans
+ except openai.APIError as e:
+ yield ans + "\n**ERROR**: " + str(e)
+
+ yield total_tokens
+
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@@ -227,7 +299,25 @@ class Base(ABC):
time.sleep(delay)
return None
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"sync base giving up: {msg}")
+ return msg
+
+ async def _exceptions_async(self, e, attempt) -> str | None:
+ logging.exception("OpenAI async completion")
+ error_code = self._classify_error(e)
+ if attempt == self.max_retries:
+ error_code = LLMErrorCode.ERROR_MAX_RETRIES
+
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ await asyncio.sleep(delay)
+ return None
+
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"async base giving up: {msg}")
+ return msg
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""
@@ -318,6 +408,60 @@ class Base(ABC):
assert False, "Shouldn't be here."
+ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ gen_conf = self._clean_conf(gen_conf)
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+
+ ans = ""
+ tk_count = 0
+ hist = deepcopy(history)
+ for attempt in range(self.max_retries + 1):
+ history = deepcopy(hist)
+ try:
+ for _ in range(self.max_rounds + 1):
+ logging.info(f"{self.tools=}")
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
+ tk_count += total_token_count_from_response(response)
+ if any([not response.choices, not response.choices[0].message]):
+ raise Exception(f"500 response structure error. Response: {response}")
+
+ if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
+ if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
+ ans += "" + response.choices[0].message.reasoning_content + ""
+
+ ans += response.choices[0].message.content
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+
+ return ans, tk_count
+
+ for tool_call in response.choices[0].message.tool_calls:
+ logging.info(f"Response {tool_call=}")
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
+ history = self._append_history(history, tool_call, tool_response)
+ ans += self._verbose_tool_use(name, args, tool_response)
+ except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ ans += self._verbose_tool_use(name, {}, str(e))
+
+ logging.warning(f"Exceed max rounds: {self.max_rounds}")
+ history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
+ response, token_count = await self._async_chat(history, gen_conf)
+ ans += response
+ tk_count += token_count
+ return ans, tk_count
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, tk_count
+
+ assert False, "Shouldn't be here."
+
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -452,6 +596,160 @@ class Base(ABC):
assert False, "Shouldn't be here."
+ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ gen_conf = self._clean_conf(gen_conf)
+ tools = self.tools
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+
+ total_tokens = 0
+ hist = deepcopy(history)
+
+ for attempt in range(self.max_retries + 1):
+ history = deepcopy(hist)
+ try:
+ for _ in range(self.max_rounds + 1):
+ reasoning_start = False
+ logging.info(f"{tools=}")
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
+
+ final_tool_calls = {}
+ answer = ""
+
+ async for resp in response:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+
+ delta = resp.choices[0].delta
+
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ index = tool_call.index
+ if index not in final_tool_calls:
+ if not tool_call.function.arguments:
+ tool_call.function.arguments = ""
+ final_tool_calls[index] = tool_call
+ else:
+ final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
+ continue
+
+ if not hasattr(delta, "content") or delta.content is None:
+ delta.content = ""
+
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += delta.reasoning_content + ""
+ yield ans
+ else:
+ reasoning_start = False
+ answer += delta.content
+ yield delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ total_tokens += num_tokens_from_string(delta.content)
+ else:
+ total_tokens = tol
+
+ finish_reason = getattr(resp.choices[0], "finish_reason", "")
+ if finish_reason == "length":
+ yield self._length_stop("")
+
+ if answer:
+ yield total_tokens
+ return
+
+ for tool_call in final_tool_calls.values():
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
+ yield self._verbose_tool_use(name, args, "Begin to call...")
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
+ history = self._append_history(history, tool_call, tool_response)
+ yield self._verbose_tool_use(name, args, tool_response)
+ except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ yield self._verbose_tool_use(name, {}, str(e))
+
+ logging.warning(f"Exceed max rounds: {self.max_rounds}")
+ history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
+
+ async for resp in response:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+ delta = resp.choices[0].delta
+ if not hasattr(delta, "content") or delta.content is None:
+ continue
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ total_tokens += num_tokens_from_string(delta.content)
+ else:
+ total_tokens = tol
+ yield delta.content
+
+ yield total_tokens
+ return
+
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ logging.error(f"async_chat_streamly failed: {e}")
+ yield e
+ yield total_tokens
+ return
+
+ assert False, "Shouldn't be here."
+
+ async def _async_chat(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
+ if self.model_name.lower().find("qwq") >= 0:
+ logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
+ final_ans = ""
+ tol_token = 0
+ async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
+ if delta.startswith("") or delta.endswith(""):
+ continue
+ final_ans += delta
+ tol_token = tol
+
+ if len(final_ans.strip()) == 0:
+ final_ans = "**ERROR**: Empty response from reasoning model"
+
+ return final_ans.strip(), tol_token
+
+ if self.model_name.lower().find("qwen3") >= 0:
+ kwargs["extra_body"] = {"enable_thinking": False}
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
+
+ if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
+ return "", 0
+ ans = response.choices[0].message.content.strip()
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+ return ans, total_token_count_from_response(response)
+
+ async def async_chat(self, system, history, gen_conf={}, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ gen_conf = self._clean_conf(gen_conf)
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ return await self._async_chat(history, gen_conf, **kwargs)
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, 0
+ assert False, "Shouldn't be here."
+
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -637,66 +935,6 @@ class BaiChuanChat(Base):
yield total_tokens
-class ZhipuChat(Base):
- _FACTORY_NAME = "ZHIPU-AI"
-
- def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
- super().__init__(key, model_name, base_url=base_url, **kwargs)
-
- self.client = ZhipuAI(api_key=key)
- self.model_name = model_name
-
- def _clean_conf(self, gen_conf):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- gen_conf = self._clean_conf_plealty(gen_conf)
- return gen_conf
-
- def _clean_conf_plealty(self, gen_conf):
- if "presence_penalty" in gen_conf:
- del gen_conf["presence_penalty"]
- if "frequency_penalty" in gen_conf:
- del gen_conf["frequency_penalty"]
- return gen_conf
-
- def chat_with_tools(self, system: str, history: list, gen_conf: dict):
- gen_conf = self._clean_conf_plealty(gen_conf)
-
- return super().chat_with_tools(system, history, gen_conf)
-
- def chat_streamly(self, system, history, gen_conf={}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
- ans = ""
- tk_count = 0
- try:
- logging.info(json.dumps(history, ensure_ascii=False, indent=2))
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
- for resp in response:
- if not resp.choices[0].delta.content:
- continue
- delta = resp.choices[0].delta.content
- ans = delta
- if resp.choices[0].finish_reason == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- tk_count = total_token_count_from_response(resp)
- if resp.choices[0].finish_reason == "stop":
- tk_count = total_token_count_from_response(resp)
- yield ans
- except Exception as e:
- yield ans + "\n**ERROR**: " + str(e)
-
- yield tk_count
-
- def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
- gen_conf = self._clean_conf_plealty(gen_conf)
- return super().chat_streamly_with_tools(system, history, gen_conf)
-
-
class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"
@@ -1398,6 +1636,7 @@ class LiteLLMBase(ABC):
"GiteeAI",
"302.AI",
"Jiekou.AI",
+ "ZHIPU-AI",
]
def __init__(self, key, model_name, base_url=None, **kwargs):
@@ -1477,6 +1716,7 @@ class LiteLLMBase(ABC):
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
@@ -1520,6 +1760,96 @@ class LiteLLMBase(ABC):
yield ans, tol
+ async def async_chat(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
+ if self.model_name.lower().find("qwen3") >= 0:
+ kwargs["extra_body"] = {"enable_thinking": False}
+
+ completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ response = await litellm.acompletion(
+ **completion_args,
+ drop_params=True,
+ timeout=self.timeout,
+ )
+
+ if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
+ return "", 0
+ ans = response.choices[0].message.content.strip()
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+
+ return ans, total_token_count_from_response(response)
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, 0
+
+ assert False, "Shouldn't be here."
+
+ async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ gen_conf = self._clean_conf(gen_conf)
+ reasoning_start = False
+ total_tokens = 0
+
+ completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
+ stop = kwargs.get("stop")
+ if stop:
+ completion_args["stop"] = stop
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ stream = await litellm.acompletion(
+ **completion_args,
+ drop_params=True,
+ timeout=self.timeout,
+ )
+
+ async for resp in stream:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+
+ delta = resp.choices[0].delta
+ if not hasattr(delta, "content") or delta.content is None:
+ delta.content = ""
+
+ if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += delta.reasoning_content + ""
+ else:
+ reasoning_start = False
+ ans = delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ tol = num_tokens_from_string(delta.content)
+ total_tokens += tol
+
+ finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
+ if finish_reason == "length":
+ if is_chinese(ans):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+
+ yield ans
+ yield total_tokens
+ return
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ yield e
+ yield total_tokens
+ return
+
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@@ -1550,6 +1880,21 @@ class LiteLLMBase(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ async def _exceptions_async(self, e, attempt) -> str | None:
+ logging.exception("LiteLLMBase async completion")
+ error_code = self._classify_error(e)
+ if attempt == self.max_retries:
+ error_code = LLMErrorCode.ERROR_MAX_RETRIES
+
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ await asyncio.sleep(delay)
+ return None
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"async_chat_streamly giving up: {msg}")
+ return msg
+
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""