diff --git a/agent/tools/base.py b/agent/tools/base.py index 791242d59..b6d4dcb4c 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -17,6 +17,7 @@ import logging import re import time from copy import deepcopy +import asyncio from functools import partial from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase @@ -50,10 +51,21 @@ class LLMToolPluginCallSession(ToolCallSession): def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" st = timer() - if isinstance(self.tools_map[name], MCPToolCallSession): - resp = self.tools_map[name].tool_call(name, arguments, 60) + tool_obj = self.tools_map[name] + print("#########################################", flush=True) + tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__ + logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}") + print("#########################################", flush=True) + if isinstance(tool_obj, MCPToolCallSession): + resp = tool_obj.tool_call(name, arguments, 60) else: - resp = self.tools_map[name].invoke(**arguments) + if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): + resp = asyncio.run(tool_obj.invoke_async(**arguments)) + else: + resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments)) + print("#########################################", flush=True) + logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s") + print("#########################################", flush=True) self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp @@ -139,6 +151,30 @@ class ToolBase(ComponentBase): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return res + async def invoke_async(self, **kwargs): + """ + Async wrapper for tool invocation. + If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking. + Mirrors the exception handling of `invoke`. + """ + if self.check_if_canceled("Tool processing"): + return + + self.set_output("_created_time", time.perf_counter()) + try: + if asyncio.iscoroutinefunction(self._invoke): + res = await self._invoke(**kwargs) + else: + res = await asyncio.to_thread(self._invoke, **kwargs) + except Exception as e: + self._param.outputs["_ERROR"] = {"value": str(e)} + logging.exception(e) + res = str(e) + self._param.debug_inputs = [] + + self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) + return res + def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): chunks = [] aggs = []