diff --git a/agent/canvas.py b/agent/canvas.py index c120e20f6..d83c63c4f 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -410,13 +410,15 @@ class Canvas(Graph): logging.info(msg) raise TaskCanceledException(msg) - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] + loop = asyncio.get_running_loop() + tasks = [] + executor = ThreadPoolExecutor(max_workers=5) + try: i = f while i < t: cpn = self.get_component_obj(self.path[i]) if cpn.component_name.lower() in ["begin", "userfillup"]: - thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) + tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, inputs=kwargs.get("inputs", {})))) i += 1 else: for _, ele in cpn.get_input_elements().items(): @@ -425,10 +427,11 @@ class Canvas(Graph): t -= 1 break else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) + tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, **cpn.get_input()))) i += 1 - for t in thr: - await asyncio.wrap_future(t) + await asyncio.gather(*tasks) + finally: + executor.shutdown(wait=False) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -468,9 +471,9 @@ class Canvas(Graph): try: for item in gen: loop.call_soon_threadsafe(q.put_nowait, item) - loop.call_soon_threadsafe(q.put_nowait, None) except Exception as e: logging.exception("Error in generator producer") + finally: loop.call_soon_threadsafe(q.put_nowait, None) q = asyncio.Queue() @@ -650,14 +653,13 @@ class Canvas(Graph): loop = asyncio.get_running_loop() tasks = [] - with ThreadPoolExecutor(max_workers=5) as exe: - for file in files: - if file["mime_type"].find("image") >=0: - tasks.append(loop.run_in_executor(exe, image_to_base64, file)) - continue - tasks.append(loop.run_in_executor(exe, partial(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))) - - return await asyncio.gather(*tasks) + for file in files: + if file["mime_type"].find("image") >=0: + tasks.append(loop.run_in_executor(None, image_to_base64, file)) + continue + tasks.append(loop.run_in_executor(None, partial(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))) + + return await asyncio.gather(*tasks) def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->")