feat: 后端封装tts,可以通过配置文件切换openvoice和vits
This commit is contained in:
parent
e2f3decfae
commit
ce033dca2b
|
@ -11,6 +11,7 @@ from datetime import datetime
|
|||
from utils.xf_asr_utils import generate_xf_asr_url
|
||||
from config import get_config
|
||||
import numpy as np
|
||||
import struct
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
|
@ -319,20 +320,11 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"],
|
||||
target_se=target_se)
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频数据
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
||||
header = struct.pack('!II',len(response_bytes),len(audio))
|
||||
message_bytes = header + response_bytes + audio
|
||||
logger.debug(f"websocket返回: {sentence}")
|
||||
if is_end:
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
|
@ -500,17 +492,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"],
|
||||
target_se=target_se)
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio)
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
|
@ -684,17 +666,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
llm_response += chunk_data
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||
for sentence in sentences:
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["speed"],
|
||||
target_se=target_se)
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
|
|
|
@ -14,7 +14,7 @@ import io
|
|||
logger = get_logger()
|
||||
|
||||
#依赖注入获取tts
|
||||
tts = get_tts()
|
||||
tts = get_tts("OPENVOICE")
|
||||
|
||||
#创建用户
|
||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
from utils.tts.openvoice_utils import TextToSpeech
|
||||
|
||||
from app.dependencies.logger import get_logger
|
||||
from config import get_config
|
||||
|
||||
logger = get_logger()
|
||||
Config = get_config()
|
||||
|
||||
from utils.tts.openvoice_utils import TextToSpeech
|
||||
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
||||
logger.info("TTS_OPENVOICE 初始化成功")
|
||||
|
||||
from utils.tts.vits_utils import TextToSpeech
|
||||
vits_tts = TextToSpeech()
|
||||
logger.info("TTS_VITS 初始化成功")
|
||||
|
||||
|
||||
|
||||
#初始化全局tts对象
|
||||
tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
||||
logger.info("TTS初始化成功")
|
||||
|
||||
def get_tts():
|
||||
return tts
|
||||
def get_tts(tts_type=Config.TTS_UTILS):
|
||||
if tts_type == "OPENVOICE":
|
||||
return openvoice_tts
|
||||
elif tts_type == "VITS":
|
||||
return vits_tts
|
|
@ -2,6 +2,7 @@ class DevelopmentConfig:
|
|||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||
LOG_LEVEL = "DEBUG" #日志级别
|
||||
TTS_UTILS = "VITS" #TTS引擎配置,可选OPENVOICE或者VITS
|
||||
class UVICORN:
|
||||
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
||||
PORT = 8001 #uvicorn运行端口
|
||||
|
|
|
@ -164,11 +164,14 @@ class TextToSpeech:
|
|||
"""
|
||||
return audio_data.cpu().detach().float().numpy()
|
||||
|
||||
def numpy2bytes(self, audio_data: np.ndarray):
|
||||
"""
|
||||
numpy类型转bytes
|
||||
"""
|
||||
return (audio_data*32768.0).astype(np.int32).tobytes()
|
||||
def numpy2bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
return audio_data
|
||||
else:
|
||||
raise TypeError("audio_data must be a numpy array")
|
||||
|
||||
def _base_tts(self,
|
||||
text: str,
|
||||
|
@ -292,14 +295,11 @@ class TextToSpeech:
|
|||
|
||||
def synthesize(self,
|
||||
text: str,
|
||||
sdp_ratio=0.2,
|
||||
noise_scale=0.6,
|
||||
noise_scale_w=0.8,
|
||||
speed=1.0,
|
||||
quite=True,
|
||||
|
||||
tts_info,
|
||||
source_se: Optional[np.ndarray]=None,
|
||||
target_se: Optional[np.ndarray]=None,
|
||||
sdp_ratio=0.2,
|
||||
quite=True,
|
||||
tau :float=0.3,
|
||||
message :str="default"):
|
||||
"""
|
||||
|
@ -316,11 +316,11 @@ class TextToSpeech:
|
|||
"""
|
||||
audio, sr = self._base_tts(text,
|
||||
sdp_ratio=sdp_ratio,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_w=noise_scale_w,
|
||||
speed=speed,
|
||||
noise_scale=tts_info['noise_scale'],
|
||||
noise_scale_w=tts_info['noise_scale_w'],
|
||||
speed=tts_info['speed'],
|
||||
quite=quite)
|
||||
if self.use_tone_convert:
|
||||
if self.use_tone_convert and target_se.size>0:
|
||||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||
audio = F.resample(audio, tts_sr, converter_sr)
|
||||
|
@ -342,4 +342,5 @@ class TextToSpeech:
|
|||
save_path: 保存路径
|
||||
"""
|
||||
sf.write(save_path, audio, sample_rate)
|
||||
print(f"Audio saved to {save_path}")
|
||||
print(f"Audio saved to {save_path}")
|
||||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
from typing import Optional
|
||||
import soundfile as sf
|
||||
# vits
|
||||
from .vits import utils, commons
|
||||
|
@ -79,19 +80,19 @@ class TextToSpeech:
|
|||
print(f"Synthesis time: {time.time() - start_time} s")
|
||||
return audio
|
||||
|
||||
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
|
||||
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
||||
if not len(text):
|
||||
return "输入文本不能为空!", None
|
||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||
if len(text) > 100 and self.limitation:
|
||||
return f"输入文字过长!{len(text)}>100", None
|
||||
text = self._preprocess_text(text, language)
|
||||
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
|
||||
text = self._preprocess_text(text, tts_info['language'])
|
||||
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
||||
if self.debug or save_audio:
|
||||
self.save_audio(audio, self.RATE, 'output_file.wav')
|
||||
if return_bytes:
|
||||
audio = self.convert_numpy_to_bytes(audio)
|
||||
return self.RATE, audio
|
||||
return audio, self.RATE
|
||||
|
||||
def convert_numpy_to_bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
|
|
Loading…
Reference in New Issue