Feat:support tts in agent
This commit is contained in:
parent
a6681d6366
commit
ba60f015f2
1 changed files with 30 additions and 2 deletions
|
|
@ -16,6 +16,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import inspect
|
import inspect
|
||||||
|
import binascii
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
|
||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
from common.constants import LLMType
|
||||||
from common.misc_utils import get_uuid, hash_str2int
|
from common.misc_utils import get_uuid, hash_str2int
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
@ -456,6 +459,7 @@ class Canvas(Graph):
|
||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
partials = []
|
partials = []
|
||||||
|
tts_mdl = None
|
||||||
while idx < len(self.path):
|
while idx < len(self.path):
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
|
|
@ -473,8 +477,11 @@ class Canvas(Graph):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
|
if cpn_obj.get_param("auto_play"):
|
||||||
|
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
|
buff_m = ""
|
||||||
stream = cpn_obj.output("content")()
|
stream = cpn_obj.output("content")()
|
||||||
if inspect.isasyncgen(stream):
|
if inspect.isasyncgen(stream):
|
||||||
async for m in stream:
|
async for m in stream:
|
||||||
|
|
@ -485,7 +492,12 @@ class Canvas(Graph):
|
||||||
elif m == "</think>":
|
elif m == "</think>":
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
yield decorate("message", {"content": "", "end_to_think": True})
|
||||||
else:
|
else:
|
||||||
yield decorate("message", {"content": m})
|
buff_m += m
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
|
||||||
|
buff_m = ""
|
||||||
|
else:
|
||||||
|
yield decorate("message", {"content": m})
|
||||||
_m += m
|
_m += m
|
||||||
else:
|
else:
|
||||||
for m in stream:
|
for m in stream:
|
||||||
|
|
@ -496,8 +508,16 @@ class Canvas(Graph):
|
||||||
elif m == "</think>":
|
elif m == "</think>":
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
yield decorate("message", {"content": "", "end_to_think": True})
|
||||||
else:
|
else:
|
||||||
yield decorate("message", {"content": m})
|
buff_m += m
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
|
||||||
|
buff_m = ""
|
||||||
|
else:
|
||||||
|
yield decorate("message", {"content": m})
|
||||||
_m += m
|
_m += m
|
||||||
|
if buff_m:
|
||||||
|
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||||
|
buff_m = ""
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
|
|
@ -618,6 +638,14 @@ class Canvas(Graph):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def tts(self,tts_mdl, text):
|
||||||
|
if not tts_mdl or not text:
|
||||||
|
return None
|
||||||
|
bin = b""
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue