Feat:support tts in agent

This commit is contained in:
buua436 2025-12-03 09:34:09 +08:00
parent a6681d6366
commit ba60f015f2

View file

@ -16,6 +16,7 @@
import asyncio
import base64
import inspect
import binascii
import json
import logging
import re
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format
@ -456,6 +459,7 @@ class Canvas(Graph):
self.error = ""
idx = len(self.path) - 1
partials = []
tts_mdl = None
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
@ -473,8 +477,11 @@ class Canvas(Graph):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial):
_m = ""
buff_m = ""
stream = cpn_obj.output("content")()
if inspect.isasyncgen(stream):
async for m in stream:
@ -485,7 +492,12 @@ class Canvas(Graph):
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
buff_m += m
if len(buff_m) > 16:
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m
else:
for m in stream:
@ -496,8 +508,16 @@ class Canvas(Graph):
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
buff_m += m
if len(buff_m) > 16:
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m
if buff_m:
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
@ -618,6 +638,14 @@ class Canvas(Graph):
return False
return True
def tts(self,tts_mdl, text):
if not tts_mdl or not text:
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size):
convs = []
if window_size <= 0: