1
0
Fork 0
TakwayDisplayPlatform/utils/bert_vits2_utils.py

464 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gc
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
import logging
import gradio as gr
import librosa
# bert_vits2
from .bert_vits2 import utils
from .bert_vits2.infer import get_net_g, latest_version, infer_multilang, infer
from .bert_vits2.config import config
from .bert_vits2 import re_matching
from .bert_vits2.tools.sentence import split_by_language
logger = logging.getLogger(__name__)
class TextToSpeech:
def __init__(self,
device='cuda',
):
self.device = device = torch.device(device)
if config.webui_config.debug:
logger.info("Enable DEBUG")
hps = utils.get_hparams_from_file(config.webui_config.config_path)
self.hps = hps
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
self.version = version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
self.net_g = net_g
self.speaker_ids = speaker_ids = hps.data.spk2id
self.speakers = speakers = list(speaker_ids.keys())
self.speaker = speakers[0]
self.languages = languages = ["ZH", "JP", "EN", "mix", "auto"]
def free_up_memory(self):
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_mix(self, slice):
_speaker = slice.pop()
_text, _lang = [], []
for lang, content in slice:
content = content.split("|")
content = [part for part in content if part != ""]
if len(content) == 0:
continue
if len(_text) == 0:
_text = [[part] for part in content]
_lang = [[lang] for part in content]
else:
_text[-1].append(content[0])
_lang[-1].append(lang)
if len(content) > 1:
_text += [[part] for part in content[1:]]
_lang += [[lang] for part in content[1:]]
return _text, _lang, _speaker
def process_auto(self, text):
_text, _lang = [], []
for slice in text.split("|"):
if slice == "":
continue
temp_text, temp_lang = [], []
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
for sentence, lang in sentences_list:
if sentence == "":
continue
temp_text.append(sentence)
if lang == "ja":
lang = "jp"
temp_lang.append(lang.upper())
_text.append(temp_text)
_lang.append(temp_lang)
return _text, _lang
def generate_audio(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def generate_audio_multilang(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
skip_start=False,
skip_end=False,
en_ratio=1.0
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer_multilang(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language[idx],
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
en_ratio=en_ratio
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text=None,
style_weight=0,
en_ratio=1.0
):
hps = self.hps
audio_list = []
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
for slice in re_matching.text_matching(text):
_text, _lang, _speaker = self.process_mix(slice)
if _speaker is None:
continue
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
elif language.lower() == "auto":
_text, _lang = self.process_auto(text)
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
else:
audio_list.extend(
self.generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_split(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
cut_by_sent,
interval_between_para,
interval_between_sent,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
):
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
text = text.replace("|", "")
para_list = re_matching.cut_para(text)
para_list = [p for p in para_list if p != ""]
audio_list = []
for p in para_list:
if not cut_by_sent:
audio_list += self.process_text(
p,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
audio_list.append(silence)
else:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
sent_list = [s for s in sent_list if s != ""]
for s in sent_list:
audio_list_sent += self.process_text(
s,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros(
(int)(44100 * (interval_between_para - interval_between_sent))
)
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
np.concatenate(audio_list_sent)
) # 对完整句子做音量归一
audio_list.append(audio16bit)
audio_concat = np.concatenate(audio_list)
return ("Success", (self.hps.data.sampling_rate, audio_concat))
def tts_fn(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
if prompt_mode == "Audio prompt":
if reference_audio == None:
return ("Invalid audio prompt", None)
else:
reference_audio = self.load_audio(reference_audio)[1]
else:
reference_audio = None
audio_list = self.process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (self.hps.data.sampling_rate, audio_concat)
def load_audio(self, path):
audio, sr = librosa.load(path, 48000)
# audio = librosa.resample(audio, 44100, 48000)
return sr, audio
def format_utils(self, text, speaker):
_text, _lang = self.process_auto(text)
res = f"[{speaker}]"
for lang_s, content_s in zip(_lang, _text):
for lang, content in zip(lang_s, content_s):
# res += f"<{lang.lower()}>{content}"
# 部分中文会被识别成日文,强转成中文
lang = lang.lower().replace("jp", "zh")
res += f"<{lang}>{content}"
res += "|"
return "mix", res[:-1]
def synthesize(self,
text,
speaker_idx=0, # self.speakers 的 index指定说话
sdp_ratio=0.5,
noise_scale=0.6,
noise_scale_w=0.9,
length_scale=1.0, # 越大语速越慢
language="mix", # ["ZH", "EN", "mix"] 三选一
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
audio_prompt=None,
text_prompt="",
prompt_mode="Text prompts",
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
# "效果较不明确,留空即为不使用该功能"
style_weight=0.7, # "主文本和辅助文本的bert混合比率0表示仅主文本1表示仅辅助文本
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
):
"""
return: audio, sample_rate
"""
speaker = self.speakers[speaker_idx]
if language == "mix":
language, text = self.format_utils(text, speaker)
text_output, audio_output = self.tts_split(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
opt_cut_by_send,
interval_between_para,
interval_between_sent,
audio_prompt,
text_prompt,
style_text,
style_weight,
en_ratio
)
else:
text_output, audio_output = self.tts_fn(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
audio_prompt,
text_prompt,
prompt_mode,
style_text,
style_weight
)
# return text_output, audio_output
return audio_output[1], audio_output[0]
def print_speakers_info(self):
for i, speaker in enumerate(self.speakers):
print(f"id: {i}, speaker: {speaker}")