blocks the main thread (event-loop) while running Agent tasks such as Canvas.run

This commit is contained in:
yaol 2025-11-28 17:36:25 +08:00
parent 918d5a9ff8
commit cd53716848

View file

@ -18,6 +18,8 @@ import json
import logging import logging
import re import re
import time import time
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
@ -372,7 +374,7 @@ class Canvas(Graph):
for k in kwargs.keys(): for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]: if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files": if k == "files":
self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) self.globals[f"sys.{k}"] = await self.get_files(kwargs[k])
else: else:
self.globals[f"sys.{k}"] = kwargs[k] self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] : if not self.globals["sys.conversation_turns"] :
@ -402,7 +404,7 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}}) self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t): async def _run_batch(f, t):
if self.is_canceled(): if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution." msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg) logging.info(msg)
@ -426,7 +428,7 @@ class Canvas(Graph):
thr.append(executor.submit(cpn.invoke, **cpn.get_input())) thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1 i += 1
for t in thr: for t in thr:
t.result() await asyncio.wrap_future(t)
def _node_finished(cpn_obj): def _node_finished(cpn_obj):
return decorate("node_finished",{ return decorate("node_finished",{
@ -453,7 +455,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]), "component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i]) "thoughts": self.get_component_thoughts(self.path[i])
}) })
_run_batch(idx, to) await _run_batch(idx, to)
to = len(self.path) to = len(self.path)
# post processing of components invocation # post processing of components invocation
for i in range(idx, to): for i in range(idx, to):
@ -462,7 +464,24 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
for m in cpn_obj.output("content")(): def generator_producer(gen, q, loop):
try:
for item in gen:
loop.call_soon_threadsafe(q.put_nowait, item)
loop.call_soon_threadsafe(q.put_nowait, None)
except Exception as e:
logging.exception("Error in generator producer")
loop.call_soon_threadsafe(q.put_nowait, None)
q = asyncio.Queue()
loop = asyncio.get_running_loop()
gen = cpn_obj.output("content")()
threading.Thread(target=generator_producer, args=(gen, q, loop)).start()
while True:
m = await q.get()
if m is None:
break
if not m: if not m:
continue continue
if m == "<think>": if m == "<think>":
@ -621,21 +640,24 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm): def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements() return self.components[cpnnm]["obj"].get_input_elements()
def get_files(self, files: Union[None, list[dict]]) -> list[str]: async def get_files(self, files: Union[None, list[dict]]) -> list[str]:
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
if not files: if not files:
return [] return []
def image_to_base64(file): def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"], return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = [] loop = asyncio.get_running_loop()
for file in files: tasks = []
if file["mime_type"].find("image") >=0: with ThreadPoolExecutor(max_workers=5) as exe:
threads.append(exe.submit(image_to_base64, file)) for file in files:
continue if file["mime_type"].find("image") >=0:
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) tasks.append(loop.run_in_executor(exe, image_to_base64, file))
return [th.result() for th in threads] continue
tasks.append(loop.run_in_executor(exe, partial(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])))
return await asyncio.gather(*tasks)
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->") agent_ids = agent_id.split("-->")