Merge 3f853ee722 into 6ea4248bdc
This commit is contained in:
commit
c68f1602e3
1 changed files with 31 additions and 9 deletions
|
|
@ -17,6 +17,8 @@ import json
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
|
@ -402,19 +404,21 @@ class Canvas(Graph):
|
|||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
async def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
executor = ThreadPoolExecutor(max_workers=5)
|
||||
try:
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, inputs=kwargs.get("inputs", {}))))
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
|
|
@ -423,10 +427,11 @@ class Canvas(Graph):
|
|||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
tasks.append(loop.run_in_executor(executor, partial(cpn.invoke, **cpn.get_input())))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
await asyncio.gather(*tasks)
|
||||
finally:
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
def _node_finished(cpn_obj):
|
||||
return decorate("node_finished",{
|
||||
|
|
@ -453,7 +458,7 @@ class Canvas(Graph):
|
|||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
|
|
@ -462,7 +467,24 @@ class Canvas(Graph):
|
|||
if cpn_obj.component_name.lower() == "message":
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
for m in cpn_obj.output("content")():
|
||||
def generator_producer(gen, q, loop):
|
||||
try:
|
||||
for item in gen:
|
||||
loop.call_soon_threadsafe(q.put_nowait, item)
|
||||
except Exception:
|
||||
logging.exception("Error in generator producer")
|
||||
finally:
|
||||
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||
|
||||
q = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
gen = cpn_obj.output("content")()
|
||||
threading.Thread(target=generator_producer, args=(gen, q, loop)).start()
|
||||
|
||||
while True:
|
||||
m = await q.get()
|
||||
if m is None:
|
||||
break
|
||||
if not m:
|
||||
continue
|
||||
if m == "<think>":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue