blocks the main thread (event-loop) while running Agent tasks such as Canvas.run
This commit is contained in:
parent
918d5a9ff8
commit
cd53716848
1 changed files with 36 additions and 14 deletions
|
|
@ -18,6 +18,8 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
@ -372,7 +374,7 @@ class Canvas(Graph):
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||||
if k == "files":
|
if k == "files":
|
||||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
self.globals[f"sys.{k}"] = await self.get_files(kwargs[k])
|
||||||
else:
|
else:
|
||||||
self.globals[f"sys.{k}"] = kwargs[k]
|
self.globals[f"sys.{k}"] = kwargs[k]
|
||||||
if not self.globals["sys.conversation_turns"] :
|
if not self.globals["sys.conversation_turns"] :
|
||||||
|
|
@ -402,7 +404,7 @@ class Canvas(Graph):
|
||||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||||
|
|
||||||
def _run_batch(f, t):
|
async def _run_batch(f, t):
|
||||||
if self.is_canceled():
|
if self.is_canceled():
|
||||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
@ -426,7 +428,7 @@ class Canvas(Graph):
|
||||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||||
i += 1
|
i += 1
|
||||||
for t in thr:
|
for t in thr:
|
||||||
t.result()
|
await asyncio.wrap_future(t)
|
||||||
|
|
||||||
def _node_finished(cpn_obj):
|
def _node_finished(cpn_obj):
|
||||||
return decorate("node_finished",{
|
return decorate("node_finished",{
|
||||||
|
|
@ -453,7 +455,7 @@ class Canvas(Graph):
|
||||||
"component_type": self.get_component_type(self.path[i]),
|
"component_type": self.get_component_type(self.path[i]),
|
||||||
"thoughts": self.get_component_thoughts(self.path[i])
|
"thoughts": self.get_component_thoughts(self.path[i])
|
||||||
})
|
})
|
||||||
_run_batch(idx, to)
|
await _run_batch(idx, to)
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
|
|
@ -462,7 +464,24 @@ class Canvas(Graph):
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
for m in cpn_obj.output("content")():
|
def generator_producer(gen, q, loop):
|
||||||
|
try:
|
||||||
|
for item in gen:
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, item)
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Error in generator producer")
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||||
|
|
||||||
|
q = asyncio.Queue()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
gen = cpn_obj.output("content")()
|
||||||
|
threading.Thread(target=generator_producer, args=(gen, q, loop)).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
m = await q.get()
|
||||||
|
if m is None:
|
||||||
|
break
|
||||||
if not m:
|
if not m:
|
||||||
continue
|
continue
|
||||||
if m == "<think>":
|
if m == "<think>":
|
||||||
|
|
@ -621,21 +640,24 @@ class Canvas(Graph):
|
||||||
def get_component_input_elements(self, cpnnm):
|
def get_component_input_elements(self, cpnnm):
|
||||||
return self.components[cpnnm]["obj"].get_input_elements()
|
return self.components[cpnnm]["obj"].get_input_elements()
|
||||||
|
|
||||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
async def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
def image_to_base64(file):
|
def image_to_base64(file):
|
||||||
return "data:{};base64,{}".format(file["mime_type"],
|
return "data:{};base64,{}".format(file["mime_type"],
|
||||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||||
exe = ThreadPoolExecutor(max_workers=5)
|
|
||||||
threads = []
|
loop = asyncio.get_running_loop()
|
||||||
|
tasks = []
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as exe:
|
||||||
for file in files:
|
for file in files:
|
||||||
if file["mime_type"].find("image") >=0:
|
if file["mime_type"].find("image") >=0:
|
||||||
threads.append(exe.submit(image_to_base64, file))
|
tasks.append(loop.run_in_executor(exe, image_to_base64, file))
|
||||||
continue
|
continue
|
||||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
tasks.append(loop.run_in_executor(exe, partial(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])))
|
||||||
return [th.result() for th in threads]
|
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||||
agent_ids = agent_id.split("-->")
|
agent_ids = agent_id.split("-->")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue