This commit is contained in:
zhaobai 2025-12-01 14:07:11 +08:00 committed by GitHub
commit c68f1602e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,6 +17,8 @@ import json
import logging
import re
import time
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
@ -402,19 +404,21 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t):
async def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
loop = asyncio.get_running_loop()
tasks = []
executor = ThreadPoolExecutor(max_workers=5)
try:
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, inputs=kwargs.get("inputs", {}))))
i += 1
else:
for _, ele in cpn.get_input_elements().items():
@ -423,10 +427,11 @@ class Canvas(Graph):
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, **cpn.get_input())))
i += 1
for t in thr:
t.result()
await asyncio.gather(*tasks)
finally:
executor.shutdown(wait=False)
def _node_finished(cpn_obj):
return decorate("node_finished",{
@ -453,7 +458,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
await _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
@ -462,7 +467,24 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial):
_m = ""
for m in cpn_obj.output("content")():
def generator_producer(gen, q, loop):
try:
for item in gen:
loop.call_soon_threadsafe(q.put_nowait, item)
except Exception:
logging.exception("Error in generator producer")
finally:
loop.call_soon_threadsafe(q.put_nowait, None)
q = asyncio.Queue()
loop = asyncio.get_running_loop()
gen = cpn_obj.output("content")()
threading.Thread(target=generator_producer, args=(gen, q, loop)).start()
while True:
m = await q.get()
if m is None:
break
if not m:
continue
if m == "<think>":