add async for agent with tools

This commit is contained in:
yongtenglei 2025-12-02 16:53:18 +08:00
parent e9356db849
commit 4e8a8348f1
5 changed files with 233 additions and 12 deletions

View file

@ -415,13 +415,18 @@ class Canvas(Graph):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
tasks = [] tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f i = f
while i < t: while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
task_fn = None task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]: if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
i += 1 i += 1
else: else:
for _, ele in cpn.get_input_elements().items(): for _, ele in cpn.get_input_elements().items():
@ -430,13 +435,18 @@ class Canvas(Graph):
t -= 1 t -= 1
break break
else: else:
task_fn = partial(cpn.invoke, **cpn.get_input()) call_kwargs = cpn.get_input()
task_fn = cpn.invoke
i += 1 i += 1
if task_fn is None: if task_fn is None:
continue continue
tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks: if tasks:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
return ans return ans
async def invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = "" answer_without_toolcall = ""
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
if use_tools: if use_tools:
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text): def _gen_citations(self, text):
retrievals = self._canvas.get_reference() retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys(): for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None self._param.inputs[k]["value"] = None
self._param.debug_inputs = {} self._param.debug_inputs = {}

View file

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import re import re
import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, Generator from typing import Any, Generator, AsyncGenerator
import json_repair import json_repair
from functools import partial from functools import partial
from common.constants import LLMType from common.constants import LLMType
@ -171,6 +173,13 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = "" ans = ""
last_idx = 0 last_idx = 0
@ -205,6 +214,70 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
# Prefer async chat_streamly if available
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# Fallback: run sync stream in thread, bridge results
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg): async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = "" answer = ""

View file

@ -52,10 +52,6 @@ class LLMToolPluginCallSession(ToolCallSession):
assert name in self.tools_map, f"LLM tool {name} does not exist" assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer() st = timer()
tool_obj = self.tools_map[name] tool_obj = self.tools_map[name]
print("#########################################", flush=True)
tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__
logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}")
print("#########################################", flush=True)
if isinstance(tool_obj, MCPToolCallSession): if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60) resp = tool_obj.tool_call(name, arguments, 60)
else: else:
@ -63,9 +59,6 @@ class LLMToolPluginCallSession(ToolCallSession):
resp = asyncio.run(tool_obj.invoke_async(**arguments)) resp = asyncio.run(tool_obj.invoke_async(**arguments))
else: else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments)) resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
print("#########################################", flush=True)
logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s")
print("#########################################", flush=True)
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import logging import logging
@ -360,6 +361,10 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
return kwd return kwd
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description: if not tools_description:
return "" return ""
@ -378,6 +383,10 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
return json_str, tk_cnt return json_str, tk_cnt
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"] goal = history[1]["content"]
@ -429,6 +438,14 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {} meta_data_structure = {}
for key, values in meta_data.items(): for key, values in meta_data.items():