TakwayDisplayPlatform/utils/bert_vits2_utils.py

464 lines
16 KiB
Python
Raw Normal View History

2024-06-23 20:39:44 +08:00
import gc
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
import logging
import gradio as gr
import librosa
# bert_vits2
from .bert_vits2 import utils
from .bert_vits2.infer import get_net_g, latest_version, infer_multilang, infer
from .bert_vits2.config import config
from .bert_vits2 import re_matching
from .bert_vits2.tools.sentence import split_by_language
logger = logging.getLogger(__name__)
class TextToSpeech:
def __init__(self,
device='cuda',
):
self.device = device = torch.device(device)
if config.webui_config.debug:
logger.info("Enable DEBUG")
hps = utils.get_hparams_from_file(config.webui_config.config_path)
self.hps = hps
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
self.version = version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
self.net_g = net_g
self.speaker_ids = speaker_ids = hps.data.spk2id
self.speakers = speakers = list(speaker_ids.keys())
self.speaker = speakers[0]
self.languages = languages = ["ZH", "JP", "EN", "mix", "auto"]
def free_up_memory(self):
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_mix(self, slice):
_speaker = slice.pop()
_text, _lang = [], []
for lang, content in slice:
content = content.split("|")
content = [part for part in content if part != ""]
if len(content) == 0:
continue
if len(_text) == 0:
_text = [[part] for part in content]
_lang = [[lang] for part in content]
else:
_text[-1].append(content[0])
_lang[-1].append(lang)
if len(content) > 1:
_text += [[part] for part in content[1:]]
_lang += [[lang] for part in content[1:]]
return _text, _lang, _speaker
def process_auto(self, text):
_text, _lang = [], []
for slice in text.split("|"):
if slice == "":
continue
temp_text, temp_lang = [], []
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
for sentence, lang in sentences_list:
if sentence == "":
continue
temp_text.append(sentence)
if lang == "ja":
lang = "jp"
temp_lang.append(lang.upper())
_text.append(temp_text)
_lang.append(temp_lang)
return _text, _lang
def generate_audio(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def generate_audio_multilang(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
skip_start=False,
skip_end=False,
en_ratio=1.0
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer_multilang(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language[idx],
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
en_ratio=en_ratio
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text=None,
style_weight=0,
en_ratio=1.0
):
hps = self.hps
audio_list = []
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
for slice in re_matching.text_matching(text):
_text, _lang, _speaker = self.process_mix(slice)
if _speaker is None:
continue
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
elif language.lower() == "auto":
_text, _lang = self.process_auto(text)
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
else:
audio_list.extend(
self.generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_split(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
cut_by_sent,
interval_between_para,
interval_between_sent,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
):
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
text = text.replace("|", "")
para_list = re_matching.cut_para(text)
para_list = [p for p in para_list if p != ""]
audio_list = []
for p in para_list:
if not cut_by_sent:
audio_list += self.process_text(
p,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
audio_list.append(silence)
else:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
sent_list = [s for s in sent_list if s != ""]
for s in sent_list:
audio_list_sent += self.process_text(
s,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros(
(int)(44100 * (interval_between_para - interval_between_sent))
)
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
np.concatenate(audio_list_sent)
) # 对完整句子做音量归一
audio_list.append(audio16bit)
audio_concat = np.concatenate(audio_list)
return ("Success", (self.hps.data.sampling_rate, audio_concat))
def tts_fn(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
if prompt_mode == "Audio prompt":
if reference_audio == None:
return ("Invalid audio prompt", None)
else:
reference_audio = self.load_audio(reference_audio)[1]
else:
reference_audio = None
audio_list = self.process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (self.hps.data.sampling_rate, audio_concat)
def load_audio(self, path):
audio, sr = librosa.load(path, 48000)
# audio = librosa.resample(audio, 44100, 48000)
return sr, audio
def format_utils(self, text, speaker):
_text, _lang = self.process_auto(text)
res = f"[{speaker}]"
for lang_s, content_s in zip(_lang, _text):
for lang, content in zip(lang_s, content_s):
# res += f"<{lang.lower()}>{content}"
# 部分中文会被识别成日文,强转成中文
lang = lang.lower().replace("jp", "zh")
res += f"<{lang}>{content}"
res += "|"
return "mix", res[:-1]
def synthesize(self,
text,
speaker_idx=0, # self.speakers 的 index指定说话
sdp_ratio=0.5,
noise_scale=0.6,
noise_scale_w=0.9,
length_scale=1.0, # 越大语速越慢
language="mix", # ["ZH", "EN", "mix"] 三选一
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
audio_prompt=None,
text_prompt="",
prompt_mode="Text prompts",
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
# "效果较不明确,留空即为不使用该功能"
style_weight=0.7, # "主文本和辅助文本的bert混合比率0表示仅主文本1表示仅辅助文本
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
):
"""
return: audio, sample_rate
"""
speaker = self.speakers[speaker_idx]
if language == "mix":
language, text = self.format_utils(text, speaker)
text_output, audio_output = self.tts_split(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
opt_cut_by_send,
interval_between_para,
interval_between_sent,
audio_prompt,
text_prompt,
style_text,
style_weight,
en_ratio
)
else:
text_output, audio_output = self.tts_fn(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
audio_prompt,
text_prompt,
prompt_mode,
style_text,
style_weight
)
# return text_output, audio_output
return audio_output[1], audio_output[0]
def print_speakers_info(self):
for i, speaker in enumerate(self.speakers):
print(f"id: {i}, speaker: {speaker}")