440 lines
14 KiB
Python
440 lines
14 KiB
Python
|
"""
|
|||
|
版本管理、兼容推理及模型加载实现。
|
|||
|
版本说明:
|
|||
|
1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
|
|||
|
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
|
|||
|
特殊版本说明:
|
|||
|
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
|
|||
|
2.3:当前版本
|
|||
|
"""
|
|||
|
import torch
|
|||
|
from . import commons
|
|||
|
from .text import cleaned_text_to_sequence, get_bert
|
|||
|
|
|||
|
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
|
|||
|
from typing import Union
|
|||
|
from .text.cleaner import clean_text
|
|||
|
from . import utils
|
|||
|
|
|||
|
from .models import SynthesizerTrn
|
|||
|
from .text.symbols import symbols
|
|||
|
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V220.models import SynthesizerTrn as V220SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V220.text import symbols as V220symbols
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V210.text import symbols as V210symbols
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V200.text import symbols as V200symbols
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V111.text import symbols as V111symbols
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V110.text import symbols as V110symbols
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
|||
|
# from utils.tts.bert_vits2.oldVersion.V101.text import symbols as V101symbols
|
|||
|
|
|||
|
# from oldVersion import V111, V110, V101, V200, V210, V220
|
|||
|
|
|||
|
# 当前版本信息
|
|||
|
latest_version = "2.3"
|
|||
|
|
|||
|
# 版本兼容
|
|||
|
# SynthesizerTrnMap = {
|
|||
|
# "2.2": V220SynthesizerTrn,
|
|||
|
# "2.1": V210SynthesizerTrn,
|
|||
|
# "2.0.2-fix": V200SynthesizerTrn,
|
|||
|
# "2.0.1": V200SynthesizerTrn,
|
|||
|
# "2.0": V200SynthesizerTrn,
|
|||
|
# "1.1.1-fix": V111SynthesizerTrn,
|
|||
|
# "1.1.1": V111SynthesizerTrn,
|
|||
|
# "1.1": V110SynthesizerTrn,
|
|||
|
# "1.1.0": V110SynthesizerTrn,
|
|||
|
# "1.0.1": V101SynthesizerTrn,
|
|||
|
# "1.0": V101SynthesizerTrn,
|
|||
|
# "1.0.0": V101SynthesizerTrn,
|
|||
|
# }
|
|||
|
|
|||
|
# symbolsMap = {
|
|||
|
# "2.2": V220symbols,
|
|||
|
# "2.1": V210symbols,
|
|||
|
# "2.0.2-fix": V200symbols,
|
|||
|
# "2.0.1": V200symbols,
|
|||
|
# "2.0": V200symbols,
|
|||
|
# "1.1.1-fix": V111symbols,
|
|||
|
# "1.1.1": V111symbols,
|
|||
|
# "1.1": V110symbols,
|
|||
|
# "1.1.0": V110symbols,
|
|||
|
# "1.0.1": V101symbols,
|
|||
|
# "1.0": V101symbols,
|
|||
|
# "1.0.0": V101symbols,
|
|||
|
# }
|
|||
|
|
|||
|
|
|||
|
# def get_emo_(reference_audio, emotion, sid):
|
|||
|
# emo = (
|
|||
|
# torch.from_numpy(get_emo(reference_audio))
|
|||
|
# if reference_audio and emotion == -1
|
|||
|
# else torch.FloatTensor(
|
|||
|
# np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
|
|||
|
# )
|
|||
|
# )
|
|||
|
# return emo
|
|||
|
|
|||
|
|
|||
|
def get_net_g(model_path: str, version: str, device: str, hps):
|
|||
|
if version != latest_version:
|
|||
|
net_g = SynthesizerTrnMap[version](
|
|||
|
len(symbolsMap[version]),
|
|||
|
hps.data.filter_length // 2 + 1,
|
|||
|
hps.train.segment_size // hps.data.hop_length,
|
|||
|
n_speakers=hps.data.n_speakers,
|
|||
|
**hps.model,
|
|||
|
).to(device)
|
|||
|
else:
|
|||
|
# 当前版本模型 net_g
|
|||
|
net_g = SynthesizerTrn(
|
|||
|
len(symbols),
|
|||
|
hps.data.filter_length // 2 + 1,
|
|||
|
hps.train.segment_size // hps.data.hop_length,
|
|||
|
n_speakers=hps.data.n_speakers,
|
|||
|
**hps.model,
|
|||
|
).to(device)
|
|||
|
_ = net_g.eval()
|
|||
|
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
|||
|
return net_g
|
|||
|
|
|||
|
|
|||
|
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
|||
|
style_text = None if style_text == "" else style_text
|
|||
|
# 在此处实现当前版本的get_text
|
|||
|
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
|||
|
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|||
|
|
|||
|
if hps.data.add_blank:
|
|||
|
phone = commons.intersperse(phone, 0)
|
|||
|
tone = commons.intersperse(tone, 0)
|
|||
|
language = commons.intersperse(language, 0)
|
|||
|
for i in range(len(word2ph)):
|
|||
|
word2ph[i] = word2ph[i] * 2
|
|||
|
word2ph[0] += 1
|
|||
|
bert_ori = get_bert(
|
|||
|
norm_text, word2ph, language_str, device, style_text, style_weight
|
|||
|
)
|
|||
|
del word2ph
|
|||
|
assert bert_ori.shape[-1] == len(phone), phone
|
|||
|
|
|||
|
if language_str == "ZH":
|
|||
|
bert = bert_ori
|
|||
|
ja_bert = torch.randn(1024, len(phone))
|
|||
|
en_bert = torch.randn(1024, len(phone))
|
|||
|
elif language_str == "JP":
|
|||
|
bert = torch.randn(1024, len(phone))
|
|||
|
ja_bert = bert_ori
|
|||
|
en_bert = torch.randn(1024, len(phone))
|
|||
|
elif language_str == "EN":
|
|||
|
bert = torch.randn(1024, len(phone))
|
|||
|
ja_bert = torch.randn(1024, len(phone))
|
|||
|
en_bert = bert_ori
|
|||
|
else:
|
|||
|
raise ValueError("language_str should be ZH, JP or EN")
|
|||
|
|
|||
|
assert bert.shape[-1] == len(
|
|||
|
phone
|
|||
|
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
|||
|
|
|||
|
phone = torch.LongTensor(phone)
|
|||
|
tone = torch.LongTensor(tone)
|
|||
|
language = torch.LongTensor(language)
|
|||
|
return bert, ja_bert, en_bert, phone, tone, language
|
|||
|
|
|||
|
|
|||
|
def infer(
|
|||
|
text,
|
|||
|
emotion: Union[int, str],
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
reference_audio=None,
|
|||
|
skip_start=False,
|
|||
|
skip_end=False,
|
|||
|
style_text=None,
|
|||
|
style_weight=0.7,
|
|||
|
):
|
|||
|
# # 2.2版本参数位置变了
|
|||
|
# inferMap_V4 = {
|
|||
|
# "2.2": V220.infer,
|
|||
|
# }
|
|||
|
# # 2.1 参数新增 emotion reference_audio skip_start skip_end
|
|||
|
# inferMap_V3 = {
|
|||
|
# "2.1": V210.infer,
|
|||
|
# }
|
|||
|
# # 支持中日英三语版本
|
|||
|
# inferMap_V2 = {
|
|||
|
# "2.0.2-fix": V200.infer,
|
|||
|
# "2.0.1": V200.infer,
|
|||
|
# "2.0": V200.infer,
|
|||
|
# "1.1.1-fix": V111.infer_fix,
|
|||
|
# "1.1.1": V111.infer,
|
|||
|
# "1.1": V110.infer,
|
|||
|
# "1.1.0": V110.infer,
|
|||
|
# }
|
|||
|
# # 仅支持中文版本
|
|||
|
# # 在测试中,并未发现两个版本的模型不能互相通用
|
|||
|
# inferMap_V1 = {
|
|||
|
# "1.0.1": V101.infer,
|
|||
|
# "1.0": V101.infer,
|
|||
|
# "1.0.0": V101.infer,
|
|||
|
# }
|
|||
|
version = hps.version if hasattr(hps, "version") else latest_version
|
|||
|
# 非当前版本,根据版本号选择合适的infer
|
|||
|
if version != latest_version:
|
|||
|
if version in inferMap_V4.keys():
|
|||
|
return inferMap_V4[version](
|
|||
|
text,
|
|||
|
emotion,
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
reference_audio,
|
|||
|
skip_start,
|
|||
|
skip_end,
|
|||
|
style_text,
|
|||
|
style_weight,
|
|||
|
)
|
|||
|
if version in inferMap_V3.keys():
|
|||
|
return inferMap_V3[version](
|
|||
|
text,
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
reference_audio,
|
|||
|
emotion,
|
|||
|
skip_start,
|
|||
|
skip_end,
|
|||
|
style_text,
|
|||
|
style_weight,
|
|||
|
)
|
|||
|
if version in inferMap_V2.keys():
|
|||
|
return inferMap_V2[version](
|
|||
|
text,
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
)
|
|||
|
if version in inferMap_V1.keys():
|
|||
|
return inferMap_V1[version](
|
|||
|
text,
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
)
|
|||
|
# 在此处实现当前版本的推理
|
|||
|
# emo = get_emo_(reference_audio, emotion, sid)
|
|||
|
# if isinstance(reference_audio, np.ndarray):
|
|||
|
# emo = get_clap_audio_feature(reference_audio, device)
|
|||
|
# else:
|
|||
|
# emo = get_clap_text_feature(emotion, device)
|
|||
|
# emo = torch.squeeze(emo, dim=1)
|
|||
|
|
|||
|
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
|||
|
text,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
device,
|
|||
|
style_text=style_text,
|
|||
|
style_weight=style_weight,
|
|||
|
)
|
|||
|
if skip_start:
|
|||
|
phones = phones[3:]
|
|||
|
tones = tones[3:]
|
|||
|
lang_ids = lang_ids[3:]
|
|||
|
bert = bert[:, 3:]
|
|||
|
ja_bert = ja_bert[:, 3:]
|
|||
|
en_bert = en_bert[:, 3:]
|
|||
|
if skip_end:
|
|||
|
phones = phones[:-2]
|
|||
|
tones = tones[:-2]
|
|||
|
lang_ids = lang_ids[:-2]
|
|||
|
bert = bert[:, :-2]
|
|||
|
ja_bert = ja_bert[:, :-2]
|
|||
|
en_bert = en_bert[:, :-2]
|
|||
|
with torch.no_grad():
|
|||
|
x_tst = phones.to(device).unsqueeze(0)
|
|||
|
tones = tones.to(device).unsqueeze(0)
|
|||
|
lang_ids = lang_ids.to(device).unsqueeze(0)
|
|||
|
bert = bert.to(device).unsqueeze(0)
|
|||
|
ja_bert = ja_bert.to(device).unsqueeze(0)
|
|||
|
en_bert = en_bert.to(device).unsqueeze(0)
|
|||
|
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
|||
|
# emo = emo.to(device).unsqueeze(0)
|
|||
|
del phones
|
|||
|
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|||
|
audio = (
|
|||
|
net_g.infer(
|
|||
|
x_tst,
|
|||
|
x_tst_lengths,
|
|||
|
speakers,
|
|||
|
tones,
|
|||
|
lang_ids,
|
|||
|
bert,
|
|||
|
ja_bert,
|
|||
|
en_bert,
|
|||
|
sdp_ratio=sdp_ratio,
|
|||
|
noise_scale=noise_scale,
|
|||
|
noise_scale_w=noise_scale_w,
|
|||
|
length_scale=length_scale,
|
|||
|
en_ratio=1.0
|
|||
|
)[0][0, 0]
|
|||
|
.data.cpu()
|
|||
|
.float()
|
|||
|
.numpy()
|
|||
|
)
|
|||
|
del (
|
|||
|
x_tst,
|
|||
|
tones,
|
|||
|
lang_ids,
|
|||
|
bert,
|
|||
|
x_tst_lengths,
|
|||
|
speakers,
|
|||
|
ja_bert,
|
|||
|
en_bert,
|
|||
|
) # , emo
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.empty_cache()
|
|||
|
return audio
|
|||
|
|
|||
|
|
|||
|
def infer_multilang(
|
|||
|
text,
|
|||
|
sdp_ratio,
|
|||
|
noise_scale,
|
|||
|
noise_scale_w,
|
|||
|
length_scale,
|
|||
|
sid,
|
|||
|
language,
|
|||
|
hps,
|
|||
|
net_g,
|
|||
|
device,
|
|||
|
reference_audio=None,
|
|||
|
emotion=None,
|
|||
|
skip_start=False,
|
|||
|
skip_end=False,
|
|||
|
en_ratio=1.0
|
|||
|
):
|
|||
|
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
|||
|
# emo = get_emo_(reference_audio, emotion, sid)
|
|||
|
# if isinstance(reference_audio, np.ndarray):
|
|||
|
# emo = get_clap_audio_feature(reference_audio, device)
|
|||
|
# else:
|
|||
|
# emo = get_clap_text_feature(emotion, device)
|
|||
|
# emo = torch.squeeze(emo, dim=1)
|
|||
|
for idx, (txt, lang) in enumerate(zip(text, language)):
|
|||
|
_skip_start = (idx != 0) or (skip_start and idx == 0)
|
|||
|
_skip_end = (idx != len(language) - 1) or skip_end
|
|||
|
(
|
|||
|
temp_bert,
|
|||
|
temp_ja_bert,
|
|||
|
temp_en_bert,
|
|||
|
temp_phones,
|
|||
|
temp_tones,
|
|||
|
temp_lang_ids,
|
|||
|
) = get_text(txt, lang, hps, device)
|
|||
|
if _skip_start:
|
|||
|
temp_bert = temp_bert[:, 3:]
|
|||
|
temp_ja_bert = temp_ja_bert[:, 3:]
|
|||
|
temp_en_bert = temp_en_bert[:, 3:]
|
|||
|
temp_phones = temp_phones[3:]
|
|||
|
temp_tones = temp_tones[3:]
|
|||
|
temp_lang_ids = temp_lang_ids[3:]
|
|||
|
if _skip_end:
|
|||
|
temp_bert = temp_bert[:, :-2]
|
|||
|
temp_ja_bert = temp_ja_bert[:, :-2]
|
|||
|
temp_en_bert = temp_en_bert[:, :-2]
|
|||
|
temp_phones = temp_phones[:-2]
|
|||
|
temp_tones = temp_tones[:-2]
|
|||
|
temp_lang_ids = temp_lang_ids[:-2]
|
|||
|
bert.append(temp_bert)
|
|||
|
ja_bert.append(temp_ja_bert)
|
|||
|
en_bert.append(temp_en_bert)
|
|||
|
phones.append(temp_phones)
|
|||
|
tones.append(temp_tones)
|
|||
|
lang_ids.append(temp_lang_ids)
|
|||
|
bert = torch.concatenate(bert, dim=1)
|
|||
|
ja_bert = torch.concatenate(ja_bert, dim=1)
|
|||
|
en_bert = torch.concatenate(en_bert, dim=1)
|
|||
|
phones = torch.concatenate(phones, dim=0)
|
|||
|
tones = torch.concatenate(tones, dim=0)
|
|||
|
lang_ids = torch.concatenate(lang_ids, dim=0)
|
|||
|
with torch.no_grad():
|
|||
|
x_tst = phones.to(device).unsqueeze(0)
|
|||
|
tones = tones.to(device).unsqueeze(0)
|
|||
|
lang_ids = lang_ids.to(device).unsqueeze(0)
|
|||
|
bert = bert.to(device).unsqueeze(0)
|
|||
|
ja_bert = ja_bert.to(device).unsqueeze(0)
|
|||
|
en_bert = en_bert.to(device).unsqueeze(0)
|
|||
|
# emo = emo.to(device).unsqueeze(0)
|
|||
|
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
|||
|
del phones
|
|||
|
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
|||
|
audio = (
|
|||
|
net_g.infer(
|
|||
|
x_tst,
|
|||
|
x_tst_lengths,
|
|||
|
speakers,
|
|||
|
tones,
|
|||
|
lang_ids,
|
|||
|
bert,
|
|||
|
ja_bert,
|
|||
|
en_bert,
|
|||
|
sdp_ratio=sdp_ratio,
|
|||
|
noise_scale=noise_scale,
|
|||
|
noise_scale_w=noise_scale_w,
|
|||
|
length_scale=length_scale,
|
|||
|
en_ratio=en_ratio
|
|||
|
)[0][0, 0]
|
|||
|
.data.cpu()
|
|||
|
.float()
|
|||
|
.numpy()
|
|||
|
)
|
|||
|
del (
|
|||
|
x_tst,
|
|||
|
tones,
|
|||
|
lang_ids,
|
|||
|
bert,
|
|||
|
x_tst_lengths,
|
|||
|
speakers,
|
|||
|
ja_bert,
|
|||
|
en_bert,
|
|||
|
) # , emo
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.empty_cache()
|
|||
|
return audio
|