chats and agents
This commit is contained in:
parent
a4ff443d0b
commit
8b894a29c4
7 changed files with 654 additions and 120 deletions
109
agent/canvas.py
109
agent/canvas.py
|
|
@ -13,7 +13,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -79,6 +81,7 @@ class Graph:
|
||||||
self.dsl = json.loads(dsl)
|
self.dsl = json.loads(dsl)
|
||||||
self._tenant_id = tenant_id
|
self._tenant_id = tenant_id
|
||||||
self.task_id = task_id if task_id else get_uuid()
|
self.task_id = task_id if task_id else get_uuid()
|
||||||
|
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
|
@ -328,6 +331,7 @@ class Canvas(Graph):
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
self.message_id = get_uuid()
|
self.message_id = get_uuid()
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
self.add_user_input(kwargs.get("query"))
|
self.add_user_input(kwargs.get("query"))
|
||||||
|
|
@ -343,7 +347,7 @@ class Canvas(Graph):
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||||
if k == "files":
|
if k == "files":
|
||||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||||
else:
|
else:
|
||||||
self.globals[f"sys.{k}"] = kwargs[k]
|
self.globals[f"sys.{k}"] = kwargs[k]
|
||||||
if not self.globals["sys.conversation_turns"] :
|
if not self.globals["sys.conversation_turns"] :
|
||||||
|
|
@ -373,31 +377,39 @@ class Canvas(Graph):
|
||||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||||
|
|
||||||
def _run_batch(f, t):
|
async def _run_batch(f, t):
|
||||||
if self.is_canceled():
|
if self.is_canceled():
|
||||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
raise TaskCanceledException(msg)
|
raise TaskCanceledException(msg)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
loop = asyncio.get_running_loop()
|
||||||
thr = []
|
tasks = []
|
||||||
i = f
|
i = f
|
||||||
while i < t:
|
while i < t:
|
||||||
cpn = self.get_component_obj(self.path[i])
|
cpn = self.get_component_obj(self.path[i])
|
||||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
task_fn = None
|
||||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
|
||||||
i += 1
|
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||||
|
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
for _, ele in cpn.get_input_elements().items():
|
||||||
|
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||||
|
self.path.pop(i)
|
||||||
|
t -= 1
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
for _, ele in cpn.get_input_elements().items():
|
task_fn = partial(cpn.invoke, **cpn.get_input())
|
||||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
i += 1
|
||||||
self.path.pop(i)
|
|
||||||
t -= 1
|
if task_fn is None:
|
||||||
break
|
continue
|
||||||
else:
|
|
||||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
|
||||||
i += 1
|
|
||||||
for t in thr:
|
if tasks:
|
||||||
t.result()
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def _node_finished(cpn_obj):
|
def _node_finished(cpn_obj):
|
||||||
return decorate("node_finished",{
|
return decorate("node_finished",{
|
||||||
|
|
@ -424,7 +436,7 @@ class Canvas(Graph):
|
||||||
"component_type": self.get_component_type(self.path[i]),
|
"component_type": self.get_component_type(self.path[i]),
|
||||||
"thoughts": self.get_component_thoughts(self.path[i])
|
"thoughts": self.get_component_thoughts(self.path[i])
|
||||||
})
|
})
|
||||||
_run_batch(idx, to)
|
await _run_batch(idx, to)
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
|
|
@ -433,16 +445,29 @@ class Canvas(Graph):
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
for m in cpn_obj.output("content")():
|
stream = cpn_obj.output("content")()
|
||||||
if not m:
|
if inspect.isasyncgen(stream):
|
||||||
continue
|
async for m in stream:
|
||||||
if m == "<think>":
|
if not m:
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
continue
|
||||||
elif m == "</think>":
|
if m == "<think>":
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
yield decorate("message", {"content": "", "start_to_think": True})
|
||||||
else:
|
elif m == "</think>":
|
||||||
yield decorate("message", {"content": m})
|
yield decorate("message", {"content": "", "end_to_think": True})
|
||||||
_m += m
|
else:
|
||||||
|
yield decorate("message", {"content": m})
|
||||||
|
_m += m
|
||||||
|
else:
|
||||||
|
for m in stream:
|
||||||
|
if not m:
|
||||||
|
continue
|
||||||
|
if m == "<think>":
|
||||||
|
yield decorate("message", {"content": "", "start_to_think": True})
|
||||||
|
elif m == "</think>":
|
||||||
|
yield decorate("message", {"content": "", "end_to_think": True})
|
||||||
|
else:
|
||||||
|
yield decorate("message", {"content": m})
|
||||||
|
_m += m
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
|
|
@ -590,21 +615,31 @@ class Canvas(Graph):
|
||||||
def get_component_input_elements(self, cpnnm):
|
def get_component_input_elements(self, cpnnm):
|
||||||
return self.components[cpnnm]["obj"].get_input_elements()
|
return self.components[cpnnm]["obj"].get_input_elements()
|
||||||
|
|
||||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
def image_to_base64(file):
|
def image_to_base64(file):
|
||||||
return "data:{};base64,{}".format(file["mime_type"],
|
return "data:{};base64,{}".format(file["mime_type"],
|
||||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||||
exe = ThreadPoolExecutor(max_workers=5)
|
loop = asyncio.get_running_loop()
|
||||||
threads = []
|
tasks = []
|
||||||
for file in files:
|
for file in files:
|
||||||
if file["mime_type"].find("image") >=0:
|
if file["mime_type"].find("image") >=0:
|
||||||
threads.append(exe.submit(image_to_base64, file))
|
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||||
continue
|
continue
|
||||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||||
return [th.result() for th in threads]
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||||
|
"""
|
||||||
|
loop = getattr(self, "_loop", None)
|
||||||
|
if loop and loop.is_running():
|
||||||
|
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||||
|
|
||||||
|
return asyncio.run(self.get_files_async(files))
|
||||||
|
|
||||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||||
agent_ids = agent_id.split("-->")
|
agent_ids = agent_id.split("-->")
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,55 @@ class LLM(ComponentBase):
|
||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||||
yield delta(txt)
|
yield delta(txt)
|
||||||
|
|
||||||
|
async def _stream_output_async(self, prompt, msg):
|
||||||
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
|
answer = ""
|
||||||
|
last_idx = 0
|
||||||
|
endswith_think = False
|
||||||
|
|
||||||
|
def delta(txt):
|
||||||
|
nonlocal answer, last_idx, endswith_think
|
||||||
|
delta_ans = txt[last_idx:]
|
||||||
|
answer = txt
|
||||||
|
|
||||||
|
if delta_ans.find("<think>") == 0:
|
||||||
|
last_idx += len("<think>")
|
||||||
|
return "<think>"
|
||||||
|
elif delta_ans.find("<think>") > 0:
|
||||||
|
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||||
|
last_idx += delta_ans.find("<think>")
|
||||||
|
return delta_ans
|
||||||
|
elif delta_ans.endswith("</think>"):
|
||||||
|
endswith_think = True
|
||||||
|
elif endswith_think:
|
||||||
|
endswith_think = False
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
last_idx = len(answer)
|
||||||
|
if answer.endswith("</think>"):
|
||||||
|
last_idx -= len("</think>")
|
||||||
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||||
|
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||||
|
if self.check_if_canceled("LLM streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(ans, int):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_output("content", self.get_exception_default_value())
|
||||||
|
yield self.get_exception_default_value()
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", ans)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield delta(ans)
|
||||||
|
|
||||||
|
self.set_output("content", answer)
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
|
|
@ -250,7 +299,7 @@ class LLM(ComponentBase):
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
self.set_output("content", partial(self._stream_output_async, prompt, msg))
|
||||||
return
|
return
|
||||||
|
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
@ -66,8 +68,12 @@ class Message(ComponentBase):
|
||||||
v = ""
|
v = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
for t in v():
|
iter_obj = v()
|
||||||
ans += t
|
if inspect.isasyncgen(iter_obj):
|
||||||
|
ans = asyncio.run(self._consume_async_gen(iter_obj))
|
||||||
|
else:
|
||||||
|
for t in iter_obj:
|
||||||
|
ans += t
|
||||||
elif isinstance(v, list) and delimiter:
|
elif isinstance(v, list) and delimiter:
|
||||||
ans = delimiter.join([str(vv) for vv in v])
|
ans = delimiter.join([str(vv) for vv in v])
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
|
|
@ -89,7 +95,13 @@ class Message(ComponentBase):
|
||||||
_kwargs[_n] = v
|
_kwargs[_n] = v
|
||||||
return script, _kwargs
|
return script, _kwargs
|
||||||
|
|
||||||
def _stream(self, rand_cnt:str):
|
async def _consume_async_gen(self, agen):
|
||||||
|
buf = ""
|
||||||
|
async for t in agen:
|
||||||
|
buf += t
|
||||||
|
return buf
|
||||||
|
|
||||||
|
async def _stream(self, rand_cnt:str):
|
||||||
s = 0
|
s = 0
|
||||||
all_content = ""
|
all_content = ""
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
@ -111,15 +123,27 @@ class Message(ComponentBase):
|
||||||
v = ""
|
v = ""
|
||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
cnt = ""
|
cnt = ""
|
||||||
for t in v():
|
iter_obj = v()
|
||||||
if self.check_if_canceled("Message streaming"):
|
if inspect.isasyncgen(iter_obj):
|
||||||
return
|
async for t in iter_obj:
|
||||||
|
if self.check_if_canceled("Message streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
all_content += t
|
all_content += t
|
||||||
cnt += t
|
cnt += t
|
||||||
yield t
|
yield t
|
||||||
|
else:
|
||||||
|
for t in iter_obj:
|
||||||
|
if self.check_if_canceled("Message streaming"):
|
||||||
|
return
|
||||||
|
|
||||||
|
all_content += t
|
||||||
|
cnt += t
|
||||||
|
yield t
|
||||||
self.set_input_value(exp, cnt)
|
self.set_input_value(exp, cnt)
|
||||||
continue
|
continue
|
||||||
|
elif inspect.isawaitable(v):
|
||||||
|
v = await v
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
try:
|
try:
|
||||||
v = json.dumps(v, ensure_ascii=False)
|
v = json.dumps(v, ensure_ascii=False)
|
||||||
|
|
@ -181,7 +205,7 @@ class Message(ComponentBase):
|
||||||
|
|
||||||
import pypandoc
|
import pypandoc
|
||||||
doc_id = get_uuid()
|
doc_id = get_uuid()
|
||||||
|
|
||||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||||
self._param.output_format = "markdown"
|
self._param.output_format = "markdown"
|
||||||
|
|
||||||
|
|
@ -231,11 +255,11 @@ class Message(ComponentBase):
|
||||||
|
|
||||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||||
self.set_output("attachment", {
|
self.set_output("attachment", {
|
||||||
"doc_id":doc_id,
|
"doc_id":doc_id,
|
||||||
"format":self._param.output_format,
|
"format":self._param.output_format,
|
||||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||||
|
|
||||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -134,12 +135,12 @@ async def run():
|
||||||
files = req.get("files", [])
|
files = req.get("files", [])
|
||||||
inputs = req.get("inputs", {})
|
inputs = req.get("inputs", {})
|
||||||
user_id = req.get("user_id", current_user.id)
|
user_id = req.get("user_id", current_user.id)
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
|
@ -149,7 +150,7 @@ async def run():
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
task_id = get_uuid()
|
task_id = get_uuid()
|
||||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
|
||||||
if not ok:
|
if not ok:
|
||||||
return get_data_error_result(message=error_message)
|
return get_data_error_result(message=error_message)
|
||||||
return get_json_result(data={"message_id": task_id})
|
return get_json_result(data={"message_id": task_id})
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant):
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
|
|
@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant):
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
if total_tokens > 0:
|
if total_tokens > 0:
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
|
||||||
|
def _bridge_sync_stream(self, gen):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for item in gen:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
||||||
|
threading.Thread(target=worker, daemon=True).start()
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
|
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||||
|
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
||||||
|
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||||
|
|
||||||
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
|
||||||
|
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
||||||
|
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
||||||
|
elif hasattr(self.mdl, "async_chat"):
|
||||||
|
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
||||||
|
else:
|
||||||
|
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
||||||
|
|
||||||
|
txt = self._remove_reasoning_content(txt)
|
||||||
|
if not self.verbose_tool_use:
|
||||||
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
|
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
|
return txt
|
||||||
|
|
||||||
|
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
|
total_tokens = 0
|
||||||
|
if self.is_tools and self.mdl.is_tools:
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
|
else:
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
|
||||||
|
if stream_fn:
|
||||||
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
async for txt in chat_partial(**use_kwargs):
|
||||||
|
if isinstance(txt, int):
|
||||||
|
total_tokens = txt
|
||||||
|
break
|
||||||
|
yield txt
|
||||||
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
||||||
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is StopAsyncIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
if isinstance(item, int):
|
||||||
|
total_tokens = item
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||||
GiteeAI = "GiteeAI"
|
GiteeAI = "GiteeAI"
|
||||||
AI_302 = "302.AI"
|
AI_302 = "302.AI"
|
||||||
JiekouAI = "Jiekou.AI"
|
JiekouAI = "Jiekou.AI"
|
||||||
|
ZHIPU_AI = "ZHIPU-AI"
|
||||||
|
|
||||||
|
|
||||||
FACTORY_DEFAULT_BASE_URL = {
|
FACTORY_DEFAULT_BASE_URL = {
|
||||||
|
|
@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||||
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
||||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||||
|
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = {
|
||||||
SupportedLiteLLMProvider.GiteeAI: "openai/",
|
SupportedLiteLLMProvider.GiteeAI: "openai/",
|
||||||
SupportedLiteLLMProvider.AI_302: "openai/",
|
SupportedLiteLLMProvider.AI_302: "openai/",
|
||||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||||
|
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = globals().get("ChatModel", {})
|
ChatModel = globals().get("ChatModel", {})
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
@ -28,10 +29,9 @@ import json_repair
|
||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI, OpenAI
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
from strenum import StrEnum
|
from strenum import StrEnum
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
|
|
||||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||||
from rag.nlp import is_chinese, is_english
|
from rag.nlp import is_chinese, is_english
|
||||||
|
|
@ -68,6 +68,7 @@ class Base(ABC):
|
||||||
def __init__(self, key, model_name, base_url, **kwargs):
|
def __init__(self, key, model_name, base_url, **kwargs):
|
||||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
# Configure retry parameters
|
# Configure retry parameters
|
||||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||||
|
|
@ -134,6 +135,23 @@ class Base(ABC):
|
||||||
|
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
|
def _bridge_sync_stream(self, gen):
|
||||||
|
"""Run a sync generator in a thread and yield asynchronously."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for item in gen:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
||||||
|
threading.Thread(target=worker, daemon=True).start()
|
||||||
|
return queue
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
def _chat(self, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
if self.model_name.lower().find("qwq") >= 0:
|
||||||
|
|
@ -199,6 +217,60 @@ class Base(ABC):
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
yield ans, tol
|
yield ans, tol
|
||||||
|
|
||||||
|
async def _async_chat_stream(self, history, gen_conf, **kwargs):
|
||||||
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
|
reasoning_start = False
|
||||||
|
|
||||||
|
request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
|
||||||
|
stop = kwargs.get("stop")
|
||||||
|
if stop:
|
||||||
|
request_kwargs["stop"] = stop
|
||||||
|
|
||||||
|
response = await self.async_client.chat.completions.create(**request_kwargs)
|
||||||
|
|
||||||
|
async for resp in response:
|
||||||
|
if not resp.choices:
|
||||||
|
continue
|
||||||
|
if not resp.choices[0].delta.content:
|
||||||
|
resp.choices[0].delta.content = ""
|
||||||
|
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||||
|
ans = ""
|
||||||
|
if not reasoning_start:
|
||||||
|
reasoning_start = True
|
||||||
|
ans = "<think>"
|
||||||
|
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||||
|
else:
|
||||||
|
reasoning_start = False
|
||||||
|
ans = resp.choices[0].delta.content
|
||||||
|
|
||||||
|
tol = total_token_count_from_response(resp)
|
||||||
|
if not tol:
|
||||||
|
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
|
|
||||||
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||||
|
if finish_reason == "length":
|
||||||
|
if is_chinese(ans):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
yield ans, tol
|
||||||
|
|
||||||
|
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||||
|
if system and history and history[0].get("role") != "system":
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
ans = ""
|
||||||
|
total_tokens = 0
|
||||||
|
try:
|
||||||
|
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
|
||||||
|
ans = delta_ans
|
||||||
|
total_tokens += tol
|
||||||
|
yield delta_ans
|
||||||
|
except openai.APIError as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield total_tokens
|
||||||
|
|
||||||
def _length_stop(self, ans):
|
def _length_stop(self, ans):
|
||||||
if is_chinese([ans]):
|
if is_chinese([ans]):
|
||||||
return ans + LENGTH_NOTIFICATION_CN
|
return ans + LENGTH_NOTIFICATION_CN
|
||||||
|
|
@ -227,7 +299,25 @@ class Base(ABC):
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
logging.error(f"sync base giving up: {msg}")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def _exceptions_async(self, e, attempt) -> str | None:
|
||||||
|
logging.exception("OpenAI async completion")
|
||||||
|
error_code = self._classify_error(e)
|
||||||
|
if attempt == self.max_retries:
|
||||||
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||||
|
|
||||||
|
if self._should_retry(error_code):
|
||||||
|
delay = self._get_delay()
|
||||||
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
return None
|
||||||
|
|
||||||
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
logging.error(f"async base giving up: {msg}")
|
||||||
|
return msg
|
||||||
|
|
||||||
def _verbose_tool_use(self, name, args, res):
|
def _verbose_tool_use(self, name, args, res):
|
||||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||||
|
|
@ -318,6 +408,60 @@ class Base(ABC):
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
if system and history and history[0].get("role") != "system":
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
hist = deepcopy(history)
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
history = deepcopy(hist)
|
||||||
|
try:
|
||||||
|
for _ in range(self.max_rounds + 1):
|
||||||
|
logging.info(f"{self.tools=}")
|
||||||
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||||
|
tk_count += total_token_count_from_response(response)
|
||||||
|
if any([not response.choices, not response.choices[0].message]):
|
||||||
|
raise Exception(f"500 response structure error. Response: {response}")
|
||||||
|
|
||||||
|
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
||||||
|
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
||||||
|
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
||||||
|
|
||||||
|
ans += response.choices[0].message.content
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
ans = self._length_stop(ans)
|
||||||
|
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
logging.info(f"Response {tool_call=}")
|
||||||
|
name = tool_call.function.name
|
||||||
|
try:
|
||||||
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
|
ans += self._verbose_tool_use(name, args, tool_response)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
|
ans += self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
response, token_count = await self._async_chat(history, gen_conf)
|
||||||
|
ans += response
|
||||||
|
tk_count += token_count
|
||||||
|
return ans, tk_count
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
return e, tk_count
|
||||||
|
|
||||||
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
@ -452,6 +596,160 @@ class Base(ABC):
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
tools = self.tools
|
||||||
|
if system and history and history[0].get("role") != "system":
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
hist = deepcopy(history)
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
history = deepcopy(hist)
|
||||||
|
try:
|
||||||
|
for _ in range(self.max_rounds + 1):
|
||||||
|
reasoning_start = False
|
||||||
|
logging.info(f"{tools=}")
|
||||||
|
|
||||||
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||||
|
|
||||||
|
final_tool_calls = {}
|
||||||
|
answer = ""
|
||||||
|
|
||||||
|
async for resp in response:
|
||||||
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = resp.choices[0].delta
|
||||||
|
|
||||||
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||||
|
for tool_call in delta.tool_calls:
|
||||||
|
index = tool_call.index
|
||||||
|
if index not in final_tool_calls:
|
||||||
|
if not tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments = ""
|
||||||
|
final_tool_calls[index] = tool_call
|
||||||
|
else:
|
||||||
|
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(delta, "content") or delta.content is None:
|
||||||
|
delta.content = ""
|
||||||
|
|
||||||
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||||
|
ans = ""
|
||||||
|
if not reasoning_start:
|
||||||
|
reasoning_start = True
|
||||||
|
ans = "<think>"
|
||||||
|
ans += delta.reasoning_content + "</think>"
|
||||||
|
yield ans
|
||||||
|
else:
|
||||||
|
reasoning_start = False
|
||||||
|
answer += delta.content
|
||||||
|
yield delta.content
|
||||||
|
|
||||||
|
tol = total_token_count_from_response(resp)
|
||||||
|
if not tol:
|
||||||
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
|
else:
|
||||||
|
total_tokens = tol
|
||||||
|
|
||||||
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||||
|
if finish_reason == "length":
|
||||||
|
yield self._length_stop("")
|
||||||
|
|
||||||
|
if answer:
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
for tool_call in final_tool_calls.values():
|
||||||
|
name = tool_call.function.name
|
||||||
|
try:
|
||||||
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
|
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||||
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
|
yield self._verbose_tool_use(name, args, tool_response)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
|
yield self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||||
|
|
||||||
|
async for resp in response:
|
||||||
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
|
continue
|
||||||
|
delta = resp.choices[0].delta
|
||||||
|
if not hasattr(delta, "content") or delta.content is None:
|
||||||
|
continue
|
||||||
|
tol = total_token_count_from_response(resp)
|
||||||
|
if not tol:
|
||||||
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
|
else:
|
||||||
|
total_tokens = tol
|
||||||
|
yield delta.content
|
||||||
|
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
logging.error(f"async_chat_streamly failed: {e}")
|
||||||
|
yield e
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
|
async def _async_chat(self, history, gen_conf, **kwargs):
|
||||||
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
|
if self.model_name.lower().find("qwq") >= 0:
|
||||||
|
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
||||||
|
final_ans = ""
|
||||||
|
tol_token = 0
|
||||||
|
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
|
||||||
|
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||||
|
continue
|
||||||
|
final_ans += delta
|
||||||
|
tol_token = tol
|
||||||
|
|
||||||
|
if len(final_ans.strip()) == 0:
|
||||||
|
final_ans = "**ERROR**: Empty response from reasoning model"
|
||||||
|
|
||||||
|
return final_ans.strip(), tol_token
|
||||||
|
|
||||||
|
if self.model_name.lower().find("qwen3") >= 0:
|
||||||
|
kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
|
|
||||||
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||||
|
|
||||||
|
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||||
|
return "", 0
|
||||||
|
ans = response.choices[0].message.content.strip()
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
ans = self._length_stop(ans)
|
||||||
|
return ans, total_token_count_from_response(response)
|
||||||
|
|
||||||
|
async def async_chat(self, system, history, gen_conf={}, **kwargs):
|
||||||
|
if system and history and history[0].get("role") != "system":
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await self._async_chat(history, gen_conf, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
return e, 0
|
||||||
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
@ -637,66 +935,6 @@ class BaiChuanChat(Base):
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class ZhipuChat(Base):
|
|
||||||
_FACTORY_NAME = "ZHIPU-AI"
|
|
||||||
|
|
||||||
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
|
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
|
||||||
|
|
||||||
self.client = ZhipuAI(api_key=key)
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
def _clean_conf(self, gen_conf):
|
|
||||||
if "max_tokens" in gen_conf:
|
|
||||||
del gen_conf["max_tokens"]
|
|
||||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
|
||||||
return gen_conf
|
|
||||||
|
|
||||||
def _clean_conf_plealty(self, gen_conf):
|
|
||||||
if "presence_penalty" in gen_conf:
|
|
||||||
del gen_conf["presence_penalty"]
|
|
||||||
if "frequency_penalty" in gen_conf:
|
|
||||||
del gen_conf["frequency_penalty"]
|
|
||||||
return gen_conf
|
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
|
||||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
|
||||||
|
|
||||||
return super().chat_with_tools(system, history, gen_conf)
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
try:
|
|
||||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
for resp in response:
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
continue
|
|
||||||
delta = resp.choices[0].delta.content
|
|
||||||
ans = delta
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
tk_count = total_token_count_from_response(resp)
|
|
||||||
if resp.choices[0].finish_reason == "stop":
|
|
||||||
tk_count = total_token_count_from_response(resp)
|
|
||||||
yield ans
|
|
||||||
except Exception as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield tk_count
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
|
||||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
|
||||||
return super().chat_streamly_with_tools(system, history, gen_conf)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAIChat(Base):
|
class LocalAIChat(Base):
|
||||||
_FACTORY_NAME = "LocalAI"
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
|
|
@ -1398,6 +1636,7 @@ class LiteLLMBase(ABC):
|
||||||
"GiteeAI",
|
"GiteeAI",
|
||||||
"302.AI",
|
"302.AI",
|
||||||
"Jiekou.AI",
|
"Jiekou.AI",
|
||||||
|
"ZHIPU-AI",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
|
|
@ -1477,6 +1716,7 @@ class LiteLLMBase(ABC):
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
||||||
|
|
@ -1520,6 +1760,96 @@ class LiteLLMBase(ABC):
|
||||||
|
|
||||||
yield ans, tol
|
yield ans, tol
|
||||||
|
|
||||||
|
async def async_chat(self, history, gen_conf, **kwargs):
|
||||||
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
|
if self.model_name.lower().find("qwen3") >= 0:
|
||||||
|
kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
|
|
||||||
|
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
**completion_args,
|
||||||
|
drop_params=True,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||||
|
return "", 0
|
||||||
|
ans = response.choices[0].message.content.strip()
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
ans = self._length_stop(ans)
|
||||||
|
|
||||||
|
return ans, total_token_count_from_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
return e, 0
|
||||||
|
|
||||||
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
|
async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
|
||||||
|
if system and history and history[0].get("role") != "system":
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
reasoning_start = False
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
||||||
|
stop = kwargs.get("stop")
|
||||||
|
if stop:
|
||||||
|
completion_args["stop"] = stop
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
stream = await litellm.acompletion(
|
||||||
|
**completion_args,
|
||||||
|
drop_params=True,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for resp in stream:
|
||||||
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = resp.choices[0].delta
|
||||||
|
if not hasattr(delta, "content") or delta.content is None:
|
||||||
|
delta.content = ""
|
||||||
|
|
||||||
|
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||||
|
ans = ""
|
||||||
|
if not reasoning_start:
|
||||||
|
reasoning_start = True
|
||||||
|
ans = "<think>"
|
||||||
|
ans += delta.reasoning_content + "</think>"
|
||||||
|
else:
|
||||||
|
reasoning_start = False
|
||||||
|
ans = delta.content
|
||||||
|
|
||||||
|
tol = total_token_count_from_response(resp)
|
||||||
|
if not tol:
|
||||||
|
tol = num_tokens_from_string(delta.content)
|
||||||
|
total_tokens += tol
|
||||||
|
|
||||||
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||||
|
if finish_reason == "length":
|
||||||
|
if is_chinese(ans):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
|
||||||
|
yield ans
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
yield e
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
def _length_stop(self, ans):
|
def _length_stop(self, ans):
|
||||||
if is_chinese([ans]):
|
if is_chinese([ans]):
|
||||||
return ans + LENGTH_NOTIFICATION_CN
|
return ans + LENGTH_NOTIFICATION_CN
|
||||||
|
|
@ -1550,6 +1880,21 @@ class LiteLLMBase(ABC):
|
||||||
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
|
||||||
|
async def _exceptions_async(self, e, attempt) -> str | None:
|
||||||
|
logging.exception("LiteLLMBase async completion")
|
||||||
|
error_code = self._classify_error(e)
|
||||||
|
if attempt == self.max_retries:
|
||||||
|
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||||
|
|
||||||
|
if self._should_retry(error_code):
|
||||||
|
delay = self._get_delay()
|
||||||
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
return None
|
||||||
|
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
logging.error(f"async_chat_streamly giving up: {msg}")
|
||||||
|
return msg
|
||||||
|
|
||||||
def _verbose_tool_use(self, name, args, res):
|
def _verbose_tool_use(self, name, args, res):
|
||||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue