diff --git a/agent/canvas.py b/agent/canvas.py index 5344d70c3..59f9e95cf 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -17,6 +17,8 @@ import json import logging import re import time +import asyncio +import threading from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial @@ -402,19 +404,21 @@ class Canvas(Graph): yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) - def _run_batch(f, t): + async def _run_batch(f, t): if self.is_canceled(): msg = f"Task {self.task_id} has been canceled during batch execution." logging.info(msg) raise TaskCanceledException(msg) - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] + loop = asyncio.get_running_loop() + tasks = [] + executor = ThreadPoolExecutor(max_workers=5) + try: i = f while i < t: cpn = self.get_component_obj(self.path[i]) if cpn.component_name.lower() in ["begin", "userfillup"]: - thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) + tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, inputs=kwargs.get("inputs", {})))) i += 1 else: for _, ele in cpn.get_input_elements().items(): @@ -423,10 +427,11 @@ class Canvas(Graph): t -= 1 break else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) + tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, **cpn.get_input()))) i += 1 - for t in thr: - t.result() + await asyncio.gather(*tasks) + finally: + executor.shutdown(wait=False) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -453,7 +458,7 @@ class Canvas(Graph): "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) - _run_batch(idx, to) + await _run_batch(idx, to) to = len(self.path) # post processing of components invocation for i in range(idx, to): @@ -462,7 +467,24 @@ class Canvas(Graph): if cpn_obj.component_name.lower() == "message": if isinstance(cpn_obj.output("content"), partial): _m = "" - for m in cpn_obj.output("content")(): + def generator_producer(gen, q, loop): + try: + for item in gen: + loop.call_soon_threadsafe(q.put_nowait, item) + except Exception: + logging.exception("Error in generator producer") + finally: + loop.call_soon_threadsafe(q.put_nowait, None) + + q = asyncio.Queue() + loop = asyncio.get_running_loop() + gen = cpn_obj.output("content")() + threading.Thread(target=generator_producer, args=(gen, q, loop)).start() + + while True: + m = await q.get() + if m is None: + break if not m: continue if m == "":