diff --git a/agent/canvas.py b/agent/canvas.py index bcb8687eb..8fb94cdec 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -359,8 +359,6 @@ class Canvas(Graph): self.globals[k] = "" else: self.globals[k] = "" - print(self.globals) - async def run(self, **kwargs): st = time.perf_counter() @@ -485,36 +483,14 @@ class Canvas(Graph): stream = cpn_obj.output("content")() if inspect.isasyncgen(stream): async for m in stream: - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - buff_m += m - if len(buff_m) > 16: - yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)}) - buff_m = "" - else: - yield decorate("message", {"content": m}) - _m += m + buff_m, _m, ev = await self._process_stream(m, buff_m, _m, decorate, tts_mdl) + if ev: + yield ev else: for m in stream: - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - buff_m += m - if len(buff_m) > 16: - yield decorate("message", {"content": m,"audio_binary":self.tts(tts_mdl, buff_m)}) - buff_m = "" - else: - yield decorate("message", {"content": m}) - _m += m + buff_m, _m, ev = await self._process_stream(m, buff_m, _m, decorate, tts_mdl) + if ev: + yield ev if buff_m: yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)}) buff_m = "" @@ -638,12 +614,73 @@ class Canvas(Graph): return False return True + async def _process_stream(self, m, buff_m, _m, decorate, tts_mdl): + if not m: + return buff_m, _m, None + if m == "": + return buff_m, _m, decorate("message", {"content": "", "start_to_think": True}) + + if m == "": + return buff_m, _m, decorate("message", {"content": "", "end_to_think": True}) + + buff_m += m + _m += m + + if len(buff_m) > 16: + ev = decorate( + "message", + { + "content": m, + "audio_binary": self.tts(tts_mdl, buff_m) + } + ) + buff_m = "" + return buff_m, _m, ev + + ev = decorate("message", {"content": m}) + return buff_m, _m, ev + def tts(self,tts_mdl, text): + def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text if not tts_mdl or not text: return None + text = clean_tts_text(text) + if not text: + return None bin = b"" - for chunk in tts_mdl.tts(text): - bin += chunk + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None return binascii.hexlify(bin).decode("utf-8") def get_history(self, window_size): diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f1b74ce82..4afdd1f3c 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text. "prompt": sys_prompt, } +def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text def tts(tts_mdl, text): if not tts_mdl or not text: return None + text = clean_tts_text(text) + if not text: + return None bin = b"" - for chunk in tts_mdl.tts(text): - bin += chunk + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None return binascii.hexlify(bin).decode("utf-8")