forked from killua/TakwayDisplayPlatform
feature: 集成bert_vits
This commit is contained in:
parent
2b98752db1
commit
c58c7c9a5b
|
@ -6,7 +6,8 @@ from .abstract import *
|
|||
from .public import *
|
||||
from .exception import *
|
||||
from .dependency import get_logger
|
||||
from utils.vits_utils import TextToSpeech
|
||||
from utils.vits_utils import TextToSpeech as VITS_TextToSpeech
|
||||
from utils.bert_vits2_utils import TextToSpeech as BertVits_TextToSpeech
|
||||
from config import Config
|
||||
import threading
|
||||
import requests
|
||||
|
@ -17,7 +18,11 @@ import time
|
|||
import json
|
||||
|
||||
# ----------- 初始化vits ----------- #
|
||||
vits = TextToSpeech()
|
||||
vits = VITS_TextToSpeech()
|
||||
# ---------------------------------- #
|
||||
|
||||
# -------- 初始化bert-vits --------- #
|
||||
bert_vits = BertVits_TextToSpeech()
|
||||
# ---------------------------------- #
|
||||
|
||||
# ---------- 初始化logger ---------- #
|
||||
|
@ -294,6 +299,14 @@ class VITS_TTS(TTS):
|
|||
def synthetize(self, assistant, text):
|
||||
tts_info = json.loads(assistant.tts_info)
|
||||
return vits.synthesize(text, tts_info)
|
||||
|
||||
class BertVits_TTS(TTS):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def synthetize(self, assistant, text):
|
||||
tts_info = json.loads(assistant.tts_info)
|
||||
return bert_vits.synthesize(text, tts_info)
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
|
@ -319,6 +332,8 @@ class TTSFactory:
|
|||
def create_tts(self,tts_type:str) -> TTS:
|
||||
if tts_type == 'VITS':
|
||||
return VITS_TTS()
|
||||
if tts_type == 'BertVits':
|
||||
return BertVits_TTS()
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
|
@ -420,7 +435,12 @@ class Agent():
|
|||
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
||||
|
||||
def init_recorder(self,user_id):
|
||||
self.recorder = Recorder(user_id)
|
||||
input_sr = 16000
|
||||
if isinstance(self.tts, BertVits_TTS):
|
||||
output_sr = 44100
|
||||
elif isinstance(self.tts, VITS_TTS):
|
||||
output_sr = 22050
|
||||
self.recorder = Recorder(user_id,input_sr,output_sr)
|
||||
|
||||
# 对用户输入的音频进行预处理
|
||||
def user_audio_process(self, audio):
|
||||
|
|
|
@ -30,12 +30,12 @@ class SentenceSegmentation():
|
|||
return self.__sentenceSegmentation(llm_chunk)
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, user_id):
|
||||
def __init__(self, user_id, input_sr, output_sr):
|
||||
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
||||
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
||||
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
||||
self.input_sr = 16000
|
||||
self.output_sr = 22050
|
||||
self.input_sr = input_sr
|
||||
self.output_sr = output_sr
|
||||
self.user_audio = b''
|
||||
self.tts_audio = b''
|
||||
self.input_text = ""
|
||||
|
|
|
@ -33,8 +33,13 @@ class update_assistant_deatil_params_request(BaseModel):
|
|||
platform:str
|
||||
model :str
|
||||
temperature :float
|
||||
tts_engine:str
|
||||
speaker_id:int
|
||||
length_scale:float
|
||||
language:str
|
||||
style_text:str
|
||||
style_weight:float
|
||||
|
||||
|
||||
class update_assistant_max_tokens_request(BaseModel):
|
||||
max_tokens:int
|
|
@ -1,8 +1,5 @@
|
|||
class Config:
|
||||
SQLITE_URL = 'sqlite:///takway.db'
|
||||
ASR = "XF" #在此处选择语音识别引擎
|
||||
LLM = "MINIMAX" #在此处选择大模型
|
||||
TTS = "VITS" #在此处选择语音合成引擎
|
||||
LOG_LEVEL = "DEBUG"
|
||||
class UVICORN:
|
||||
HOST = '0.0.0.0'
|
||||
|
|
17
main.py
17
main.py
|
@ -122,7 +122,16 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
|
|||
llm_info['temperature'] = request.temperature
|
||||
tts_info['speaker_id'] = request.speaker_id
|
||||
tts_info['length_scale'] = request.length_scale
|
||||
tts_info['language'] = request.language
|
||||
tts_info['style_text'] = request.style_text
|
||||
tts_info['style_weight'] = request.style_weight
|
||||
tts_info['sdp_ratio'] = 0.5
|
||||
tts_info['opt_cut_by_send'] = False
|
||||
tts_info['interval_between_para'] = 1.0
|
||||
tts_info['interval_between_sent'] = 0.2
|
||||
tts_info['en_ratio'] = 1.0
|
||||
user_info['llm_type'] = request.platform
|
||||
user_info['tts_type'] = request.tts_engine
|
||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||||
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
||||
|
@ -227,15 +236,7 @@ async def streaming_chat(ws: WebSocket):
|
|||
agent.recorder.input_text = prompt
|
||||
logger.debug("开始调用大模型")
|
||||
llm_frames = await agent.chat(assistant, prompt)
|
||||
|
||||
start_time = time.time()
|
||||
is_first_response = True
|
||||
|
||||
for llm_frame in llm_frames:
|
||||
if is_first_response:
|
||||
end_time = time.time()
|
||||
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||
is_first_response = False
|
||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
|
|
|
@ -14,3 +14,4 @@ librosa
|
|||
aiohttp
|
||||
'volcengine-python-sdk[ark]'
|
||||
zhipuai
|
||||
pyopenjtalk
|
13
test.py
13
test.py
|
@ -1,13 +0,0 @@
|
|||
from utils.bert_vits2_utils import TextToSpeech
|
||||
import soundfile as sf
|
||||
tts = TextToSpeech()
|
||||
tts.print_speakers_info()
|
||||
|
||||
audio, sample_rate= tts.synthesize("你好,我好开心", # 文本
|
||||
0, # 说话人 id
|
||||
style_text="我很难过!!!!呜呜呜!!!", # 情绪prompt,当language=="ZH" 才有效
|
||||
style_weight=0.9, # 情绪prompt权重
|
||||
language="mix", # 语言类型,包括 "ZH" "EN" "mix"
|
||||
en_ratio=1.) # mix语言类型下,英文文本速度,越大速度越慢
|
||||
save_path = "./tmp2.wav"
|
||||
sf.write(save_path, audio, sample_rate)
|
|
@ -395,29 +395,28 @@ class TextToSpeech:
|
|||
|
||||
def synthesize(self,
|
||||
text,
|
||||
speaker_idx=0, # self.speakers 的 index,指定说话
|
||||
sdp_ratio=0.5,
|
||||
noise_scale=0.6,
|
||||
noise_scale_w=0.9,
|
||||
length_scale=1.0, # 越大语速越慢
|
||||
language="mix", # ["ZH", "EN", "mix"] 三选一
|
||||
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
|
||||
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
|
||||
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
|
||||
audio_prompt=None,
|
||||
text_prompt="",
|
||||
prompt_mode="Text prompts",
|
||||
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
||||
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
|
||||
# "效果较不明确,留空即为不使用该功能"
|
||||
style_weight=0.7, # "主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本
|
||||
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
|
||||
tts_info,
|
||||
):
|
||||
"""
|
||||
return: audio, sample_rate
|
||||
"""
|
||||
speaker_id = tts_info['speaker_id'] # self.speakers 的 index,指定说话
|
||||
sdp_ratio = tts_info['sdp_ratio']
|
||||
noise_scale = tts_info['noise_scale']
|
||||
noise_scale_w = tts_info['noise_scale_w']
|
||||
length_scale = tts_info['length_scale']
|
||||
language = tts_info['language'] # ["ZH", "EN", "mix"] 三选一
|
||||
opt_cut_by_send = tts_info['opt_cut_by_send']
|
||||
interval_between_para = tts_info['interval_between_para'] # 段间停顿(秒),需要大于句间停顿才有效
|
||||
interval_between_sent = tts_info['interval_between_sent'] # 句间停顿(秒),勾选按句切分才生效
|
||||
audio_prompt = None
|
||||
text_prompt = ""
|
||||
prompt_mode = "Text prompts"
|
||||
style_text = tts_info['style_text']
|
||||
style_weight = tts_info['style_weight']
|
||||
en_ratio = tts_info['en_ratio']
|
||||
|
||||
speaker = self.speakers[speaker_idx]
|
||||
speaker = self.speakers[speaker_id]
|
||||
|
||||
if language == "mix":
|
||||
language, text = self.format_utils(text, speaker)
|
||||
|
@ -455,9 +454,17 @@ class TextToSpeech:
|
|||
style_weight
|
||||
)
|
||||
|
||||
# return text_output, audio_output
|
||||
return audio_output[1], audio_output[0]
|
||||
return self.convert_numpy_to_bytes(audio_output[1])
|
||||
|
||||
def print_speakers_info(self):
|
||||
for i, speaker in enumerate(self.speakers):
|
||||
print(f"id: {i}, speaker: {speaker}")
|
||||
|
||||
def convert_numpy_to_bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
return audio_data
|
||||
else:
|
||||
raise TypeError("audio_data must be a numpy array")
|
||||
|
|
Loading…
Reference in New Issue