From ce033dca2bee2d8d3013d894569f992e466ae883 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Fri, 24 May 2024 15:08:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=8E=E7=AB=AF=E5=B0=81=E8=A3=85tts?= =?UTF-8?q?=EF=BC=8C=E5=8F=AF=E4=BB=A5=E9=80=9A=E8=BF=87=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=88=87=E6=8D=A2openvoice=E5=92=8Cvits?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat_controller.py | 42 +++++------------------------- app/controllers/user_controller.py | 2 +- app/dependencies/tts.py | 24 ++++++++++++----- config/development.py | 1 + utils/tts/openvoice_utils.py | 33 +++++++++++------------ utils/tts/vits_utils.py | 9 ++++--- 6 files changed, 49 insertions(+), 62 deletions(-) diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index b5007bf..d56ae10 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -11,6 +11,7 @@ from datetime import datetime from utils.xf_asr_utils import generate_xf_asr_url from config import get_config import numpy as np +import struct import uuid import json import asyncio @@ -319,20 +320,11 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 elif response_type == RESPONSE_AUDIO: - if target_se.size == 0: - audio,sr = tts._base_tts(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"]) - else: - audio,sr = tts.synthesize(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"], - target_se=target_se) + audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) response_message = {"type": "text", "code":200, "msg": sentence} - await ws.send_bytes(audio) #返回音频数据 - await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 + response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8') + header = struct.pack('!II',len(response_bytes),len(audio)) + message_bytes = header + response_bytes + audio logger.debug(f"websocket返回: {sentence}") if is_end: logger.debug(f"llm返回结果: {llm_response}") @@ -500,17 +492,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) elif response_type == RESPONSE_AUDIO: - if target_se.size == 0: - audio,sr = tts._base_tts(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"]) - else: - audio,sr = tts.synthesize(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"], - target_se=target_se) + audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_bytes(audio) await ws.send_text(json.dumps(response_message, ensure_ascii=False)) @@ -684,17 +666,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re llm_response += chunk_data sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) for sentence in sentences: - if target_se.size == 0: - audio,sr = tts._base_tts(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"]) - else: - audio,sr = tts.synthesize(text=sentence, - noise_scale=tts_info["noise_scale"], - noise_scale_w=tts_info["noise_scale_w"], - speed=tts_info["speed"], - target_se=target_se) + audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) text_response = {"type": "llm_text", "code": 200, "msg": sentence} await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index 9788675..c7b2685 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -14,7 +14,7 @@ import io logger = get_logger() #依赖注入获取tts -tts = get_tts() +tts = get_tts("OPENVOICE") #创建用户 async def create_user_handler(user:UserCrateRequest, db: Session): diff --git a/app/dependencies/tts.py b/app/dependencies/tts.py index 616f82d..933b78d 100644 --- a/app/dependencies/tts.py +++ b/app/dependencies/tts.py @@ -1,11 +1,23 @@ -from utils.tts.openvoice_utils import TextToSpeech + from app.dependencies.logger import get_logger +from config import get_config logger = get_logger() +Config = get_config() + +from utils.tts.openvoice_utils import TextToSpeech +openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda') +logger.info("TTS_OPENVOICE 初始化成功") + +from utils.tts.vits_utils import TextToSpeech +vits_tts = TextToSpeech() +logger.info("TTS_VITS 初始化成功") + + #初始化全局tts对象 -tts = TextToSpeech(use_tone_convert=True,device='cuda') -logger.info("TTS初始化成功") - -def get_tts(): - return tts \ No newline at end of file +def get_tts(tts_type=Config.TTS_UTILS): + if tts_type == "OPENVOICE": + return openvoice_tts + elif tts_type == "VITS": + return vits_tts \ No newline at end of file diff --git a/config/development.py b/config/development.py index 8fd48e6..cffca47 100644 --- a/config/development.py +++ b/config/development.py @@ -2,6 +2,7 @@ class DevelopmentConfig: SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置 REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置 LOG_LEVEL = "DEBUG" #日志级别 + TTS_UTILS = "VITS" #TTS引擎配置,可选OPENVOICE或者VITS class UVICORN: HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip PORT = 8001 #uvicorn运行端口 diff --git a/utils/tts/openvoice_utils.py b/utils/tts/openvoice_utils.py index ab0604a..8cf3581 100644 --- a/utils/tts/openvoice_utils.py +++ b/utils/tts/openvoice_utils.py @@ -164,11 +164,14 @@ class TextToSpeech: """ return audio_data.cpu().detach().float().numpy() - def numpy2bytes(self, audio_data: np.ndarray): - """ - numpy类型转bytes - """ - return (audio_data*32768.0).astype(np.int32).tobytes() + def numpy2bytes(self, audio_data): + if isinstance(audio_data, np.ndarray): + if audio_data.dtype == np.dtype('float32'): + audio_data = np.int16(audio_data * np.iinfo(np.int16).max) + audio_data = audio_data.tobytes() + return audio_data + else: + raise TypeError("audio_data must be a numpy array") def _base_tts(self, text: str, @@ -292,14 +295,11 @@ class TextToSpeech: def synthesize(self, text: str, - sdp_ratio=0.2, - noise_scale=0.6, - noise_scale_w=0.8, - speed=1.0, - quite=True, - + tts_info, source_se: Optional[np.ndarray]=None, target_se: Optional[np.ndarray]=None, + sdp_ratio=0.2, + quite=True, tau :float=0.3, message :str="default"): """ @@ -316,11 +316,11 @@ class TextToSpeech: """ audio, sr = self._base_tts(text, sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - speed=speed, + noise_scale=tts_info['noise_scale'], + noise_scale_w=tts_info['noise_scale_w'], + speed=tts_info['speed'], quite=quite) - if self.use_tone_convert: + if self.use_tone_convert and target_se.size>0: tts_sr = self.base_tts_model.hps.data.sampling_rate converter_sr = self.tone_color_converter.hps.data.sampling_rate audio = F.resample(audio, tts_sr, converter_sr) @@ -342,4 +342,5 @@ class TextToSpeech: save_path: 保存路径 """ sf.write(save_path, audio, sample_rate) - print(f"Audio saved to {save_path}") \ No newline at end of file + print(f"Audio saved to {save_path}") + \ No newline at end of file diff --git a/utils/tts/vits_utils.py b/utils/tts/vits_utils.py index ddcc55c..c496751 100644 --- a/utils/tts/vits_utils.py +++ b/utils/tts/vits_utils.py @@ -2,6 +2,7 @@ import os import numpy as np import torch from torch import LongTensor +from typing import Optional import soundfile as sf # vits from .vits import utils, commons @@ -79,19 +80,19 @@ class TextToSpeech: print(f"Synthesis time: {time.time() - start_time} s") return audio - def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False): + def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True): if not len(text): return "输入文本不能为空!", None text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") if len(text) > 100 and self.limitation: return f"输入文字过长!{len(text)}>100", None - text = self._preprocess_text(text, language) - audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale) + text = self._preprocess_text(text, tts_info['language']) + audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale']) if self.debug or save_audio: self.save_audio(audio, self.RATE, 'output_file.wav') if return_bytes: audio = self.convert_numpy_to_bytes(audio) - return self.RATE, audio + return audio, self.RATE def convert_numpy_to_bytes(self, audio_data): if isinstance(audio_data, np.ndarray):