add async method to agent/tools/base.py

This commit is contained in:
yongtenglei 2025-12-02 16:21:28 +08:00
parent 057fa0057a
commit e9356db849

View file

@ -17,6 +17,7 @@ import logging
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
import asyncio
from functools import partial from functools import partial
from typing import TypedDict, List, Any from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
@ -50,10 +51,21 @@ class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist" assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer() st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession): tool_obj = self.tools_map[name]
resp = self.tools_map[name].tool_call(name, arguments, 60) print("#########################################", flush=True)
tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__
logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}")
print("#########################################", flush=True)
if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60)
else: else:
resp = self.tools_map[name].invoke(**arguments) if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments))
else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
print("#########################################", flush=True)
logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s")
print("#########################################", flush=True)
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp
@ -139,6 +151,30 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
if asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = [] chunks = []
aggs = [] aggs = []