diff --git a/agent/canvas.py b/agent/canvas.py index 9e95a5611..c120e20f6 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -18,6 +18,8 @@ import json import logging import re import time +import asyncio +import threading from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial @@ -372,7 +374,7 @@ class Canvas(Graph): for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": - self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) + self.globals[f"sys.{k}"] = await self.get_files(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : @@ -402,7 +404,7 @@ class Canvas(Graph): yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) - def _run_batch(f, t): + async def _run_batch(f, t): if self.is_canceled(): msg = f"Task {self.task_id} has been canceled during batch execution." logging.info(msg) @@ -426,7 +428,7 @@ class Canvas(Graph): thr.append(executor.submit(cpn.invoke, **cpn.get_input())) i += 1 for t in thr: - t.result() + await asyncio.wrap_future(t) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -453,7 +455,7 @@ class Canvas(Graph): "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) - _run_batch(idx, to) + await _run_batch(idx, to) to = len(self.path) # post processing of components invocation for i in range(idx, to): @@ -462,7 +464,24 @@ class Canvas(Graph): if cpn_obj.component_name.lower() == "message": if isinstance(cpn_obj.output("content"), partial): _m = "" - for m in cpn_obj.output("content")(): + def generator_producer(gen, q, loop): + try: + for item in gen: + loop.call_soon_threadsafe(q.put_nowait, item) + loop.call_soon_threadsafe(q.put_nowait, None) + except Exception as e: + logging.exception("Error in generator producer") + loop.call_soon_threadsafe(q.put_nowait, None) + + q = asyncio.Queue() + loop = asyncio.get_running_loop() + gen = cpn_obj.output("content")() + threading.Thread(target=generator_producer, args=(gen, q, loop)).start() + + while True: + m = await q.get() + if m is None: + break if not m: continue if m == "": @@ -621,21 +640,24 @@ class Canvas(Graph): def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() - def get_files(self, files: Union[None, list[dict]]) -> list[str]: + async def get_files(self, files: Union[None, list[dict]]) -> list[str]: from api.db.services.file_service import FileService if not files: return [] def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) - threads = [] - for file in files: - if file["mime_type"].find("image") >=0: - threads.append(exe.submit(image_to_base64, file)) - continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) - return [th.result() for th in threads] + + loop = asyncio.get_running_loop() + tasks = [] + with ThreadPoolExecutor(max_workers=5) as exe: + for file in files: + if file["mime_type"].find("image") >=0: + tasks.append(loop.run_in_executor(exe, image_to_base64, file)) + continue + tasks.append(loop.run_in_executor(exe, partial(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))) + + return await asyncio.gather(*tasks) def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->")