forked from killua/TakwayPlatform
feat: 后端封装tts,可以通过配置文件切换openvoice和vits
This commit is contained in:
parent
e2f3decfae
commit
ce033dca2b
|
@ -11,6 +11,7 @@ from datetime import datetime
|
||||||
from utils.xf_asr_utils import generate_xf_asr_url
|
from utils.xf_asr_utils import generate_xf_asr_url
|
||||||
from config import get_config
|
from config import get_config
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import struct
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -319,20 +320,11 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||||
elif response_type == RESPONSE_AUDIO:
|
elif response_type == RESPONSE_AUDIO:
|
||||||
if target_se.size == 0:
|
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||||
audio,sr = tts._base_tts(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"])
|
|
||||||
else:
|
|
||||||
audio,sr = tts.synthesize(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"],
|
|
||||||
target_se=target_se)
|
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_bytes(audio) #返回音频数据
|
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
header = struct.pack('!II',len(response_bytes),len(audio))
|
||||||
|
message_bytes = header + response_bytes + audio
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
logger.debug(f"websocket返回: {sentence}")
|
||||||
if is_end:
|
if is_end:
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
|
@ -500,17 +492,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
elif response_type == RESPONSE_AUDIO:
|
elif response_type == RESPONSE_AUDIO:
|
||||||
if target_se.size == 0:
|
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||||
audio,sr = tts._base_tts(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"])
|
|
||||||
else:
|
|
||||||
audio,sr = tts.synthesize(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"],
|
|
||||||
target_se=target_se)
|
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_bytes(audio)
|
await ws.send_bytes(audio)
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
|
@ -684,17 +666,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
llm_response += chunk_data
|
llm_response += chunk_data
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
if target_se.size == 0:
|
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||||
audio,sr = tts._base_tts(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"])
|
|
||||||
else:
|
|
||||||
audio,sr = tts.synthesize(text=sentence,
|
|
||||||
noise_scale=tts_info["noise_scale"],
|
|
||||||
noise_scale_w=tts_info["noise_scale_w"],
|
|
||||||
speed=tts_info["speed"],
|
|
||||||
target_se=target_se)
|
|
||||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
|
|
|
@ -14,7 +14,7 @@ import io
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
#依赖注入获取tts
|
#依赖注入获取tts
|
||||||
tts = get_tts()
|
tts = get_tts("OPENVOICE")
|
||||||
|
|
||||||
#创建用户
|
#创建用户
|
||||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||||
|
|
|
@ -1,11 +1,23 @@
|
||||||
from utils.tts.openvoice_utils import TextToSpeech
|
|
||||||
from app.dependencies.logger import get_logger
|
from app.dependencies.logger import get_logger
|
||||||
|
from config import get_config
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
Config = get_config()
|
||||||
|
|
||||||
|
from utils.tts.openvoice_utils import TextToSpeech
|
||||||
|
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
||||||
|
logger.info("TTS_OPENVOICE 初始化成功")
|
||||||
|
|
||||||
|
from utils.tts.vits_utils import TextToSpeech
|
||||||
|
vits_tts = TextToSpeech()
|
||||||
|
logger.info("TTS_VITS 初始化成功")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#初始化全局tts对象
|
#初始化全局tts对象
|
||||||
tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
def get_tts(tts_type=Config.TTS_UTILS):
|
||||||
logger.info("TTS初始化成功")
|
if tts_type == "OPENVOICE":
|
||||||
|
return openvoice_tts
|
||||||
def get_tts():
|
elif tts_type == "VITS":
|
||||||
return tts
|
return vits_tts
|
|
@ -2,6 +2,7 @@ class DevelopmentConfig:
|
||||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||||
LOG_LEVEL = "DEBUG" #日志级别
|
LOG_LEVEL = "DEBUG" #日志级别
|
||||||
|
TTS_UTILS = "VITS" #TTS引擎配置,可选OPENVOICE或者VITS
|
||||||
class UVICORN:
|
class UVICORN:
|
||||||
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
||||||
PORT = 8001 #uvicorn运行端口
|
PORT = 8001 #uvicorn运行端口
|
||||||
|
|
|
@ -164,11 +164,14 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
return audio_data.cpu().detach().float().numpy()
|
return audio_data.cpu().detach().float().numpy()
|
||||||
|
|
||||||
def numpy2bytes(self, audio_data: np.ndarray):
|
def numpy2bytes(self, audio_data):
|
||||||
"""
|
if isinstance(audio_data, np.ndarray):
|
||||||
numpy类型转bytes
|
if audio_data.dtype == np.dtype('float32'):
|
||||||
"""
|
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||||
return (audio_data*32768.0).astype(np.int32).tobytes()
|
audio_data = audio_data.tobytes()
|
||||||
|
return audio_data
|
||||||
|
else:
|
||||||
|
raise TypeError("audio_data must be a numpy array")
|
||||||
|
|
||||||
def _base_tts(self,
|
def _base_tts(self,
|
||||||
text: str,
|
text: str,
|
||||||
|
@ -292,14 +295,11 @@ class TextToSpeech:
|
||||||
|
|
||||||
def synthesize(self,
|
def synthesize(self,
|
||||||
text: str,
|
text: str,
|
||||||
sdp_ratio=0.2,
|
tts_info,
|
||||||
noise_scale=0.6,
|
|
||||||
noise_scale_w=0.8,
|
|
||||||
speed=1.0,
|
|
||||||
quite=True,
|
|
||||||
|
|
||||||
source_se: Optional[np.ndarray]=None,
|
source_se: Optional[np.ndarray]=None,
|
||||||
target_se: Optional[np.ndarray]=None,
|
target_se: Optional[np.ndarray]=None,
|
||||||
|
sdp_ratio=0.2,
|
||||||
|
quite=True,
|
||||||
tau :float=0.3,
|
tau :float=0.3,
|
||||||
message :str="default"):
|
message :str="default"):
|
||||||
"""
|
"""
|
||||||
|
@ -316,11 +316,11 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
audio, sr = self._base_tts(text,
|
audio, sr = self._base_tts(text,
|
||||||
sdp_ratio=sdp_ratio,
|
sdp_ratio=sdp_ratio,
|
||||||
noise_scale=noise_scale,
|
noise_scale=tts_info['noise_scale'],
|
||||||
noise_scale_w=noise_scale_w,
|
noise_scale_w=tts_info['noise_scale_w'],
|
||||||
speed=speed,
|
speed=tts_info['speed'],
|
||||||
quite=quite)
|
quite=quite)
|
||||||
if self.use_tone_convert:
|
if self.use_tone_convert and target_se.size>0:
|
||||||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||||
audio = F.resample(audio, tts_sr, converter_sr)
|
audio = F.resample(audio, tts_sr, converter_sr)
|
||||||
|
@ -342,4 +342,5 @@ class TextToSpeech:
|
||||||
save_path: 保存路径
|
save_path: 保存路径
|
||||||
"""
|
"""
|
||||||
sf.write(save_path, audio, sample_rate)
|
sf.write(save_path, audio, sample_rate)
|
||||||
print(f"Audio saved to {save_path}")
|
print(f"Audio saved to {save_path}")
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import LongTensor
|
from torch import LongTensor
|
||||||
|
from typing import Optional
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
# vits
|
# vits
|
||||||
from .vits import utils, commons
|
from .vits import utils, commons
|
||||||
|
@ -79,19 +80,19 @@ class TextToSpeech:
|
||||||
print(f"Synthesis time: {time.time() - start_time} s")
|
print(f"Synthesis time: {time.time() - start_time} s")
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
|
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
||||||
if not len(text):
|
if not len(text):
|
||||||
return "输入文本不能为空!", None
|
return "输入文本不能为空!", None
|
||||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||||
if len(text) > 100 and self.limitation:
|
if len(text) > 100 and self.limitation:
|
||||||
return f"输入文字过长!{len(text)}>100", None
|
return f"输入文字过长!{len(text)}>100", None
|
||||||
text = self._preprocess_text(text, language)
|
text = self._preprocess_text(text, tts_info['language'])
|
||||||
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
|
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
||||||
if self.debug or save_audio:
|
if self.debug or save_audio:
|
||||||
self.save_audio(audio, self.RATE, 'output_file.wav')
|
self.save_audio(audio, self.RATE, 'output_file.wav')
|
||||||
if return_bytes:
|
if return_bytes:
|
||||||
audio = self.convert_numpy_to_bytes(audio)
|
audio = self.convert_numpy_to_bytes(audio)
|
||||||
return self.RATE, audio
|
return audio, self.RATE
|
||||||
|
|
||||||
def convert_numpy_to_bytes(self, audio_data):
|
def convert_numpy_to_bytes(self, audio_data):
|
||||||
if isinstance(audio_data, np.ndarray):
|
if isinstance(audio_data, np.ndarray):
|
||||||
|
|
Loading…
Reference in New Issue