Feat:support tts in agent
This commit is contained in:
parent
a6681d6366
commit
ba60f015f2
1 changed files with 30 additions and 2 deletions
|
|
@ -16,6 +16,7 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
|
|||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.constants import LLMType
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
|
@ -456,6 +459,7 @@ class Canvas(Graph):
|
|||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
partials = []
|
||||
tts_mdl = None
|
||||
while idx < len(self.path):
|
||||
to = len(self.path)
|
||||
for i in range(idx, to):
|
||||
|
|
@ -473,8 +477,11 @@ class Canvas(Graph):
|
|||
cpn = self.get_component(self.path[i])
|
||||
cpn_obj = self.get_component_obj(self.path[i])
|
||||
if cpn_obj.component_name.lower() == "message":
|
||||
if cpn_obj.get_param("auto_play"):
|
||||
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
buff_m = ""
|
||||
stream = cpn_obj.output("content")()
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
|
|
@ -485,7 +492,12 @@ class Canvas(Graph):
|
|||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
buff_m += m
|
||||
if len(buff_m) > 16:
|
||||
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
else:
|
||||
for m in stream:
|
||||
|
|
@ -496,8 +508,16 @@ class Canvas(Graph):
|
|||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
buff_m += m
|
||||
if len(buff_m) > 16:
|
||||
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
if buff_m:
|
||||
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
|
|
@ -618,6 +638,14 @@ class Canvas(Graph):
|
|||
return False
|
||||
return True
|
||||
|
||||
def tts(self,tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue