add async method to agent/tools/base.py

This commit is contained in:
yongtenglei 2025-12-02 16:21:28 +08:00
parent 057fa0057a
commit e9356db849

View file

@ -17,6 +17,7 @@ import logging
import re
import time
from copy import deepcopy
import asyncio
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
@ -50,10 +51,21 @@ class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
tool_obj = self.tools_map[name]
print("#########################################", flush=True)
tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__
logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}")
print("#########################################", flush=True)
if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60)
else:
resp = self.tools_map[name].invoke(**arguments)
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments))
else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
print("#########################################", flush=True)
logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s")
print("#########################################", flush=True)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -139,6 +151,30 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
if asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []