feat: 后端封装tts,可以通过配置文件切换openvoice和vits

This commit is contained in:
killua4396 2024-05-24 15:08:55 +08:00
parent e2f3decfae
commit ce033dca2b
6 changed files with 49 additions and 62 deletions

View File

@ -11,6 +11,7 @@ from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config from config import get_config
import numpy as np import numpy as np
import struct
import uuid import uuid
import json import json
import asyncio import asyncio
@ -319,20 +320,11 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO: elif response_type == RESPONSE_AUDIO:
if target_se.size == 0: audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据 response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 header = struct.pack('!II',len(response_bytes),len(audio))
message_bytes = header + response_bytes + audio
logger.debug(f"websocket返回: {sentence}") logger.debug(f"websocket返回: {sentence}")
if is_end: if is_end:
logger.debug(f"llm返回结果: {llm_response}") logger.debug(f"llm返回结果: {llm_response}")
@ -500,17 +492,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO: elif response_type == RESPONSE_AUDIO:
if target_se.size == 0: audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) await ws.send_text(json.dumps(response_message, ensure_ascii=False))
@ -684,17 +666,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
llm_response += chunk_data llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences: for sentence in sentences:
if target_se.size == 0: audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
text_response = {"type": "llm_text", "code": 200, "msg": sentence} text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据

View File

@ -14,7 +14,7 @@ import io
logger = get_logger() logger = get_logger()
#依赖注入获取tts #依赖注入获取tts
tts = get_tts() tts = get_tts("OPENVOICE")
#创建用户 #创建用户
async def create_user_handler(user:UserCrateRequest, db: Session): async def create_user_handler(user:UserCrateRequest, db: Session):

View File

@ -1,11 +1,23 @@
from utils.tts.openvoice_utils import TextToSpeech
from app.dependencies.logger import get_logger from app.dependencies.logger import get_logger
from config import get_config
logger = get_logger() logger = get_logger()
Config = get_config()
from utils.tts.openvoice_utils import TextToSpeech
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS_OPENVOICE 初始化成功")
from utils.tts.vits_utils import TextToSpeech
vits_tts = TextToSpeech()
logger.info("TTS_VITS 初始化成功")
#初始化全局tts对象 #初始化全局tts对象
tts = TextToSpeech(use_tone_convert=True,device='cuda') def get_tts(tts_type=Config.TTS_UTILS):
logger.info("TTS初始化成功") if tts_type == "OPENVOICE":
return openvoice_tts
def get_tts(): elif tts_type == "VITS":
return tts return vits_tts

View File

@ -2,6 +2,7 @@ class DevelopmentConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置 SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置 REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "DEBUG" #日志级别 LOG_LEVEL = "DEBUG" #日志级别
TTS_UTILS = "VITS" #TTS引擎配置可选OPENVOICE或者VITS
class UVICORN: class UVICORN:
HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip
PORT = 8001 #uvicorn运行端口 PORT = 8001 #uvicorn运行端口

View File

@ -164,11 +164,14 @@ class TextToSpeech:
""" """
return audio_data.cpu().detach().float().numpy() return audio_data.cpu().detach().float().numpy()
def numpy2bytes(self, audio_data: np.ndarray): def numpy2bytes(self, audio_data):
""" if isinstance(audio_data, np.ndarray):
numpy类型转bytes if audio_data.dtype == np.dtype('float32'):
""" audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
return (audio_data*32768.0).astype(np.int32).tobytes() audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def _base_tts(self, def _base_tts(self,
text: str, text: str,
@ -292,14 +295,11 @@ class TextToSpeech:
def synthesize(self, def synthesize(self,
text: str, text: str,
sdp_ratio=0.2, tts_info,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
quite=True,
source_se: Optional[np.ndarray]=None, source_se: Optional[np.ndarray]=None,
target_se: Optional[np.ndarray]=None, target_se: Optional[np.ndarray]=None,
sdp_ratio=0.2,
quite=True,
tau :float=0.3, tau :float=0.3,
message :str="default"): message :str="default"):
""" """
@ -316,11 +316,11 @@ class TextToSpeech:
""" """
audio, sr = self._base_tts(text, audio, sr = self._base_tts(text,
sdp_ratio=sdp_ratio, sdp_ratio=sdp_ratio,
noise_scale=noise_scale, noise_scale=tts_info['noise_scale'],
noise_scale_w=noise_scale_w, noise_scale_w=tts_info['noise_scale_w'],
speed=speed, speed=tts_info['speed'],
quite=quite) quite=quite)
if self.use_tone_convert: if self.use_tone_convert and target_se.size>0:
tts_sr = self.base_tts_model.hps.data.sampling_rate tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr) audio = F.resample(audio, tts_sr, converter_sr)
@ -343,3 +343,4 @@ class TextToSpeech:
""" """
sf.write(save_path, audio, sample_rate) sf.write(save_path, audio, sample_rate)
print(f"Audio saved to {save_path}") print(f"Audio saved to {save_path}")

View File

@ -2,6 +2,7 @@ import os
import numpy as np import numpy as np
import torch import torch
from torch import LongTensor from torch import LongTensor
from typing import Optional
import soundfile as sf import soundfile as sf
# vits # vits
from .vits import utils, commons from .vits import utils, commons
@ -79,19 +80,19 @@ class TextToSpeech:
print(f"Synthesis time: {time.time() - start_time} s") print(f"Synthesis time: {time.time() - start_time} s")
return audio return audio
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False): def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
if not len(text): if not len(text):
return "输入文本不能为空!", None return "输入文本不能为空!", None
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation: if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, language) text = self._preprocess_text(text, tts_info['language'])
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale) audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
if self.debug or save_audio: if self.debug or save_audio:
self.save_audio(audio, self.RATE, 'output_file.wav') self.save_audio(audio, self.RATE, 'output_file.wav')
if return_bytes: if return_bytes:
audio = self.convert_numpy_to_bytes(audio) audio = self.convert_numpy_to_bytes(audio)
return self.RATE, audio return audio, self.RATE
def convert_numpy_to_bytes(self, audio_data): def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray): if isinstance(audio_data, np.ndarray):