From 4e8a8348f181400304feb4bf4b4a4628698cc0db Mon Sep 17 00:00:00 2001 From: yongtenglei Date: Tue, 2 Dec 2025 16:53:18 +0800 Subject: [PATCH] add async for agent with tools --- agent/canvas.py | 16 +++- agent/component/agent_with_tools.py | 130 +++++++++++++++++++++++++++- agent/component/llm.py | 75 +++++++++++++++- agent/tools/base.py | 7 -- rag/prompts/generator.py | 17 ++++ 5 files changed, 233 insertions(+), 12 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index c447b77b3..2dd0c8999 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -415,13 +415,18 @@ class Canvas(Graph): loop = asyncio.get_running_loop() tasks = [] + def _run_async_in_thread(coro_func, **call_kwargs): + return asyncio.run(coro_func(**call_kwargs)) + i = f while i < t: cpn = self.get_component_obj(self.path[i]) task_fn = None + call_kwargs = None if cpn.component_name.lower() in ["begin", "userfillup"]: - task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) + call_kwargs = {"inputs": kwargs.get("inputs", {})} + task_fn = cpn.invoke i += 1 else: for _, ele in cpn.get_input_elements().items(): @@ -430,13 +435,18 @@ class Canvas(Graph): t -= 1 break else: - task_fn = partial(cpn.invoke, **cpn.get_input()) + call_kwargs = cpn.get_input() + task_fn = cpn.invoke i += 1 if task_fn is None: continue - tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) + invoke_async = getattr(cpn, "invoke_async", None) + if invoke_async and asyncio.iscoroutinefunction(invoke_async): + tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {})))) + else: + tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {})))) if tasks: await asyncio.gather(*tasks) diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 979b636af..93892a739 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os @@ -239,6 +240,86 @@ class Agent(LLM, ToolBase): self.set_output("use_tools", use_tools) return ans + async def invoke_async(self, **kwargs): + """ + Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking. + """ + if self.check_if_canceled("Agent processing"): + return + + if kwargs.get("user_prompt"): + usr_pmt = "" + if kwargs.get("reasoning"): + usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"]) + if kwargs.get("context"): + usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"]) + if usr_pmt: + usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"])) + else: + usr_pmt = str(kwargs["user_prompt"]) + self._param.prompts = [{"role": "user", "content": usr_pmt}] + + if not self.tools: + if self.check_if_canceled("Agent processing"): + return + return await asyncio.to_thread(LLM._invoke, self, **kwargs) + + prompt, msg, user_defined_prompt = self._prepare_prompt_variables() + output_schema = self._get_output_schema() + schema_prompt = "" + if output_schema: + schema = json.dumps(output_schema, ensure_ascii=False, indent=2) + schema_prompt = structured_output_prompt(schema) + + downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] + ex = self.exception_handler() + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: + self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt)) + return + + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + use_tools = [] + ans = "" + async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt): + if self.check_if_canceled("Agent processing"): + return + ans += delta_ans + + if ans.find("**ERROR**") >= 0: + logging.error(f"Agent._chat got error. response: {ans}") + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + else: + self.set_output("_ERROR", ans) + return + + if output_schema: + error = "" + for _ in range(self._param.max_retries + 1): + try: + def clean_formated_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) + obj = json_repair.loads(clean_formated_answer(ans)) + self.set_output("structured", obj) + if use_tools: + self.set_output("use_tools", use_tools) + return obj + except Exception: + error = "The answer cannot be parsed as JSON" + ans = self._force_format_to_schema(ans, schema_prompt) + if ans.find("**ERROR**") >= 0: + continue + + self.set_output("_ERROR", error) + return + + self.set_output("content", ans) + if use_tools: + self.set_output("use_tools", use_tools) + return ans + def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer_without_toolcall = "" @@ -261,6 +342,54 @@ class Agent(LLM, ToolBase): if use_tools: self.set_output("use_tools", use_tools) + async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}): + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + answer_without_toolcall = "" + use_tools = [] + async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt): + if self.check_if_canceled("Agent streaming"): + return + + if delta_ans.find("**ERROR**") >= 0: + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + yield self.get_exception_default_value() + else: + self.set_output("_ERROR", delta_ans) + return + answer_without_toolcall += delta_ans + yield delta_ans + + self.set_output("content", answer_without_toolcall) + if use_tools: + self.set_output("use_tools", use_tools) + + async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): + """ + Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop. + """ + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt): + asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop) + except Exception as e: + asyncio.run_coroutine_threadsafe(queue.put(e), loop) + finally: + asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop) + + await asyncio.to_thread(worker) + + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + yield item + def _gen_citations(self, text): retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} @@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer. for k in self._param.inputs.keys(): self._param.inputs[k]["value"] = None self._param.debug_inputs = {} - diff --git a/agent/component/llm.py b/agent/component/llm.py index a29a36860..5f550c11a 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os import re +import threading from copy import deepcopy -from typing import Any, Generator +from typing import Any, Generator, AsyncGenerator import json_repair from functools import partial from common.constants import LLMType @@ -171,6 +173,13 @@ class LLM(ComponentBase): return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + async def _generate_async(self, msg: list[dict], **kwargs) -> str: + if not self.imgs and hasattr(self.chat_mdl, "async_chat"): + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) + if self.imgs and hasattr(self.chat_mdl, "async_chat"): + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + return await asyncio.to_thread(self._generate, msg, **kwargs) + def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: ans = "" last_idx = 0 @@ -205,6 +214,70 @@ class LLM(ComponentBase): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): yield delta(txt) + async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: + # Prefer async chat_streamly if available + async def delta_wrapper(txt_iter): + ans = "" + last_idx = 0 + endswith_think = False + + def delta(txt): + nonlocal ans, last_idx, endswith_think + delta_ans = txt[last_idx:] + ans = txt + + if delta_ans.find("") == 0: + last_idx += len("") + return "" + elif delta_ans.find("") > 0: + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] + last_idx += delta_ans.find("") + return delta_ans + elif delta_ans.endswith(""): + endswith_think = True + elif endswith_think: + endswith_think = False + return "" + + last_idx = len(ans) + if ans.endswith(""): + last_idx -= len("") + return re.sub(r"(|)", "", delta_ans) + + async for t in txt_iter: + yield delta(t) + + if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): + yield t + return + if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): + yield t + return + + # Fallback: run sync stream in thread, bridge results + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in self._generate_streamly(msg, **kwargs): + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + yield item + async def _stream_output_async(self, prompt, msg): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer = "" diff --git a/agent/tools/base.py b/agent/tools/base.py index b6d4dcb4c..c8554e075 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -52,10 +52,6 @@ class LLMToolPluginCallSession(ToolCallSession): assert name in self.tools_map, f"LLM tool {name} does not exist" st = timer() tool_obj = self.tools_map[name] - print("#########################################", flush=True) - tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__ - logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}") - print("#########################################", flush=True) if isinstance(tool_obj, MCPToolCallSession): resp = tool_obj.tool_call(name, arguments, 60) else: @@ -63,9 +59,6 @@ class LLMToolPluginCallSession(ToolCallSession): resp = asyncio.run(tool_obj.invoke_async(**arguments)) else: resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments)) - print("#########################################", flush=True) - logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s") - print("#########################################", flush=True) self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index dd33d885e..fa3f84679 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import logging @@ -360,6 +361,10 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use return kwd +async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): + return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts) + + def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): if not tools_description: return "" @@ -378,6 +383,10 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, return json_str, tk_cnt +async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): + return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts) + + def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] goal = history[1]["content"] @@ -429,6 +438,14 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st return re.sub(r"^.*", "", ans, flags=re.DOTALL) +async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): + return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts) + + +async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): + return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts) + + def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: meta_data_structure = {} for key, values in meta_data.items():