feat: 后端封装tts,可以通过配置文件切换openvoice和vits

This commit is contained in:
killua4396 2024-05-24 15:08:55 +08:00
parent e2f3decfae
commit ce033dca2b
6 changed files with 49 additions and 62 deletions

View File

@ -11,6 +11,7 @@ from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config
import numpy as np
import struct
import uuid
import json
import asyncio
@ -319,20 +320,11 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO:
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(response_bytes),len(audio))
message_bytes = header + response_bytes + audio
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
@ -500,17 +492,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
@ -684,17 +666,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["speed"],
target_se=target_se)
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据

View File

@ -14,7 +14,7 @@ import io
logger = get_logger()
#依赖注入获取tts
tts = get_tts()
tts = get_tts("OPENVOICE")
#创建用户
async def create_user_handler(user:UserCrateRequest, db: Session):

View File

@ -1,11 +1,23 @@
from utils.tts.openvoice_utils import TextToSpeech
from app.dependencies.logger import get_logger
from config import get_config
logger = get_logger()
Config = get_config()
from utils.tts.openvoice_utils import TextToSpeech
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS_OPENVOICE 初始化成功")
from utils.tts.vits_utils import TextToSpeech
vits_tts = TextToSpeech()
logger.info("TTS_VITS 初始化成功")
#初始化全局tts对象
tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS初始化成功")
def get_tts():
return tts
def get_tts(tts_type=Config.TTS_UTILS):
if tts_type == "OPENVOICE":
return openvoice_tts
elif tts_type == "VITS":
return vits_tts

View File

@ -2,6 +2,7 @@ class DevelopmentConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "DEBUG" #日志级别
TTS_UTILS = "VITS" #TTS引擎配置可选OPENVOICE或者VITS
class UVICORN:
HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip
PORT = 8001 #uvicorn运行端口

View File

@ -164,11 +164,14 @@ class TextToSpeech:
"""
return audio_data.cpu().detach().float().numpy()
def numpy2bytes(self, audio_data: np.ndarray):
"""
numpy类型转bytes
"""
return (audio_data*32768.0).astype(np.int32).tobytes()
def numpy2bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def _base_tts(self,
text: str,
@ -292,14 +295,11 @@ class TextToSpeech:
def synthesize(self,
text: str,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
quite=True,
tts_info,
source_se: Optional[np.ndarray]=None,
target_se: Optional[np.ndarray]=None,
sdp_ratio=0.2,
quite=True,
tau :float=0.3,
message :str="default"):
"""
@ -316,11 +316,11 @@ class TextToSpeech:
"""
audio, sr = self._base_tts(text,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
speed=speed,
noise_scale=tts_info['noise_scale'],
noise_scale_w=tts_info['noise_scale_w'],
speed=tts_info['speed'],
quite=quite)
if self.use_tone_convert:
if self.use_tone_convert and target_se.size>0:
tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr)
@ -343,3 +343,4 @@ class TextToSpeech:
"""
sf.write(save_path, audio, sample_rate)
print(f"Audio saved to {save_path}")

View File

@ -2,6 +2,7 @@ import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
# vits
from .vits import utils, commons
@ -79,19 +80,19 @@ class TextToSpeech:
print(f"Synthesis time: {time.time() - start_time} s")
return audio
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
if not len(text):
return "输入文本不能为空!", None
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, language)
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
text = self._preprocess_text(text, tts_info['language'])
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
if self.debug or save_audio:
self.save_audio(audio, self.RATE, 'output_file.wav')
if return_bytes:
audio = self.convert_numpy_to_bytes(audio)
return self.RATE, audio
return audio, self.RATE
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):