Feat:support tts in agent

This commit is contained in:
buua436 2025-12-03 09:34:09 +08:00
parent a6681d6366
commit ba60f015f2

View file

@ -16,6 +16,7 @@
import asyncio import asyncio
import base64 import base64
import inspect import inspect
import binascii
import json import json
import logging import logging
import re import re
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
@ -456,6 +459,7 @@ class Canvas(Graph):
self.error = "" self.error = ""
idx = len(self.path) - 1 idx = len(self.path) - 1
partials = [] partials = []
tts_mdl = None
while idx < len(self.path): while idx < len(self.path):
to = len(self.path) to = len(self.path)
for i in range(idx, to): for i in range(idx, to):
@ -473,8 +477,11 @@ class Canvas(Graph):
cpn = self.get_component(self.path[i]) cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i]) cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
buff_m = ""
stream = cpn_obj.output("content")() stream = cpn_obj.output("content")()
if inspect.isasyncgen(stream): if inspect.isasyncgen(stream):
async for m in stream: async for m in stream:
@ -485,7 +492,12 @@ class Canvas(Graph):
elif m == "</think>": elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True}) yield decorate("message", {"content": "", "end_to_think": True})
else: else:
yield decorate("message", {"content": m}) buff_m += m
if len(buff_m) > 16:
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m _m += m
else: else:
for m in stream: for m in stream:
@ -496,8 +508,16 @@ class Canvas(Graph):
elif m == "</think>": elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True}) yield decorate("message", {"content": "", "end_to_think": True})
else: else:
yield decorate("message", {"content": m}) buff_m += m
if len(buff_m) > 16:
yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)})
buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m _m += m
if buff_m:
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
cpn_obj.set_output("content", _m) cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else: else:
@ -618,6 +638,14 @@ class Canvas(Graph):
return False return False
return True return True
def tts(self,tts_mdl, text):
if not tts_mdl or not text:
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size): def get_history(self, window_size):
convs = [] convs = []
if window_size <= 0: if window_size <= 0: