add async method to agent/tools/base.py
This commit is contained in:
parent
057fa0057a
commit
e9356db849
1 changed files with 39 additions and 3 deletions
|
|
@ -17,6 +17,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TypedDict, List, Any
|
from typing import TypedDict, List, Any
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
|
|
@ -50,10 +51,21 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||||
st = timer()
|
st = timer()
|
||||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
tool_obj = self.tools_map[name]
|
||||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
print("#########################################", flush=True)
|
||||||
|
tool_desc = getattr(tool_obj, "_id", None) or getattr(tool_obj, "component_name", None) or tool_obj.__class__.__name__
|
||||||
|
logging.info(f"[ToolCall] start name={name}, tool={tool_desc}, async={hasattr(tool_obj, 'invoke_async')}")
|
||||||
|
print("#########################################", flush=True)
|
||||||
|
if isinstance(tool_obj, MCPToolCallSession):
|
||||||
|
resp = tool_obj.tool_call(name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
resp = self.tools_map[name].invoke(**arguments)
|
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||||
|
resp = asyncio.run(tool_obj.invoke_async(**arguments))
|
||||||
|
else:
|
||||||
|
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
|
||||||
|
print("#########################################", flush=True)
|
||||||
|
logging.info(f"[ToolCall] done name={name}, tool={tool_desc}, elapsed={timer()-st:.3f}s")
|
||||||
|
print("#########################################", flush=True)
|
||||||
|
|
||||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
|
|
@ -139,6 +151,30 @@ class ToolBase(ComponentBase):
|
||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async wrapper for tool invocation.
|
||||||
|
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||||
|
Mirrors the exception handling of `invoke`.
|
||||||
|
"""
|
||||||
|
if self.check_if_canceled("Tool processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
res = await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||||
|
logging.exception(e)
|
||||||
|
res = str(e)
|
||||||
|
self._param.debug_inputs = []
|
||||||
|
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return res
|
||||||
|
|
||||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
aggs = []
|
aggs = []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue