TakwayDisplayPlatform/utils/bert_vits2_utils.py

471 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gc
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
import logging
import gradio as gr
import librosa
# bert_vits2
from .bert_vits2 import utils
from .bert_vits2.infer import get_net_g, latest_version, infer_multilang, infer
from .bert_vits2.config import config
from .bert_vits2 import re_matching
from .bert_vits2.tools.sentence import split_by_language
logger = logging.getLogger(__name__)
class TextToSpeech:
def __init__(self,
device='cuda',
):
self.device = device = torch.device(device)
if config.webui_config.debug:
logger.info("Enable DEBUG")
hps = utils.get_hparams_from_file(config.webui_config.config_path)
self.hps = hps
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
self.version = version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
self.net_g = net_g
self.speaker_ids = speaker_ids = hps.data.spk2id
self.speakers = speakers = list(speaker_ids.keys())
self.speaker = speakers[0]
self.languages = languages = ["ZH", "JP", "EN", "mix", "auto"]
def free_up_memory(self):
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_mix(self, slice):
_speaker = slice.pop()
_text, _lang = [], []
for lang, content in slice:
content = content.split("|")
content = [part for part in content if part != ""]
if len(content) == 0:
continue
if len(_text) == 0:
_text = [[part] for part in content]
_lang = [[lang] for part in content]
else:
_text[-1].append(content[0])
_lang[-1].append(lang)
if len(content) > 1:
_text += [[part] for part in content[1:]]
_lang += [[lang] for part in content[1:]]
return _text, _lang, _speaker
def process_auto(self, text):
_text, _lang = [], []
for slice in text.split("|"):
if slice == "":
continue
temp_text, temp_lang = [], []
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
for sentence, lang in sentences_list:
if sentence == "":
continue
temp_text.append(sentence)
if lang == "ja":
lang = "jp"
temp_lang.append(lang.upper())
_text.append(temp_text)
_lang.append(temp_lang)
return _text, _lang
def generate_audio(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def generate_audio_multilang(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
skip_start=False,
skip_end=False,
en_ratio=1.0
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer_multilang(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language[idx],
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
en_ratio=en_ratio
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text=None,
style_weight=0,
en_ratio=1.0
):
hps = self.hps
audio_list = []
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
for slice in re_matching.text_matching(text):
_text, _lang, _speaker = self.process_mix(slice)
if _speaker is None:
continue
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
elif language.lower() == "auto":
_text, _lang = self.process_auto(text)
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
else:
audio_list.extend(
self.generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_split(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
cut_by_sent,
interval_between_para,
interval_between_sent,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
):
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
text = text.replace("|", "")
para_list = re_matching.cut_para(text)
para_list = [p for p in para_list if p != ""]
audio_list = []
for p in para_list:
if not cut_by_sent:
audio_list += self.process_text(
p,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
audio_list.append(silence)
else:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
sent_list = [s for s in sent_list if s != ""]
for s in sent_list:
audio_list_sent += self.process_text(
s,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros(
(int)(44100 * (interval_between_para - interval_between_sent))
)
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
np.concatenate(audio_list_sent)
) # 对完整句子做音量归一
audio_list.append(audio16bit)
audio_concat = np.concatenate(audio_list)
return ("Success", (self.hps.data.sampling_rate, audio_concat))
def tts_fn(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
if prompt_mode == "Audio prompt":
if reference_audio == None:
return ("Invalid audio prompt", None)
else:
reference_audio = self.load_audio(reference_audio)[1]
else:
reference_audio = None
audio_list = self.process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (self.hps.data.sampling_rate, audio_concat)
def load_audio(self, path):
audio, sr = librosa.load(path, 48000)
# audio = librosa.resample(audio, 44100, 48000)
return sr, audio
def format_utils(self, text, speaker):
_text, _lang = self.process_auto(text)
res = f"[{speaker}]"
for lang_s, content_s in zip(_lang, _text):
for lang, content in zip(lang_s, content_s):
# res += f"<{lang.lower()}>{content}"
# 部分中文会被识别成日文,强转成中文
lang = lang.lower().replace("jp", "zh")
res += f"<{lang}>{content}"
res += "|"
return "mix", res[:-1]
def synthesize(self,
text,
tts_info,
):
"""
return: audio, sample_rate
"""
speaker_id = tts_info['speaker_id'] # self.speakers 的 index指定说话
sdp_ratio = tts_info['sdp_ratio']
noise_scale = tts_info['noise_scale']
noise_scale_w = tts_info['noise_scale_w']
length_scale = tts_info['length_scale']
language = tts_info['language'] # ["ZH", "EN", "mix"] 三选一
opt_cut_by_send = tts_info['opt_cut_by_send']
interval_between_para = tts_info['interval_between_para'] # 段间停顿(秒),需要大于句间停顿才有效
interval_between_sent = tts_info['interval_between_sent'] # 句间停顿(秒),勾选按句切分才生效
audio_prompt = None
text_prompt = ""
prompt_mode = "Text prompts"
style_text = tts_info['style_text']
style_weight = tts_info['style_weight']
en_ratio = tts_info['en_ratio']
speaker = self.speakers[speaker_id]
if language == "mix":
language, text = self.format_utils(text, speaker)
text_output, audio_output = self.tts_split(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
opt_cut_by_send,
interval_between_para,
interval_between_sent,
audio_prompt,
text_prompt,
style_text,
style_weight,
en_ratio
)
else:
text_output, audio_output = self.tts_fn(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
audio_prompt,
text_prompt,
prompt_mode,
style_text,
style_weight
)
return self.convert_numpy_to_bytes(audio_output[1])
def print_speakers_info(self):
for i, speaker in enumerate(self.speakers):
print(f"id: {i}, speaker: {speaker}")
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")