bert-vits2-utils

This commit is contained in:
bing 2024-06-23 20:39:44 +08:00
parent 3f78fcb12c
commit 33c952d327
37 changed files with 137723 additions and 1 deletions

4
.gitignore vendored
View File

@ -8,3 +8,7 @@ takway.db
/storage/** /storage/**
/app.log /app.log
**/*.pth
**/bert/*
**/emotional/*

13
test.py Normal file
View File

@ -0,0 +1,13 @@
from utils.bert_vits2_utils import TextToSpeech
import soundfile as sf
tts = TextToSpeech()
tts.print_speakers_info()
audio, sample_rate= tts.synthesize("你好,我好开心", # 文本
0, # 说话人 id
style_text="我很难过!!!!呜呜呜!!!", # 情绪prompt当language=="ZH" 才有效
style_weight=0.9, # 情绪prompt权重
language="mix", # 语言类型,包括 "ZH" "EN" "mix"
en_ratio=1.) # mix语言类型下英文文本速度越大速度越慢
save_path = "./tmp2.wav"
sf.write(save_path, audio, sample_rate)

BIN
tmp2.wav Normal file

Binary file not shown.

View File

View File

@ -0,0 +1,464 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from . import commons
import logging
logger = logging.getLogger(__name__)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
# if isflow:
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
# self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256
self.cond_layer_idx = self.n_layers
if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = (
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
)
logging.debug(self.gin_channels, self.cond_layer_idx)
assert (
self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
if i == self.cond_layer_idx and g is not None:
g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# pad along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x

158
utils/bert_vits2/commons.py Normal file
View File

@ -0,0 +1,158 @@
import math
import torch
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
1, x.size(1), 1
) + torch.arange(segment_size, device=x.device)
return torch.gather(x, 2, gather_indices)
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

262
utils/bert_vits2/config.py Normal file
View File

@ -0,0 +1,262 @@
"""
@Desc: 全局配置文件读取
"""
import argparse
import yaml
from typing import Dict, List
import os
import shutil
import sys
class Resample_config:
"""重采样配置"""
def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
self.sampling_rate: int = sampling_rate # 目标采样率
self.in_dir: str = in_dir # 待处理音频目录路径
self.out_dir: str = out_dir # 重采样输出路径
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
"""从字典中生成实例"""
# 不检查路径是否有效此逻辑在resample.py中处理
data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
return cls(**data)
class Preprocess_text_config:
"""数据预处理配置"""
def __init__(
self,
transcription_path: str,
cleaned_path: str,
train_path: str,
val_path: str,
config_path: str,
val_per_lang: int = 5,
max_val_total: int = 10000,
clean: bool = True,
):
self.transcription_path: str = (
transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
)
self.cleaned_path: str = (
cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
)
self.train_path: str = (
train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
)
self.val_path: str = (
val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
)
self.config_path: str = config_path # 配置文件路径
self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
self.max_val_total: int = (
max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
)
self.clean: bool = clean # 是否进行数据清洗
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
"""从字典中生成实例"""
data["transcription_path"] = os.path.join(
dataset_path, data["transcription_path"]
)
if data["cleaned_path"] == "" or data["cleaned_path"] is None:
data["cleaned_path"] = None
else:
data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
data["train_path"] = os.path.join(dataset_path, data["train_path"])
data["val_path"] = os.path.join(dataset_path, data["val_path"])
data["config_path"] = os.path.join(dataset_path, data["config_path"])
return cls(**data)
class Bert_gen_config:
"""bert_gen 配置"""
def __init__(
self,
config_path: str,
num_processes: int = 2,
device: str = "cuda",
use_multi_device: bool = False,
):
self.config_path = config_path
self.num_processes = num_processes
self.device = device
self.use_multi_device = use_multi_device
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
data["config_path"] = os.path.join(dataset_path, data["config_path"])
return cls(**data)
class Emo_gen_config:
"""emo_gen 配置"""
def __init__(
self,
config_path: str,
num_processes: int = 2,
device: str = "cuda",
use_multi_device: bool = False,
):
self.config_path = config_path
self.num_processes = num_processes
self.device = device
self.use_multi_device = use_multi_device
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
data["config_path"] = os.path.join(dataset_path, data["config_path"])
return cls(**data)
class Train_ms_config:
"""训练配置"""
def __init__(
self,
config_path: str,
env: Dict[str, any],
base: Dict[str, any],
model: str,
num_workers: int,
spec_cache: bool,
keep_ckpts: int,
):
self.env = env # 需要加载的环境变量
self.base = base # 底模配置
self.model = (
model # 训练模型存储目录该路径为相对于dataset_path的路径而非项目根目录
)
self.config_path = config_path # 配置文件路径
self.num_workers = num_workers # worker数量
self.spec_cache = spec_cache # 是否启用spec缓存
self.keep_ckpts = keep_ckpts # ckpt数量
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
# data["model"] = os.path.join(dataset_path, data["model"])
data["config_path"] = os.path.join(dataset_path, data["config_path"])
return cls(**data)
class Webui_config:
"""webui 配置"""
def __init__(
self,
device: str,
model: str,
config_path: str,
language_identification_library: str,
port: int = 7860,
share: bool = False,
debug: bool = False,
):
self.device: str = device
self.model: str = model # 端口号
self.config_path: str = config_path # 是否公开部署,对外网开放
self.port: int = port # 是否开启debug模式
self.share: bool = share # 模型路径
self.debug: bool = debug # 配置文件路径
self.language_identification_library: str = (
language_identification_library # 语种识别库
)
@classmethod
def from_dict(cls, dataset_path: str, data: Dict[str, any]):
data["config_path"] = os.path.join(dataset_path, data["config_path"])
data["model"] = os.path.join(dataset_path, data["model"])
return cls(**data)
class Server_config:
def __init__(
self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda"
):
self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置
self.port: int = port # 端口号
self.device: str = device # 模型默认使用设备
@classmethod
def from_dict(cls, data: Dict[str, any]):
return cls(**data)
class Translate_config:
"""翻译api配置"""
def __init__(self, app_key: str, secret_key: str):
self.app_key = app_key
self.secret_key = secret_key
@classmethod
def from_dict(cls, data: Dict[str, any]):
return cls(**data)
class Config:
def __init__(self, config_path: str):
if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
shutil.copy(src="default_config.yml", dst=config_path)
print(
f"已根据默认配置文件default_config.yml生成配置文件{config_path}。请按该配置文件的说明进行配置后重新运行。"
)
print("如无特殊需求请勿修改default_config.yml或备份该文件。")
sys.exit(0)
print(os.getcwd())
with open(file=config_path, mode="r", encoding="utf-8") as file:
yaml_config: Dict[str, any] = yaml.safe_load(file.read())
dataset_path: str = yaml_config["dataset_path"]
openi_token: str = yaml_config["openi_token"]
self.dataset_path: str = dataset_path
self.mirror: str = yaml_config["mirror"]
self.openi_token: str = openi_token
self.resample_config: Resample_config = Resample_config.from_dict(
dataset_path, yaml_config["resample"]
)
self.preprocess_text_config: Preprocess_text_config = (
Preprocess_text_config.from_dict(
dataset_path, yaml_config["preprocess_text"]
)
)
self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
dataset_path, yaml_config["bert_gen"]
)
self.emo_gen_config: Emo_gen_config = Emo_gen_config.from_dict(
dataset_path, yaml_config["emo_gen"]
)
self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
dataset_path, yaml_config["train_ms"]
)
self.webui_config: Webui_config = Webui_config.from_dict(
dataset_path, yaml_config["webui"]
)
self.server_config: Server_config = Server_config.from_dict(
yaml_config["server"]
)
self.translate_config: Translate_config = Translate_config.from_dict(
yaml_config["translate"]
)
parser = argparse.ArgumentParser()
# 为避免与以前的config.json起冲突将其更名如下
parser.add_argument("-y", "--yml_config", type=str, default="./utils/bert_vits2/config.yml")
args, _ = parser.parse_known_args()
config = Config(args.yml_config)

180
utils/bert_vits2/config.yml Normal file
View File

@ -0,0 +1,180 @@
# 全局配置
# 对于希望在同一时间使用多个配置文件的情况例如两个GPU同时跑两个训练集通过环境变量指定配置文件不指定则默认为./config.yml
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
# 每个数据集与其对应的模型存放至统一路径下后续所有的路径配置均为相对于datasetPath的路径
# 不填或者填空则路径为相对于项目根目录的路径
# dataset_path: "data/Genshin-KL"
dataset_path: "utils/bert_vits2/data/mix"
# 模型镜像源默认huggingface使用openi镜像源需指定openi_token
mirror: ""
openi_token: "" # openi token
# resample 音频重采样配置
# 注意, “:” 后需要加空格
resample:
# 目标重采样率
sampling_rate: 44100
# sampling_rate: 16000
# 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
# 请填入相对于datasetPath的相对路径
in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
# 音频文件重采样后输出路径
out_dir: "audios/wavs"
# preprocess_text 数据集预处理相关配置
# 注意, “:” 后需要加空格
preprocess_text:
# 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
transcription_path: "filelists/你的数据集文本.list"
# 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
cleaned_path: ""
# 训练集路径
train_path: "filelists/train.list"
# 验证集路径
val_path: "filelists/val.list"
# 配置文件路径
config_path: "config.json"
# 每个语言的验证集条数
val_per_lang: 4
# 验证集最大条数,多于的会被截断并放到训练集中
max_val_total: 12
# 是否进行数据清洗
clean: true
# bert_gen 相关配置
# 注意, “:” 后需要加空格
bert_gen:
# 训练数据集配置文件路径
config_path: "config.json"
# 并行数
num_processes: 4
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
# 该选项同时决定了get_bert_feature的默认设备
device: "cuda"
# 使用多卡推理
use_multi_device: false
# emo_gen 相关配置
# 注意, “:” 后需要加空格
emo_gen:
# 训练数据集配置文件路径
config_path: "config.json"
# 并行数
num_processes: 4
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
device: "cuda"
# 使用多卡推理
use_multi_device: false
# train 训练配置
# 注意, “:” 后需要加空格
train_ms:
env:
MASTER_ADDR: "localhost"
MASTER_PORT: 10086
WORLD_SIZE: 1
LOCAL_RANK: 0
RANK: 0
# 可以填写任意名的环境变量
# THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
# 底模设置
base:
use_base_model: false
repo_id: "Stardust_minus/Bert-VITS2"
model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
# 训练模型存储目录与旧版本的区别原先数据集是存放在logs/model_name下的现在改为统一存放在Data/你的数据集/models下
model: "models"
# 配置文件路径
config_path: "configs/config.json"
# 训练使用的worker不建议超过CPU核心数
num_workers: 16
# 关闭此项可以节约接近70%的磁盘空间但是可能导致实际训练速度变慢和更高的CPU使用率。
spec_cache: False
# 保存的检查点数量,多于此数目的权重会被删除来节省空间。
keep_ckpts: 3
# webui webui配置
# 注意, “:” 后需要加空格
webui:
# 推理设备
device: "cuda"
# 模型路径
# model: "models/G_32000.pth"
model: "models/G_250000.pth"
# 配置文件路径
config_path: "configs/config.json"
# 端口号
port: 7861
# 是否公开部署,对外网开放
share: false
# 是否开启debug模式
debug: false
# 语种识别库可选langid, fastlid
language_identification_library: "langid"
# server-fastapi配置
# 注意, “:” 后需要加空格
# 注意,本配置下的所有配置均为相对于根目录的路径
server:
# 端口号
port: 5005
# 模型默认使用设备:但是当前并没有实现这个配置。
device: "cuda"
# 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型
# 不加载模型的配置格式删除默认给的两个模型配置给models赋值 [ ]也就是空列表。参考模型2的speakers 即 models: [ ]
# 注意所有模型都必须正确配置model与config的路径空路径会导致加载错误。
# 也可以不填模型等网页加载成功后手动填写models。
models:
- # 模型的路径
model: ""
# 模型config.json的路径
config: ""
# 模型使用设备,若填写则会覆盖默认配置
device: "cuda"
# 模型默认使用的语言
language: "ZH"
# 模型人物默认参数
# 不必填写所有人物,不填的使用默认值
# 暂时不用填写,当前尚未实现按人区分配置
speakers:
- speaker: "科比"
sdp_ratio: 0.2
noise_scale: 0.6
noise_scale_w: 0.8
length_scale: 1
- speaker: "五条悟"
sdp_ratio: 0.3
noise_scale: 0.7
noise_scale_w: 0.8
length_scale: 0.5
- speaker: "安倍晋三"
sdp_ratio: 0.2
noise_scale: 0.6
noise_scale_w: 0.8
length_scale: 1.2
- # 模型的路径
model: ""
# 模型config.json的路径
config: ""
# 模型使用设备,若填写则会覆盖默认配置
device: "cpu"
# 模型默认使用的语言
language: "JP"
# 模型人物默认参数
# 不必填写所有人物,不填的使用默认值
speakers: [ ] # 也可以不填
# 百度翻译开放平台 api配置
# api接入文档 https://api.fanyi.baidu.com/doc/21
# 请不要在github等网站公开分享你的app id 与 key
translate:
# 你的APPID
"app_key": ""
# 你的密钥
"secret_key": ""

View File

@ -0,0 +1,112 @@
{
"train": {
"log_interval": 400,
"eval_interval": 2000,
"seed": 42,
"epochs": 1000,
"learning_rate": 0.00002,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 24,
"bf16_run": false,
"lr_decay": 0.99995,
"segment_size": 16384,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"c_commit": 100,
"skip_optimizer": true,
"freeze_ZH_bert": false,
"freeze_JP_bert": false,
"freeze_EN_bert": false,
"freeze_emo": false
},
"data": {
"training_files": "data/mix/train.list",
"validation_files": "data/mix/val.list",
"max_wav_value": 32768.0,
"sampling_rate": 44100,
"filter_length": 2048,
"hop_length": 512,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 5,
"cleaned_text": true,
"spk2id": {
"可莉": 0,
"钟离": 1,
"八重神子": 2,
"枫原万叶": 3,
"胡桃": 4
}
},
"model": {
"use_spk_conditioned_encoder": true,
"use_noise_scaled_mas": true,
"use_mel_posterior_encoder": false,
"use_duration_discriminator": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 512,
"slm": {
"model": "./slm/wavlm-base-plus",
"sr": 16000,
"hidden": 768,
"nlayers": 13,
"initial_channel": 64
}
},
"version": "2.3"
}

439
utils/bert_vits2/infer.py Normal file
View File

@ -0,0 +1,439 @@
"""
版本管理兼容推理及模型加载实现
版本说明
1. 版本号与github的release版本号对应使用哪个release版本训练的模型即对应其版本号
2. 请在模型的config.json中显示声明版本号添加一个字段"version" : "你的版本号"
特殊版本说明
1.1.1-fix 1.1.1版本训练的模型但是在推理时使用dev的日语修复
2.3当前版本
"""
import torch
from . import commons
from .text import cleaned_text_to_sequence, get_bert
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from typing import Union
from .text.cleaner import clean_text
from . import utils
from .models import SynthesizerTrn
from .text.symbols import symbols
# from utils.tts.bert_vits2.oldVersion.V220.models import SynthesizerTrn as V220SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V220.text import symbols as V220symbols
# from utils.tts.bert_vits2.oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V210.text import symbols as V210symbols
# from utils.tts.bert_vits2.oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V200.text import symbols as V200symbols
# from utils.tts.bert_vits2.oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V111.text import symbols as V111symbols
# from utils.tts.bert_vits2.oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V110.text import symbols as V110symbols
# from utils.tts.bert_vits2.oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
# from utils.tts.bert_vits2.oldVersion.V101.text import symbols as V101symbols
# from oldVersion import V111, V110, V101, V200, V210, V220
# 当前版本信息
latest_version = "2.3"
# 版本兼容
# SynthesizerTrnMap = {
# "2.2": V220SynthesizerTrn,
# "2.1": V210SynthesizerTrn,
# "2.0.2-fix": V200SynthesizerTrn,
# "2.0.1": V200SynthesizerTrn,
# "2.0": V200SynthesizerTrn,
# "1.1.1-fix": V111SynthesizerTrn,
# "1.1.1": V111SynthesizerTrn,
# "1.1": V110SynthesizerTrn,
# "1.1.0": V110SynthesizerTrn,
# "1.0.1": V101SynthesizerTrn,
# "1.0": V101SynthesizerTrn,
# "1.0.0": V101SynthesizerTrn,
# }
# symbolsMap = {
# "2.2": V220symbols,
# "2.1": V210symbols,
# "2.0.2-fix": V200symbols,
# "2.0.1": V200symbols,
# "2.0": V200symbols,
# "1.1.1-fix": V111symbols,
# "1.1.1": V111symbols,
# "1.1": V110symbols,
# "1.1.0": V110symbols,
# "1.0.1": V101symbols,
# "1.0": V101symbols,
# "1.0.0": V101symbols,
# }
# def get_emo_(reference_audio, emotion, sid):
# emo = (
# torch.from_numpy(get_emo(reference_audio))
# if reference_audio and emotion == -1
# else torch.FloatTensor(
# np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
# )
# )
# return emo
def get_net_g(model_path: str, version: str, device: str, hps):
if version != latest_version:
net_g = SynthesizerTrnMap[version](
len(symbolsMap[version]),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
else:
# 当前版本模型 net_g
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
return net_g
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
style_text = None if style_text == "" else style_text
# 在此处实现当前版本的get_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_ori = get_bert(
norm_text, word2ph, language_str, device, style_text, style_weight
)
del word2ph
assert bert_ori.shape[-1] == len(phone), phone
if language_str == "ZH":
bert = bert_ori
ja_bert = torch.randn(1024, len(phone))
en_bert = torch.randn(1024, len(phone))
elif language_str == "JP":
bert = torch.randn(1024, len(phone))
ja_bert = bert_ori
en_bert = torch.randn(1024, len(phone))
elif language_str == "EN":
bert = torch.randn(1024, len(phone))
ja_bert = torch.randn(1024, len(phone))
en_bert = bert_ori
else:
raise ValueError("language_str should be ZH, JP or EN")
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, en_bert, phone, tone, language
def infer(
text,
emotion: Union[int, str],
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio=None,
skip_start=False,
skip_end=False,
style_text=None,
style_weight=0.7,
):
# # 2.2版本参数位置变了
# inferMap_V4 = {
# "2.2": V220.infer,
# }
# # 2.1 参数新增 emotion reference_audio skip_start skip_end
# inferMap_V3 = {
# "2.1": V210.infer,
# }
# # 支持中日英三语版本
# inferMap_V2 = {
# "2.0.2-fix": V200.infer,
# "2.0.1": V200.infer,
# "2.0": V200.infer,
# "1.1.1-fix": V111.infer_fix,
# "1.1.1": V111.infer,
# "1.1": V110.infer,
# "1.1.0": V110.infer,
# }
# # 仅支持中文版本
# # 在测试中,并未发现两个版本的模型不能互相通用
# inferMap_V1 = {
# "1.0.1": V101.infer,
# "1.0": V101.infer,
# "1.0.0": V101.infer,
# }
version = hps.version if hasattr(hps, "version") else latest_version
# 非当前版本根据版本号选择合适的infer
if version != latest_version:
if version in inferMap_V4.keys():
return inferMap_V4[version](
text,
emotion,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio,
skip_start,
skip_end,
style_text,
style_weight,
)
if version in inferMap_V3.keys():
return inferMap_V3[version](
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio,
emotion,
skip_start,
skip_end,
style_text,
style_weight,
)
if version in inferMap_V2.keys():
return inferMap_V2[version](
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
)
if version in inferMap_V1.keys():
return inferMap_V1[version](
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
hps,
net_g,
device,
)
# 在此处实现当前版本的推理
# emo = get_emo_(reference_audio, emotion, sid)
# if isinstance(reference_audio, np.ndarray):
# emo = get_clap_audio_feature(reference_audio, device)
# else:
# emo = get_clap_text_feature(emotion, device)
# emo = torch.squeeze(emo, dim=1)
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
text,
language,
hps,
device,
style_text=style_text,
style_weight=style_weight,
)
if skip_start:
phones = phones[3:]
tones = tones[3:]
lang_ids = lang_ids[3:]
bert = bert[:, 3:]
ja_bert = ja_bert[:, 3:]
en_bert = en_bert[:, 3:]
if skip_end:
phones = phones[:-2]
tones = tones[:-2]
lang_ids = lang_ids[:-2]
bert = bert[:, :-2]
ja_bert = ja_bert[:, :-2]
en_bert = en_bert[:, :-2]
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
# emo = emo.to(device).unsqueeze(0)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
en_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
en_ratio=1.0
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del (
x_tst,
tones,
lang_ids,
bert,
x_tst_lengths,
speakers,
ja_bert,
en_bert,
) # , emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio
def infer_multilang(
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio=None,
emotion=None,
skip_start=False,
skip_end=False,
en_ratio=1.0
):
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
# emo = get_emo_(reference_audio, emotion, sid)
# if isinstance(reference_audio, np.ndarray):
# emo = get_clap_audio_feature(reference_audio, device)
# else:
# emo = get_clap_text_feature(emotion, device)
# emo = torch.squeeze(emo, dim=1)
for idx, (txt, lang) in enumerate(zip(text, language)):
_skip_start = (idx != 0) or (skip_start and idx == 0)
_skip_end = (idx != len(language) - 1) or skip_end
(
temp_bert,
temp_ja_bert,
temp_en_bert,
temp_phones,
temp_tones,
temp_lang_ids,
) = get_text(txt, lang, hps, device)
if _skip_start:
temp_bert = temp_bert[:, 3:]
temp_ja_bert = temp_ja_bert[:, 3:]
temp_en_bert = temp_en_bert[:, 3:]
temp_phones = temp_phones[3:]
temp_tones = temp_tones[3:]
temp_lang_ids = temp_lang_ids[3:]
if _skip_end:
temp_bert = temp_bert[:, :-2]
temp_ja_bert = temp_ja_bert[:, :-2]
temp_en_bert = temp_en_bert[:, :-2]
temp_phones = temp_phones[:-2]
temp_tones = temp_tones[:-2]
temp_lang_ids = temp_lang_ids[:-2]
bert.append(temp_bert)
ja_bert.append(temp_ja_bert)
en_bert.append(temp_en_bert)
phones.append(temp_phones)
tones.append(temp_tones)
lang_ids.append(temp_lang_ids)
bert = torch.concatenate(bert, dim=1)
ja_bert = torch.concatenate(ja_bert, dim=1)
en_bert = torch.concatenate(en_bert, dim=1)
phones = torch.concatenate(phones, dim=0)
tones = torch.concatenate(tones, dim=0)
lang_ids = torch.concatenate(lang_ids, dim=0)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
# emo = emo.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
en_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
en_ratio=en_ratio
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del (
x_tst,
tones,
lang_ids,
bert,
x_tst_lengths,
speakers,
ja_bert,
en_bert,
) # , emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio

1081
utils/bert_vits2/models.py Normal file

File diff suppressed because it is too large Load Diff

580
utils/bert_vits2/modules.py Normal file
View File

@ -0,0 +1,580 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from . import commons
from .commons import init_weights, get_padding
from .transforms import piecewise_rational_quadratic_transform
from .attentions import Encoder
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dilated and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ConvFlow(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class TransformerCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels=0,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = (
Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=gin_channels,
)
if wn_sharing_parameter is None
else wn_sharing_parameter
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

View File

@ -0,0 +1,16 @@
from numpy import zeros, int32, float32
from torch import from_numpy
from .core import maximum_path_jit
def maximum_path(neg_cent, mask):
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)

View File

@ -0,0 +1,46 @@
import numba
@numba.jit(
numba.void(
numba.int32[:, :, ::1],
numba.float32[:, :, ::1],
numba.int32[::1],
numba.int32[::1],
),
nopython=True,
nogil=True,
)
def maximum_path_jit(paths, values, t_ys, t_xs):
b = paths.shape[0]
max_neg_val = -1e9
for i in range(int(b)):
path = paths[i]
value = values[i]
t_y = t_ys[i]
t_x = t_xs[i]
v_prev = v_cur = 0.0
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (
index == y or value[y - 1, index] < value[y - 1, index - 1]
):
index = index - 1

View File

@ -0,0 +1,81 @@
import re
def extract_language_and_text_updated(speaker, dialogue):
# 使用正则表达式匹配<语言>标签和其后的文本
pattern_language_text = r"<(\S+?)>([^<]+)"
matches = re.findall(pattern_language_text, dialogue, re.DOTALL)
speaker = speaker[1:-1]
# 清理文本:去除两边的空白字符
matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches]
matches_cleaned.append(speaker)
return matches_cleaned
def validate_text(input_text):
# 验证说话人的正则表达式
pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)"
# 使用re.DOTALL标志使.匹配包括换行符在内的所有字符
matches = re.findall(pattern_speaker, input_text, re.DOTALL)
# 对每个匹配到的说话人内容进行进一步验证
for _, dialogue in matches:
language_text_matches = extract_language_and_text_updated(_, dialogue)
if not language_text_matches:
return (
False,
"Error: Invalid format detected in dialogue content. Please check your input.",
)
# 如果输入的文本中没有找到任何匹配项
if not matches:
return (
False,
"Error: No valid speaker format detected. Please check your input.",
)
return True, "Input is valid."
def text_matching(text: str) -> list:
speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)"
matches = re.findall(speaker_pattern, text, re.DOTALL)
result = []
for speaker, dialogue in matches:
result.append(extract_language_and_text_updated(speaker, dialogue))
return result
def cut_para(text):
splitted_para = re.split("[\n]", text) # 按段分
splitted_para = [
sentence.strip() for sentence in splitted_para if sentence.strip()
] # 删除空字符串
return splitted_para
def cut_sent(para):
para = re.sub("([。!;\?])([^”’])", r"\1\n\2", para) # 单字符断句符
para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号
para = re.sub("(\{2})([^”’])", r"\1\n\2", para) # 中文省略号
para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para)
para = para.rstrip() # 段尾如果有多余的\n就去掉它
return para.split("\n")
if __name__ == "__main__":
text = """
[说话人1]
[说话人2]<zh>你好吗<jp>元気ですか<jp>こんにちは世界<zh>你好吗
[说话人3]<zh>谢谢<jp>どういたしまして
"""
text_matching(text)
# 测试函数
test_text = """
[说话人1]<zh>你好こんにちは<jp>こんにちは世界
[说话人2]<zh>你好吗
"""
text_matching(test_text)
res = validate_text(test_text)
print(res)

View File

@ -0,0 +1,63 @@
from ..text.symbols import *
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
def cleaned_text_to_sequence(cleaned_text, tones, language):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
tone_start = language_tone_start_map[language]
tones = [i + tone_start for i in tones]
lang_id = language_id_map[language]
lang_ids = [lang_id for i in phones]
return phones, tones, lang_ids
def get_bert(norm_text, word2ph, language, device, style_text=None, style_weight=0.7):
from .chinese_bert import get_bert_feature as zh_bert
from .english_bert_mock import get_bert_feature as en_bert
from .japanese_bert import get_bert_feature as jp_bert
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
bert = lang_bert_func_map[language](
norm_text, word2ph, device, style_text, style_weight
)
return bert
def check_bert_models():
relative_path = r"./utils/bert_vits2/"
import json
from pathlib import Path
from ..config import config
from .bert_utils import _check_bert
if config.mirror.lower() == "openi":
import openi
kwargs = {"token": config.openi_token} if config.openi_token else {}
openi.login(**kwargs)
with open(relative_path+"bert/bert_models.json", "r") as fp:
models = json.load(fp)
for k, v in models.items():
local_path = Path(relative_path+"bert").joinpath(k)
_check_bert(v["repo_id"], v["files"], local_path)
def init_openjtalk():
import platform
if platform.platform() == "Linux":
import pyopenjtalk
pyopenjtalk.g2p("こんにちは,世界。")
init_openjtalk()
check_bert_models()

View File

@ -0,0 +1,23 @@
from pathlib import Path
from huggingface_hub import hf_hub_download
from ..config import config
MIRROR: str = config.mirror
def _check_bert(repo_id, files, local_path):
for file in files:
if not Path(local_path).joinpath(file).exists():
if MIRROR.lower() == "openi":
import openi
openi.model.download_model(
"Stardust_minus/Bert-VITS2", repo_id.split("/")[-1], "./bert"
)
else:
hf_hub_download(
repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
)

View File

@ -0,0 +1,206 @@
import os
import re
from pypinyin import lazy_pinyin, Style
from ..text.symbols import punctuation
from ..text.tone_sandhi import ToneSandhi
try:
from tn.chinese.normalizer import Normalizer
normalizer = Normalizer().normalize
except ImportError:
import cn2an
print("tn.chinese.normalizer not found, use cn2an normalizer")
normalizer = lambda x: cn2an.transform(x, "an2cn")
current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {
line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba.posseg as psg
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
"$": ".",
"": "'",
"": "'",
'"': "'",
"": "'",
"": "'",
"": "'",
"": "'",
"(": "'",
")": "'",
"": "'",
"": "'",
"": "'",
"": "'",
"[": "'",
"]": "'",
"": "-",
"": "-",
"~": "-",
"": "'",
"": "'",
}
tone_modifier = ToneSandhi()
def replace_punctuation(text):
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def g2p(text):
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, tones, word2ph = _g2p(sentences)
assert sum(word2ph) == len(phones)
assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
return phones, tones, word2ph
def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)
return initials, finals
def _g2p(segments):
phones_list = []
tones_list = []
word2ph = []
for seg in segments:
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
initials = []
finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
#
for c, v in zip(initials, finals):
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
assert c in punctuation
phone = [c]
tone = "0"
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
phone = pinyin_to_symbol_map[pinyin].split(" ")
word2ph.append(len(phone))
phones_list += phone
tones_list += [int(tone)] * len(phone)
return phones_list, tones_list, word2ph
def text_normalize(text):
text = normalizer(text)
text = replace_punctuation(text)
return text
def get_bert_feature(text, word2ph):
from text import chinese_bert
return chinese_bert.get_bert_feature(text, word2ph)
if __name__ == "__main__":
from text.chinese_bert import get_bert_feature
text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
text = text_normalize(text)
print(text)
phones, tones, word2ph = g2p(text)
bert = get_bert_feature(text, word2ph)
print(phones, tones, word2ph, bert.shape)
# # 示例用法
# text = "这是一个示例文本:,你好!这是一个测试...."
# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试

View File

@ -0,0 +1,119 @@
import sys
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from ..config import config
LOCAL_PATH = "./utils/bert_vits2/bert/chinese-roberta-wwm-ext-large"
tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
models = dict()
def get_bert_feature(
text,
word2ph,
device=config.bert_gen_config.device,
style_text=None,
style_weight=0.7,
):
if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
and device == "cpu"
):
device = "mps"
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = models[device](**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = models[device](**style_inputs, output_hidden_states=True)
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
style_res_mean = style_res.mean(0)
assert len(word2ph) == len(text) + 2
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
if __name__ == "__main__":
word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
word2phone = [
1,
2,
1,
2,
2,
1,
2,
2,
1,
2,
2,
1,
2,
2,
2,
2,
2,
1,
1,
2,
2,
1,
2,
2,
2,
2,
1,
2,
2,
2,
2,
2,
1,
2,
2,
2,
2,
1,
]
# 计算总帧数
total_frames = sum(word2phone)
print(word_level_feature.shape)
print(word2phone)
phone_level_feature = []
for i in range(len(word2phone)):
print(word_level_feature[i].shape)
# 对每个词重复word2phone[i]次
repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
print(phone_level_feature.shape) # torch.Size([36, 1024])

View File

@ -0,0 +1,29 @@
from ..text import chinese, japanese, english, cleaned_text_to_sequence
# from text import chinese, cleaned_text_to_sequence
language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
# language_module_map = {"ZH": chinese}
def clean_text(text, language):
language_module = language_module_map[language]
norm_text = language_module.text_normalize(text)
phones, tones, word2ph = language_module.g2p(norm_text)
return norm_text, phones, tones, word2ph
def clean_text_bert(text, language):
language_module = language_module_map[language]
norm_text = language_module.text_normalize(text)
phones, tones, word2ph = language_module.g2p(norm_text)
bert = language_module.get_bert_feature(norm_text, word2ph)
return phones, tones, bert
def text_to_sequence(text, language):
norm_text, phones, tones, word2ph = clean_text(text, language)
return cleaned_text_to_sequence(phones, tones, language)
if __name__ == "__main__":
pass

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -0,0 +1,494 @@
import pickle
import os
import re
from g2p_en import G2p
from transformers import DebertaV2Tokenizer
from ..text import symbols
from ..text.symbols import punctuation
current_file_path = os.path.dirname(__file__)
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
_g2p = G2p()
LOCAL_PATH = "./utils/bert_vits2/bert/deberta-v3-large"
tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
def post_replace_ph(ph):
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"": "...",
"···": "...",
"・・・": "...",
"v": "V",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = "UNK"
return ph
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"": ".",
"": "...",
"···": "...",
"・・・": "...",
"·": ",",
"": ",",
"": ",",
"$": ".",
"": "'",
"": "'",
'"': "'",
"": "'",
"": "'",
"": "'",
"": "'",
"(": "'",
")": "'",
"": "'",
"": "'",
"": "'",
"": "'",
"[": "'",
"]": "'",
"": "-",
"": "-",
"": "-",
"~": "-",
"": "'",
"": "'",
}
def replace_punctuation(text):
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
# replaced_text = re.sub(
# r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
# + "".join(punctuation)
# + r"]+",
# "",
# replaced_text,
# )
return replaced_text
def read_dict():
g2p_dict = {}
start_line = 49
with open(CMU_DICT_PATH) as f:
line = f.readline()
line_index = 1
while line:
if line_index >= start_line:
line = line.strip()
word_split = line.split(" ")
word = word_split[0]
syllable_split = word_split[1].split(" - ")
g2p_dict[word] = []
for syllable in syllable_split:
phone_split = syllable.split(" ")
g2p_dict[word].append(phone_split)
line_index = line_index + 1
line = f.readline()
return g2p_dict
def cache_dict(g2p_dict, file_path):
with open(file_path, "wb") as pickle_file:
pickle.dump(g2p_dict, pickle_file)
def get_dict():
if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, "rb") as pickle_file:
g2p_dict = pickle.load(pickle_file)
else:
g2p_dict = read_dict()
cache_dict(g2p_dict, CACHE_PATH)
return g2p_dict
eng_dict = get_dict()
def refine_ph(phn):
tone = 0
if re.search(r"\d$", phn):
tone = int(phn[-1]) + 1
phn = phn[:-1]
else:
tone = 3
return phn.lower(), tone
def refine_syllables(syllables):
tones = []
phonemes = []
for phn_list in syllables:
for i in range(len(phn_list)):
phn = phn_list[i]
phn, tone = refine_ph(phn)
phonemes.append(phn)
tones.append(tone)
return phonemes, tones
import inflect
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
# List of (ipa, lazy ipa) pairs:
_lazy_ipa = [
(re.compile("%s" % x[0]), x[1])
for x in [
("r", "ɹ"),
("æ", "e"),
("ɑ", "a"),
("ɔ", "o"),
("ð", "z"),
("θ", "s"),
("ɛ", "e"),
("ɪ", "i"),
("ʊ", "u"),
("ʒ", "ʥ"),
("ʤ", "ʥ"),
("ˈ", ""),
]
]
# List of (ipa, lazy ipa2) pairs:
_lazy_ipa2 = [
(re.compile("%s" % x[0]), x[1])
for x in [
("r", "ɹ"),
("ð", "z"),
("θ", "s"),
("ʒ", "ʑ"),
("ʤ", ""),
("ˈ", ""),
]
]
# List of (ipa, ipa2) pairs
_ipa_to_ipa2 = [
(re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", ""), ("ʧ", "")]
]
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
def text_normalize(text):
text = normalize_numbers(text)
text = replace_punctuation(text)
text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text)
return text
def distribute_phone(n_phone, n_word):
phones_per_word = [0] * n_word
for task in range(n_phone):
min_tasks = min(phones_per_word)
min_index = phones_per_word.index(min_tasks)
phones_per_word[min_index] += 1
return phones_per_word
def sep_text(text):
words = re.split(r"([,;.\?\!\s+])", text)
words = [word for word in words if word.strip() != ""]
return words
def text_to_words(text):
tokens = tokenizer.tokenize(text)
words = []
for idx, t in enumerate(tokens):
if t.startswith(""):
words.append([t[1:]])
else:
if t in punctuation:
if idx == len(tokens) - 1:
words.append([f"{t}"])
else:
if (
not tokens[idx + 1].startswith("")
and tokens[idx + 1] not in punctuation
):
if idx == 0:
words.append([])
words[-1].append(f"{t}")
else:
words.append([f"{t}"])
else:
if idx == 0:
words.append([])
words[-1].append(f"{t}")
return words
def g2p(text):
phones = []
tones = []
phone_len = []
# words = sep_text(text)
# tokens = [tokenizer.tokenize(i) for i in words]
words = text_to_words(text)
for word in words:
temp_phones, temp_tones = [], []
if len(word) > 1:
if "'" in word:
word = ["".join(word)]
for w in word:
if w in punctuation:
temp_phones.append(w)
temp_tones.append(0)
continue
if w.upper() in eng_dict:
phns, tns = refine_syllables(eng_dict[w.upper()])
temp_phones += [post_replace_ph(i) for i in phns]
temp_tones += tns
# w2ph.append(len(phns))
else:
phone_list = list(filter(lambda p: p != " ", _g2p(w)))
phns = []
tns = []
for ph in phone_list:
if ph in arpa:
ph, tn = refine_ph(ph)
phns.append(ph)
tns.append(tn)
else:
phns.append(ph)
tns.append(0)
temp_phones += [post_replace_ph(i) for i in phns]
temp_tones += tns
phones += temp_phones
tones += temp_tones
phone_len.append(len(temp_phones))
# phones = [post_replace_ph(i) for i in phones]
word2ph = []
for token, pl in zip(words, phone_len):
word_len = len(token)
aaa = distribute_phone(pl, word_len)
word2ph += aaa
phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
assert len(phones) == len(tones), text
assert len(phones) == sum(word2ph), text
return phones, tones, word2ph
def get_bert_feature(text, word2ph):
from text import english_bert_mock
return english_bert_mock.get_bert_feature(text, word2ph)
if __name__ == "__main__":
# print(get_dict())
# print(eng_word_to_phoneme("hello"))
print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
# all_phones = set()
# for k, syllables in eng_dict.items():
# for group in syllables:
# for ph in group:
# all_phones.add(ph)
# print(all_phones)

View File

@ -0,0 +1,61 @@
import sys
import torch
from transformers import DebertaV2Model, DebertaV2Tokenizer
from ..config import config
LOCAL_PATH = "./utils/bert_vits2/bert/deberta-v3-large"
tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
models = dict()
def get_bert_feature(
text,
word2ph,
device=config.bert_gen_config.device,
style_text=None,
style_weight=0.7,
):
if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
and device == "cpu"
):
device = "mps"
if not device:
device = "cuda"
if device not in models.keys():
models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = models[device](**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = models[device](**style_inputs, output_hidden_states=True)
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
style_res_mean = style_res.mean(0)
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T

View File

@ -0,0 +1,720 @@
# Convert Japanese text to phonemes which is
# compatible with Julius https://github.com/julius-speech/segmentation-kit
import re
import unicodedata
from transformers import AutoTokenizer
from ..text import punctuation, symbols
from num2words import num2words
import pyopenjtalk
import jaconv
# Mapping of hiragana to phonetic representation
hiragana_map = {
"う゛ぁ": " v a",
"う゛ぃ": " v i",
"う゛ぇ": " v e",
"う゛ぉ": " v o",
"う゛ゅ": " by u",
"ぅ゛": " v u",
# ゔ等の処理を追加
"ゔぁ": " v a",
"ゔぃ": " v i",
"ゔぇ": " v e",
"ゔぉ": " v o",
"ゔゅ": " by u",
# 2文字からなる変換規則
"あぁ": " a a",
"いぃ": " i i",
"いぇ": " i e",
"いゃ": " y a",
"うぅ": " u:",
"えぇ": " e e",
"おぉ": " o:",
"かぁ": " k a:",
"きぃ": " k i:",
"くぅ": " k u:",
"くゃ": " ky a",
"くゅ": " ky u",
"くょ": " ky o",
"けぇ": " k e:",
"こぉ": " k o:",
"がぁ": " g a:",
"ぎぃ": " g i:",
"ぐぅ": " g u:",
"ぐゃ": " gy a",
"ぐゅ": " gy u",
"ぐょ": " gy o",
"げぇ": " g e:",
"ごぉ": " g o:",
"さぁ": " s a:",
"しぃ": " sh i",
"すぅ": " s u:",
"すゃ": " sh a",
"すゅ": " sh u",
"すょ": " sh o",
"せぇ": " s e:",
"そぉ": " s o:",
"ざぁ": " z a:",
"じぃ": " j i:",
"ずぅ": " z u:",
"ずゃ": " zy a",
"ずゅ": " zy u",
"ずょ": " zy o",
"ぜぇ": " z e:",
"ぞぉ": " z o:",
"たぁ": " t a:",
"ちぃ": " ch i",
"つぁ": " ts a",
"つぃ": " ts i",
"つぅ": " ts u",
"つゃ": " ch a",
"つゅ": " ch u",
"つょ": " ch o",
"つぇ": " ts e",
"つぉ": " ts o",
"てぇ": " t e:",
"とぉ": " t o:",
"だぁ": " d a:",
"ぢぃ": " j i:",
"づぅ": " d u:",
"づゃ": " zy a",
"づゅ": " zy u",
"づょ": " zy o",
"でぇ": " d e:",
"なぁ": " n a:",
"にぃ": " n i:",
"ぬぅ": " n u:",
"ぬゃ": " ny a",
"ぬゅ": " ny u",
"ぬょ": " ny o",
"ねぇ": " n e:",
"のぉ": " n o:",
"はぁ": " h a:",
"ひぃ": " h i:",
"ふぅ": " f u:",
"ふゃ": " hy a",
"へぇ": " h e:",
"ほぉ": " h o:",
"ばぁ": " b a:",
"びぃ": " b i:",
"ぶぅ": " b u:",
"ぶゅ": " by u",
"べぇ": " b e:",
"ぼぉ": " b o:",
"ぱぁ": " p a:",
"ぴぃ": " p i:",
"ぷぅ": " p u:",
"ぷゃ": " py a",
"ぷゅ": " py u",
"ぷょ": " py o",
"ぺぇ": " p e:",
"ぽぉ": " p o:",
"まぁ": " m a:",
"みぃ": " m i:",
"むぅ": " m u:",
"むゃ": " my a",
"むゅ": " my u",
"むょ": " my o",
"めぇ": " m e:",
"もぉ": " m o:",
"やぁ": " y a:",
"ゆぅ": " y u:",
"ゆゃ": " y a:",
"ゆゅ": " y u:",
"ゆょ": " y o:",
"よぉ": " y o:",
"らぁ": " r a:",
"りぃ": " r i:",
"るぅ": " r u:",
"るゃ": " ry a",
"るゅ": " ry u",
"るょ": " ry o",
"れぇ": " r e:",
"ろぉ": " r o:",
"わぁ": " w a:",
"をぉ": " o:",
"う゛": " b u",
"でぃ": " d i",
"でゃ": " dy a",
"でゅ": " dy u",
"でょ": " dy o",
"てぃ": " t i",
"てゃ": " ty a",
"てゅ": " ty u",
"てょ": " ty o",
"すぃ": " s i",
"ずぁ": " z u",
"ずぃ": " z i",
"ずぇ": " z e",
"ずぉ": " z o",
"きゃ": " ky a",
"きゅ": " ky u",
"きょ": " ky o",
"しゃ": " sh a",
"しゅ": " sh u",
"しぇ": " sh e",
"しょ": " sh o",
"ちゃ": " ch a",
"ちゅ": " ch u",
"ちぇ": " ch e",
"ちょ": " ch o",
"とぅ": " t u",
"とゃ": " ty a",
"とゅ": " ty u",
"とょ": " ty o",
"どぁ": " d o ",
"どぅ": " d u",
"どゃ": " dy a",
"どゅ": " dy u",
"どょ": " dy o",
"どぉ": " d o:",
"にゃ": " ny a",
"にゅ": " ny u",
"にょ": " ny o",
"ひゃ": " hy a",
"ひゅ": " hy u",
"ひょ": " hy o",
"みゃ": " my a",
"みゅ": " my u",
"みょ": " my o",
"りゃ": " ry a",
"りゅ": " ry u",
"りょ": " ry o",
"ぎゃ": " gy a",
"ぎゅ": " gy u",
"ぎょ": " gy o",
"ぢぇ": " j e",
"ぢゃ": " j a",
"ぢゅ": " j u",
"ぢょ": " j o",
"じぇ": " j e",
"じゃ": " j a",
"じゅ": " j u",
"じょ": " j o",
"びゃ": " by a",
"びゅ": " by u",
"びょ": " by o",
"ぴゃ": " py a",
"ぴゅ": " py u",
"ぴょ": " py o",
"うぁ": " u a",
"うぃ": " w i",
"うぇ": " w e",
"うぉ": " w o",
"ふぁ": " f a",
"ふぃ": " f i",
"ふゅ": " hy u",
"ふょ": " hy o",
"ふぇ": " f e",
"ふぉ": " f o",
# 1音からなる変換規則
"": " a",
"": " i",
"": " u",
"": " v u", # ゔの処理を追加
"": " e",
"": " o",
"": " k a",
"": " k i",
"": " k u",
"": " k e",
"": " k o",
"": " s a",
"": " sh i",
"": " s u",
"": " s e",
"": " s o",
"": " t a",
"": " ch i",
"": " ts u",
"": " t e",
"": " t o",
"": " n a",
"": " n i",
"": " n u",
"": " n e",
"": " n o",
"": " h a",
"": " h i",
"": " f u",
"": " h e",
"": " h o",
"": " m a",
"": " m i",
"": " m u",
"": " m e",
"": " m o",
"": " r a",
"": " r i",
"": " r u",
"": " r e",
"": " r o",
"": " g a",
"": " g i",
"": " g u",
"": " g e",
"": " g o",
"": " z a",
"": " j i",
"": " z u",
"": " z e",
"": " z o",
"": " d a",
"": " j i",
"": " z u",
"": " d e",
"": " d o",
"": " b a",
"": " b i",
"": " b u",
"": " b e",
"": " b o",
"": " p a",
"": " p i",
"": " p u",
"": " p e",
"": " p o",
"": " y a",
"": " y u",
"": " y o",
"": " w a",
"": " i",
"": " e",
"": " N",
"": " q",
# ここまでに処理されてない ぁぃぅぇぉ はそのまま大文字扱い
"": " a",
"": " i",
"": " u",
"": " e",
"": " o",
"": " w a",
# 長音の処理
# for (pattern, replace_str) in JULIUS_LONG_VOWEL:
# text = pattern.sub(replace_str, text)
# text = text.replace("o u", "o:") # おう -> おーの音便
"": ":",
"": ":",
"": ":",
"-": ":",
# その他特別な処理
"": " o",
# ここまでに処理されていないゅ等もそのまま大文字扱い(追加)
"": " y a",
"": " y u",
"": " y o",
}
def hiragana2p(txt: str) -> str:
"""
Modification of `jaconv.hiragana2julius`.
- avoid using `:`, instead, `あーーー` -> `a a a a`.
- avoid converting `o u` to `o o` (because the input is already actual `yomi`).
- avoid using `N` for `` (for compatibility)
- use `v` for `` related text.
- add bare `` `` `` to `y a` `y u` `y o` (for compatibility).
"""
result = []
skip = 0
for i in range(len(txt)):
if skip:
skip -= 1
continue
for length in range(3, 0, -1):
if txt[i : i + length] in hiragana_map:
result.append(hiragana_map[txt[i : i + length]])
skip = length - 1
break
txt = "".join(result)
txt = txt.strip()
txt = txt.replace(":+", ":")
# ここまで`jaconv.hiragana2julius`と音便処理と長音処理をのぞいて同じ
# ここから`k a:: k i:`→`k a a a k i i`のように`:`の数だけ繰り返す処理
pattern = r"(\w)(:*)"
replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
txt = re.sub(pattern, replacement, txt)
txt = txt.replace("N", "n") # 促音のNをnに変換
return txt
def kata2phoneme(text: str) -> str:
"""Convert katakana text to phonemes."""
text = text.strip()
if text == "":
return [""]
elif text.startswith(""):
return [""] + kata2phoneme(text[1:])
res = []
prev = None
while text:
if re.match(_MARKS, text):
res.append(text)
text = text[1:]
continue
if text.startswith(""):
if prev:
res.append(prev[-1])
text = text[1:]
continue
res += hiragana2p(jaconv.kata2hira(text)).split(" ")
break
# res = _COLON_RX.sub(":", res)
return res
_SYMBOL_TOKENS = set(list("・、。?!"))
_NO_YOMI_TOKENS = set(list("「」『』―()[][]"))
_MARKS = re.compile(
r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
def text2sep_kata(text: str):
parsed = pyopenjtalk.run_frontend(text)
res = []
sep = []
for parts in parsed:
word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
"", ""
)
if yomi:
if re.match(_MARKS, yomi):
if len(word) > 1:
word = [replace_punctuation(i) for i in list(word)]
yomi = word
res += yomi
sep += word
continue
elif word not in rep_map.keys() and word not in rep_map.values():
word = ","
yomi = word
res.append(yomi)
else:
if word in _SYMBOL_TOKENS:
res.append(word)
elif word in ("", ""):
res.append("")
elif word in _NO_YOMI_TOKENS:
pass
else:
res.append(word)
sep.append(word)
return sep, res, get_accent(parsed)
def get_accent(parsed):
labels = pyopenjtalk.make_label(parsed)
phonemes = []
accents = []
for n, label in enumerate(labels):
phoneme = re.search(r"\-([^\+]*)\+", label).group(1)
if phoneme not in ["sil", "pau"]:
phonemes.append(phoneme.replace("cl", "q").lower())
else:
continue
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]:
a2_next = -1
else:
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
# Falling
if a1 == 0 and a2_next == a2 + 1:
accents.append(-1)
# Rising
elif a2 == 1 and a2_next == 2:
accents.append(1)
else:
accents.append(0)
return list(zip(phonemes, accents))
_ALPHASYMBOL_YOMI = {
"#": "シャープ",
"%": "パーセント",
"&": "アンド",
"+": "プラス",
"-": "マイナス",
":": "コロン",
";": "セミコロン",
"<": "小なり",
"=": "イコール",
">": "大なり",
"@": "アット",
"a": "エー",
"b": "ビー",
"c": "シー",
"d": "ディー",
"e": "イー",
"f": "エフ",
"g": "ジー",
"h": "エイチ",
"i": "アイ",
"j": "ジェー",
"k": "ケー",
"l": "エル",
"m": "エム",
"n": "エヌ",
"o": "オー",
"p": "ピー",
"q": "キュー",
"r": "アール",
"s": "エス",
"t": "ティー",
"u": "ユー",
"v": "ブイ",
"w": "ダブリュー",
"x": "エックス",
"y": "ワイ",
"z": "ゼット",
"α": "アルファ",
"β": "ベータ",
"γ": "ガンマ",
"δ": "デルタ",
"ε": "イプシロン",
"ζ": "ゼータ",
"η": "イータ",
"θ": "シータ",
"ι": "イオタ",
"κ": "カッパ",
"λ": "ラムダ",
"μ": "ミュー",
"ν": "ニュー",
"ξ": "クサイ",
"ο": "オミクロン",
"π": "パイ",
"ρ": "ロー",
"σ": "シグマ",
"τ": "タウ",
"υ": "ウプシロン",
"φ": "ファイ",
"χ": "カイ",
"ψ": "プサイ",
"ω": "オメガ",
}
_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
_CURRENCY_MAP = {"$": "ドル", "¥": "", "£": "ポンド", "": "ユーロ"}
_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
def japanese_convert_numbers_to_words(text: str) -> str:
res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
return res
def japanese_convert_alpha_symbols_to_words(text: str) -> str:
return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
def is_japanese_character(char):
# 定义日语文字系统的 Unicode 范围
japanese_ranges = [
(0x3040, 0x309F), # 平假名
(0x30A0, 0x30FF), # 片假名
(0x4E00, 0x9FFF), # 汉字 (CJK Unified Ideographs)
(0x3400, 0x4DBF), # 汉字扩展 A
(0x20000, 0x2A6DF), # 汉字扩展 B
# 可以根据需要添加其他汉字扩展范围
]
# 将字符的 Unicode 编码转换为整数
char_code = ord(char)
# 检查字符是否在任何一个日语范围内
for start, end in japanese_ranges:
if start <= char_code <= end:
return True
return False
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"": ".",
"": "...",
"···": "...",
"・・・": "...",
"·": ",",
"": ",",
"": ",",
"$": ".",
"": "'",
"": "'",
'"': "'",
"": "'",
"": "'",
"": "'",
"": "'",
"(": "'",
")": "'",
"": "'",
"": "'",
"": "'",
"": "'",
"[": "'",
"]": "'",
"": "-",
"": "-",
"": "-",
"~": "-",
"": "'",
"": "'",
}
def replace_punctuation(text):
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
+ "".join(punctuation)
+ r"]+",
"",
replaced_text,
)
return replaced_text
def text_normalize(text):
res = unicodedata.normalize("NFKC", text)
res = japanese_convert_numbers_to_words(res)
# res = "".join([i for i in res if is_japanese_character(i)])
res = replace_punctuation(res)
res = res.replace("", "")
return res
def distribute_phone(n_phone, n_word):
phones_per_word = [0] * n_word
for task in range(n_phone):
min_tasks = min(phones_per_word)
min_index = phones_per_word.index(min_tasks)
phones_per_word[min_index] += 1
return phones_per_word
def handle_long(sep_phonemes):
for i in range(len(sep_phonemes)):
if sep_phonemes[i][0] == "":
sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
if "" in sep_phonemes[i]:
for j in range(len(sep_phonemes[i])):
if sep_phonemes[i][j] == "":
sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
return sep_phonemes
tokenizer = AutoTokenizer.from_pretrained("./utils/bert_vits2/bert/deberta-v2-large-japanese-char-wwm")
def align_tones(phones, tones):
res = []
for pho in phones:
temp = [0] * len(pho)
for idx, p in enumerate(pho):
if len(tones) == 0:
break
if p == tones[0][0]:
temp[idx] = tones[0][1]
if idx > 0:
temp[idx] += temp[idx - 1]
tones.pop(0)
temp = [0] + temp
temp = temp[:-1]
if -1 in temp:
temp = [i + 1 for i in temp]
res.append(temp)
res = [i for j in res for i in j]
assert not any([i < 0 for i in res]) and not any([i > 1 for i in res])
return res
def rearrange_tones(tones, phones):
res = [0] * len(tones)
for i in range(len(tones)):
if i == 0:
if tones[i] not in punctuation:
res[i] = 1
elif tones[i] == prev:
if phones[i] in punctuation:
res[i] = 0
else:
res[i] = 1
elif tones[i] > prev:
res[i] = 2
elif tones[i] < prev:
res[i - 1] = 3
res[i] = 1
prev = tones[i]
return res
def g2p(norm_text):
sep_text, sep_kata, acc = text2sep_kata(norm_text)
sep_tokenized = []
for i in sep_text:
if i not in punctuation:
sep_tokenized.append(tokenizer.tokenize(i))
else:
sep_tokenized.append([i])
sep_phonemes = handle_long([kata2phoneme(i) for i in sep_kata])
# 异常处理MeCab不认识的词的话会一路传到这里来然后炸掉。目前来看只有那些超级稀有的生僻词会出现这种情况
for i in sep_phonemes:
for j in i:
assert j in symbols, (sep_text, sep_kata, sep_phonemes)
tones = align_tones(sep_phonemes, acc)
word2ph = []
for token, phoneme in zip(sep_tokenized, sep_phonemes):
phone_len = len(phoneme)
word_len = len(token)
aaa = distribute_phone(phone_len, word_len)
word2ph += aaa
phones = ["_"] + [j for i in sep_phonemes for j in i] + ["_"]
# tones = [0] + rearrange_tones(tones, phones[1:-1]) + [0]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
assert len(phones) == len(tones)
return phones, tones, word2ph
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("./utils/bert_vits2/bert/deberta-v2-large-japanese")
text = "hello,こんにちは、世界ー!……"
from text.japanese_bert import get_bert_feature
text = text_normalize(text)
print(text)
phones, tones, word2ph = g2p(text)
bert = get_bert_feature(text, word2ph)
print(phones, tones, word2ph, bert.shape)

View File

@ -0,0 +1,65 @@
import sys
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from ..config import config
from ..text.japanese import text2sep_kata
LOCAL_PATH = "./utils/bert_vits2/bert/deberta-v2-large-japanese-char-wwm"
tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
models = dict()
def get_bert_feature(
text,
word2ph,
device=config.bert_gen_config.device,
style_text=None,
style_weight=0.7,
):
text = "".join(text2sep_kata(text)[0])
if style_text:
style_text = "".join(text2sep_kata(style_text)[0])
if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
and device == "cpu"
):
device = "mps"
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = models[device](**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = models[device](**style_inputs, output_hidden_states=True)
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
style_res_mean = style_res.mean(0)
assert len(word2ph) == len(text) + 2
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T

View File

@ -0,0 +1,429 @@
a AA a
ai AA ai
an AA an
ang AA ang
ao AA ao
ba b a
bai b ai
ban b an
bang b ang
bao b ao
bei b ei
ben b en
beng b eng
bi b i
bian b ian
biao b iao
bie b ie
bin b in
bing b ing
bo b o
bu b u
ca c a
cai c ai
can c an
cang c ang
cao c ao
ce c e
cei c ei
cen c en
ceng c eng
cha ch a
chai ch ai
chan ch an
chang ch ang
chao ch ao
che ch e
chen ch en
cheng ch eng
chi ch ir
chong ch ong
chou ch ou
chu ch u
chua ch ua
chuai ch uai
chuan ch uan
chuang ch uang
chui ch ui
chun ch un
chuo ch uo
ci c i0
cong c ong
cou c ou
cu c u
cuan c uan
cui c ui
cun c un
cuo c uo
da d a
dai d ai
dan d an
dang d ang
dao d ao
de d e
dei d ei
den d en
deng d eng
di d i
dia d ia
dian d ian
diao d iao
die d ie
ding d ing
diu d iu
dong d ong
dou d ou
du d u
duan d uan
dui d ui
dun d un
duo d uo
e EE e
ei EE ei
en EE en
eng EE eng
er EE er
fa f a
fan f an
fang f ang
fei f ei
fen f en
feng f eng
fo f o
fou f ou
fu f u
ga g a
gai g ai
gan g an
gang g ang
gao g ao
ge g e
gei g ei
gen g en
geng g eng
gong g ong
gou g ou
gu g u
gua g ua
guai g uai
guan g uan
guang g uang
gui g ui
gun g un
guo g uo
ha h a
hai h ai
han h an
hang h ang
hao h ao
he h e
hei h ei
hen h en
heng h eng
hong h ong
hou h ou
hu h u
hua h ua
huai h uai
huan h uan
huang h uang
hui h ui
hun h un
huo h uo
ji j i
jia j ia
jian j ian
jiang j iang
jiao j iao
jie j ie
jin j in
jing j ing
jiong j iong
jiu j iu
ju j v
jv j v
juan j van
jvan j van
jue j ve
jve j ve
jun j vn
jvn j vn
ka k a
kai k ai
kan k an
kang k ang
kao k ao
ke k e
kei k ei
ken k en
keng k eng
kong k ong
kou k ou
ku k u
kua k ua
kuai k uai
kuan k uan
kuang k uang
kui k ui
kun k un
kuo k uo
la l a
lai l ai
lan l an
lang l ang
lao l ao
le l e
lei l ei
leng l eng
li l i
lia l ia
lian l ian
liang l iang
liao l iao
lie l ie
lin l in
ling l ing
liu l iu
lo l o
long l ong
lou l ou
lu l u
luan l uan
lun l un
luo l uo
lv l v
lve l ve
ma m a
mai m ai
man m an
mang m ang
mao m ao
me m e
mei m ei
men m en
meng m eng
mi m i
mian m ian
miao m iao
mie m ie
min m in
ming m ing
miu m iu
mo m o
mou m ou
mu m u
na n a
nai n ai
nan n an
nang n ang
nao n ao
ne n e
nei n ei
nen n en
neng n eng
ni n i
nian n ian
niang n iang
niao n iao
nie n ie
nin n in
ning n ing
niu n iu
nong n ong
nou n ou
nu n u
nuan n uan
nun n un
nuo n uo
nv n v
nve n ve
o OO o
ou OO ou
pa p a
pai p ai
pan p an
pang p ang
pao p ao
pei p ei
pen p en
peng p eng
pi p i
pian p ian
piao p iao
pie p ie
pin p in
ping p ing
po p o
pou p ou
pu p u
qi q i
qia q ia
qian q ian
qiang q iang
qiao q iao
qie q ie
qin q in
qing q ing
qiong q iong
qiu q iu
qu q v
qv q v
quan q van
qvan q van
que q ve
qve q ve
qun q vn
qvn q vn
ran r an
rang r ang
rao r ao
re r e
ren r en
reng r eng
ri r ir
rong r ong
rou r ou
ru r u
rua r ua
ruan r uan
rui r ui
run r un
ruo r uo
sa s a
sai s ai
san s an
sang s ang
sao s ao
se s e
sen s en
seng s eng
sha sh a
shai sh ai
shan sh an
shang sh ang
shao sh ao
she sh e
shei sh ei
shen sh en
sheng sh eng
shi sh ir
shou sh ou
shu sh u
shua sh ua
shuai sh uai
shuan sh uan
shuang sh uang
shui sh ui
shun sh un
shuo sh uo
si s i0
song s ong
sou s ou
su s u
suan s uan
sui s ui
sun s un
suo s uo
ta t a
tai t ai
tan t an
tang t ang
tao t ao
te t e
tei t ei
teng t eng
ti t i
tian t ian
tiao t iao
tie t ie
ting t ing
tong t ong
tou t ou
tu t u
tuan t uan
tui t ui
tun t un
tuo t uo
wa w a
wai w ai
wan w an
wang w ang
wei w ei
wen w en
weng w eng
wo w o
wu w u
xi x i
xia x ia
xian x ian
xiang x iang
xiao x iao
xie x ie
xin x in
xing x ing
xiong x iong
xiu x iu
xu x v
xv x v
xuan x van
xvan x van
xue x ve
xve x ve
xun x vn
xvn x vn
ya y a
yan y En
yang y ang
yao y ao
ye y E
yi y i
yin y in
ying y ing
yo y o
yong y ong
you y ou
yu y v
yv y v
yuan y van
yvan y van
yue y ve
yve y ve
yun y vn
yvn y vn
za z a
zai z ai
zan z an
zang z ang
zao z ao
ze z e
zei z ei
zen z en
zeng z eng
zha zh a
zhai zh ai
zhan zh an
zhang zh ang
zhao zh ao
zhe zh e
zhei zh ei
zhen zh en
zheng zh eng
zhi zh ir
zhong zh ong
zhou zh ou
zhu zh u
zhua zh ua
zhuai zh uai
zhuan zh uan
zhuang zh uang
zhui zh ui
zhun zh un
zhuo zh uo
zi z i0
zong z ong
zou z ou
zu z u
zuan z uan
zui z ui
zun z un
zuo z uo

View File

@ -0,0 +1,187 @@
punctuation = ["!", "?", "", ",", ".", "'", "-"]
pu_symbols = punctuation + ["SP", "UNK"]
pad = "_"
# chinese
zh_symbols = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"b",
"c",
"ch",
"d",
"e",
"ei",
"en",
"eng",
"er",
"f",
"g",
"h",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"j",
"k",
"l",
"m",
"n",
"o",
"ong",
"ou",
"p",
"q",
"r",
"s",
"sh",
"t",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
"w",
"x",
"y",
"z",
"zh",
"AA",
"EE",
"OO",
]
num_zh_tones = 6
# japanese
ja_symbols = [
"N",
"a",
"a:",
"b",
"by",
"ch",
"d",
"dy",
"e",
"e:",
"f",
"g",
"gy",
"h",
"hy",
"i",
"i:",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"o:",
"p",
"py",
"q",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"ty",
"u",
"u:",
"w",
"y",
"z",
"zy",
]
num_ja_tones = 2
# English
en_symbols = [
"aa",
"ae",
"ah",
"ao",
"aw",
"ay",
"b",
"ch",
"d",
"dh",
"eh",
"er",
"ey",
"f",
"g",
"hh",
"ih",
"iy",
"jh",
"k",
"l",
"m",
"n",
"ng",
"ow",
"oy",
"p",
"r",
"s",
"sh",
"t",
"th",
"uh",
"uw",
"V",
"w",
"y",
"z",
"zh",
]
num_en_tones = 4
# combine all symbols
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
symbols = [pad] + normal_symbols + pu_symbols
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
# combine all tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones
# language maps
language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
num_languages = len(language_id_map.keys())
language_tone_start_map = {
"ZH": 0,
"JP": num_zh_tones,
"EN": num_zh_tones + num_ja_tones,
}
if __name__ == "__main__":
a = set(zh_symbols)
b = set(en_symbols)
print(sorted(a & b))

View File

@ -0,0 +1,776 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Tuple
import jieba
from pypinyin import lazy_pinyin
from pypinyin import Style
class ToneSandhi:
def __init__(self):
self.must_neural_tone_words = {
"麻烦",
"麻利",
"鸳鸯",
"高粱",
"骨头",
"骆驼",
"马虎",
"首饰",
"馒头",
"馄饨",
"风筝",
"难为",
"队伍",
"阔气",
"闺女",
"门道",
"锄头",
"铺盖",
"铃铛",
"铁匠",
"钥匙",
"里脊",
"里头",
"部分",
"那么",
"道士",
"造化",
"迷糊",
"连累",
"这么",
"这个",
"运气",
"过去",
"软和",
"转悠",
"踏实",
"跳蚤",
"跟头",
"趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
}
self.must_not_neural_tone_words = {
"男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
}
self.punc = ":,;。?!“”‘’':,;.?!"
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
# e.g.
# word: "家里"
# pos: "s"
# finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word):
if (
j - 1 >= 0
and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
finals[-1] = finals[-1][:-1] + "5"
elif len(word) >= 1 and word[-1] in "的地得":
finals[-1] = finals[-1][:-1] + "5"
# e.g. 走了, 看着, 去过
# elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
# finals[-1] = finals[-1][:-1] + "5"
elif (
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
finals[-1] = finals[-1][:-1] + "5"
# e.g. 上来, 下去
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif (
ge_idx >= 1
and (
word[ge_idx - 1].isnumeric()
or word[ge_idx - 1] in "几有两半多各整每做是"
)
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word)
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list):
# conventional neural in Chinese
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, [])
return finals
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
# e.g. 看不懂
if len(word) == 3 and word[1] == "":
finals[1] = finals[1][:-1] + "5"
else:
for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕
if char == "" and i + 1 < len(word) and finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]
):
return finals
# "一" between reduplication words should be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
finals[1] = finals[1][:-1] + "5"
# when "一" is ordinal word, it should be yi1
elif word.startswith("第一"):
finals[1] = finals[1][:-1] + "1"
else:
for i, char in enumerate(word):
if char == "" and i + 1 < len(word):
# "一" before tone4 should be yi2, e.g. 一段
if finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
# "一" before non-tone4 should be yi4, e.g. 一天
else:
# "一" 后面如果是标点,还读一声
if word[i + 1] not in self.punc:
finals[i] = finals[i][:-1] + "4"
return finals
def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword) :]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[: -len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
if len(word) == 2 and self._all_tone_three(finals):
finals[0] = finals[0][:-1] + "2"
elif len(word) == 3:
word_list = self._split_word(word)
if self._all_tone_three(finals):
# disyllabic + monosyllabic, e.g. 蒙古/包
if len(word_list[0]) == 2:
finals[0] = finals[0][:-1] + "2"
finals[1] = finals[1][:-1] + "2"
# monosyllabic + disyllabic, e.g. 纸/老虎
elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2"
else:
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
if len(finals_list) == 2:
for i, sub in enumerate(finals_list):
# e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢
elif (
i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, [])
# split idiom into two words who's length is 2
elif len(word) == 4:
finals_list = [finals[:2], finals[2:]]
finals = []
for sub in finals_list:
if self._all_tone_three(sub):
sub[0] = sub[0][:-1] + "2"
finals += sub
return finals
def _all_tone_three(self, finals: List[str]) -> bool:
return all(x[-1] == "3" for x in finals)
# merge "不" and the word behind it
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
last_word = ""
for word, pos in seg:
if last_word == "":
word = last_word + word
if word != "":
new_seg.append((word, pos))
last_word = word[:]
if last_word == "":
new_seg.append((last_word, "d"))
last_word = ""
return new_seg
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
# function 2: merge single "一" and the word behind it
# if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
# e.g.
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
# output seg: [['听一听', 'v']]
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = [] * len(seg)
# function 1
i = 0
while i < len(seg):
word, pos = seg[i]
if (
i - 1 >= 0
and word == ""
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
i += 2
else:
if (
i - 2 >= 0
and seg[i - 1][0] == ""
and seg[i - 2][0] == word
and pos == "v"
):
continue
else:
new_seg.append([word, pos])
i += 1
seg = [i for i in new_seg if len(i) > 0]
new_seg = []
# function 2
for i, (word, pos) in enumerate(seg):
if new_seg and new_seg[-1][0] == "":
new_seg[-1][0] = new_seg[-1][0] + word
else:
new_seg.append([word, pos])
return new_seg
# the first and the second words are all_tone_three
def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if (
i - 1 >= 0
and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
new_seg.append([word, pos])
else:
new_seg.append([word, pos])
return new_seg
def _is_reduplication(self, word: str) -> bool:
return len(word) == 2 and word[0] == word[1]
# the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if (
i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
new_seg.append([word, pos])
else:
new_seg.append([word, pos])
return new_seg
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and seg[i - 1][0] != "#":
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]:
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg)
try:
seg = self._merge_yi(seg)
except:
print("_merge_yi failed")
seg = self._merge_reduplication(seg)
seg = self._merge_continuous_three_tones(seg)
seg = self._merge_continuous_three_tones_2(seg)
seg = self._merge_er(seg)
return seg
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)
finals = self._three_sandhi(word, finals)
return finals

View File

@ -0,0 +1,3 @@
"""
工具包
"""

View File

@ -0,0 +1,197 @@
import regex as re
try:
from ..config import config
LANGUAGE_IDENTIFICATION_LIBRARY = (
config.webui_config.language_identification_library
)
except:
LANGUAGE_IDENTIFICATION_LIBRARY = "langid"
module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
langid_languages = [
"af",
"am",
"an",
"ar",
"as",
"az",
"be",
"bg",
"bn",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"dz",
"el",
"en",
"eo",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"ga",
"gl",
"gu",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jv",
"ka",
"kk",
"km",
"kn",
"ko",
"ku",
"ky",
"la",
"lb",
"lo",
"lt",
"lv",
"mg",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"nb",
"ne",
"nl",
"nn",
"no",
"oc",
"or",
"pa",
"pl",
"ps",
"pt",
"qu",
"ro",
"ru",
"rw",
"se",
"si",
"sk",
"sl",
"sq",
"sr",
"sv",
"sw",
"ta",
"te",
"th",
"tl",
"tr",
"ug",
"uk",
"ur",
"vi",
"vo",
"wa",
"xh",
"zh",
"zu",
]
def classify_language(text: str, target_languages: list = None) -> str:
if module == "fastlid" or module == "fasttext":
from fastlid import fastlid, supported_langs
classifier = fastlid
if target_languages != None:
target_languages = [
lang for lang in target_languages if lang in supported_langs
]
fastlid.set_languages = target_languages
elif module == "langid":
import langid
classifier = langid.classify
if target_languages != None:
target_languages = [
lang for lang in target_languages if lang in langid_languages
]
langid.set_languages(target_languages)
else:
raise ValueError(f"Wrong module {module}")
lang = classifier(text)[0]
return lang
def classify_zh_ja(text: str) -> str:
for idx, char in enumerate(text):
unicode_val = ord(char)
# 检测日语字符
if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF:
return "ja"
# 检测汉字字符
if 0x4E00 <= unicode_val <= 0x9FFF:
# 检查周围的字符
next_char = text[idx + 1] if idx + 1 < len(text) else None
if next_char and (
0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF
):
return "ja"
return "zh"
def split_alpha_nonalpha(text, mode=1):
if mode == 1:
pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\d\s])(?=[\p{Latin}])|(?<=[\p{Latin}\s])(?=[\u4e00-\u9fff\u3040-\u30FF\d])"
elif mode == 2:
pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\s])(?=[\p{Latin}\d])|(?<=[\p{Latin}\d\s])(?=[\u4e00-\u9fff\u3040-\u30FF])"
else:
raise ValueError("Invalid mode. Supported modes are 1 and 2.")
return re.split(pattern, text)
if __name__ == "__main__":
text = "这是一个测试文本"
print(classify_language(text))
print(classify_zh_ja(text)) # "zh"
text = "これはテストテキストです"
print(classify_language(text))
print(classify_zh_ja(text)) # "ja"
text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
print(split_alpha_nonalpha(text, mode=1))
# output: ['vits', '和', 'Bert-VITS', '2是', 'tts', '模型。花费3', 'days.花费3天。Take 3 days']
print(split_alpha_nonalpha(text, mode=2))
# output: ['vits', '和', 'Bert-VITS2', '是', 'tts', '模型。花费', '3days.花费', '3', '天。Take 3 days']
text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
print(split_alpha_nonalpha(text, mode=1))
# output: ['vits ', '和 ', 'Bert-VITS', '2 ', '是 ', 'tts ', '模型。花费3', 'days.花费3天。Take ', '3 ', 'days']
text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
print(split_alpha_nonalpha(text, mode=2))
# output: ['vits ', '和 ', 'Bert-VITS2 ', '是 ', 'tts ', '模型。花费', '3days.花费', '3', '天。Take ', '3 ', 'days']

View File

@ -0,0 +1,17 @@
"""
logger封装
"""
from loguru import logger
import sys
# 移除所有默认的处理器
logger.remove()
# 自定义格式并添加到标准输出
log_format = (
"<g>{time:MM-DD HH:mm:ss}</g> <lvl>{level:<9}</lvl>| {file}:{line} | {message}"
)
logger.add(sys.stdout, format=log_format, backtrace=True, diagnose=True)

View File

@ -0,0 +1,173 @@
import logging
import regex as re
from ..tools.classify_language import classify_language, split_alpha_nonalpha
def check_is_none(item) -> bool:
"""none -> True, not none -> False"""
return (
item is None
or (isinstance(item, str) and str(item).isspace())
or str(item) == ""
)
def markup_language(text: str, target_languages: list = None) -> str:
pattern = (
r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\\\\\‟…‧﹏.]+"
)
sentences = re.split(pattern, text)
pre_lang = ""
p = 0
if target_languages is not None:
sorted_target_languages = sorted(target_languages)
if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
new_sentences = []
for sentence in sentences:
new_sentences.extend(split_alpha_nonalpha(sentence))
sentences = new_sentences
for sentence in sentences:
if check_is_none(sentence):
continue
lang = classify_language(sentence, target_languages)
if pre_lang == "":
text = text[:p] + text[p:].replace(
sentence, f"[{lang.upper()}]{sentence}", 1
)
p += len(f"[{lang.upper()}]")
elif pre_lang != lang:
text = text[:p] + text[p:].replace(
sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1
)
p += len(f"[{pre_lang.upper()}][{lang.upper()}]")
pre_lang = lang
p += text[p:].index(sentence) + len(sentence)
text += f"[{pre_lang.upper()}]"
return text
def split_by_language(text: str, target_languages: list = None) -> list:
pattern = (
r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
r"\\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\\\\\‟…‧﹏.]+"
)
sentences = re.split(pattern, text)
pre_lang = ""
start = 0
end = 0
sentences_list = []
if target_languages is not None:
sorted_target_languages = sorted(target_languages)
if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
new_sentences = []
for sentence in sentences:
new_sentences.extend(split_alpha_nonalpha(sentence))
sentences = new_sentences
for sentence in sentences:
if check_is_none(sentence):
continue
lang = classify_language(sentence, target_languages)
end += text[end:].index(sentence)
if pre_lang != "" and pre_lang != lang:
sentences_list.append((text[start:end], pre_lang))
start = end
end += len(sentence)
pre_lang = lang
sentences_list.append((text[start:], pre_lang))
return sentences_list
def sentence_split(text: str, max: int) -> list:
pattern = r"[!(),—+\-.:;??。,、;:]+"
sentences = re.split(pattern, text)
discarded_chars = re.findall(pattern, text)
sentences_list, count, p = [], 0, 0
# 按被分割的符号遍历
for i, discarded_chars in enumerate(discarded_chars):
count += len(sentences[i]) + len(discarded_chars)
if count >= max:
sentences_list.append(text[p : p + count].strip())
p += count
count = 0
# 加入最后剩余的文本
if p < len(text):
sentences_list.append(text[p:])
return sentences_list
def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None):
# 如果该speaker只支持一种语言
if speaker_lang is not None and len(speaker_lang) == 1:
if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]:
logging.debug(
f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}'
)
lang = speaker_lang[0]
sentences_list = []
if lang.upper() != "MIX":
if max <= 0:
sentences_list.append(
markup_language(text, speaker_lang)
if lang.upper() == "AUTO"
else f"[{lang.upper()}]{text}[{lang.upper()}]"
)
else:
for i in sentence_split(text, max):
if check_is_none(i):
continue
sentences_list.append(
markup_language(i, speaker_lang)
if lang.upper() == "AUTO"
else f"[{lang.upper()}]{i}[{lang.upper()}]"
)
else:
sentences_list.append(text)
for i in sentences_list:
logging.debug(i)
return sentences_list
if __name__ == "__main__":
text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。"
print(markup_language(text, target_languages=None))
print(sentence_split(text, max=50))
print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None))
text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好今天我们要介绍VITS项目其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。"
print(split_by_language(text, ["zh", "ja", "en"]))
text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
print(split_by_language(text, ["zh", "ja", "en"]))
# output: [('vits', 'en'), ('和', 'ja'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
print(split_by_language(text, ["zh", "en"]))
# output: [('vits', 'en'), ('和', 'zh'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
text = "vits 和 Bert-VITS2 是 tts 模型。花费 3 days. 花费 3天。Take 3 days"
print(split_by_language(text, ["zh", "en"]))
# output: [('vits ', 'en'), ('和 ', 'zh'), ('Bert-VITS2 ', 'en'), ('是 ', 'zh'), ('tts ', 'en'), ('模型。花费 ', 'zh'), ('3 days. ', 'en'), ('花费 3天。', 'zh'), ('Take 3 days', 'en')]

View File

@ -0,0 +1,62 @@
"""
翻译api
"""
from ..config import config
import random
import hashlib
import requests
def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
"""
:param Sentence: 待翻译语句
:param from_Language: 待翻译语句语言
:param to_Language: 目标语言
:return: 翻译后语句 出错时返回None
常见语言代码中文 zh 英语 en 日语 jp
"""
appid = config.translate_config.app_key
key = config.translate_config.secret_key
if appid == "" or key == "":
return "请开发者在config.yml中配置app_key与secret_key"
url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
texts = Sentence.splitlines()
outTexts = []
for t in texts:
if t != "":
# 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
salt = str(random.randint(1, 100000))
signString = appid + t + salt + key
hs = hashlib.md5()
hs.update(signString.encode("utf-8"))
signString = hs.hexdigest()
if from_Language == "":
from_Language = "auto"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
payload = {
"q": t,
"from": from_Language,
"to": to_Language,
"appid": appid,
"salt": salt,
"sign": signString,
}
# 发送请求
try:
response = requests.post(
url=url, data=payload, headers=headers, timeout=3
)
response = response.json()
if "trans_result" in response.keys():
result = response["trans_result"][0]
if "dst" in result.keys():
dst = result["dst"]
outTexts.append(dst)
except Exception:
return Sentence
else:
outTexts.append(t)
return "\n".join(outTexts)

View File

@ -0,0 +1,209 @@
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

460
utils/bert_vits2/utils.py Normal file
View File

@ -0,0 +1,460 @@
import os
import glob
import argparse
import logging
import json
import shutil
import subprocess
import numpy as np
from huggingface_hub import hf_hub_download
from scipy.io.wavfile import read
import torch
import re
MATPLOTLIB_FLAG = False
logger = logging.getLogger(__name__)
def download_emo_models(mirror, repo_id, model_name):
if mirror == "openi":
import openi
openi.model.download_model(
"Stardust_minus/Bert-VITS2",
repo_id.split("/")[-1],
"./emotional",
)
else:
hf_hub_download(
repo_id,
"pytorch_model.bin",
local_dir=model_name,
local_dir_use_symlinks=False,
)
def download_checkpoint(
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
):
repo_id = repo_config["repo_id"]
f_list = glob.glob(os.path.join(dir_path, regex))
if f_list:
print("Use existed model, skip downloading.")
return
if mirror.lower() == "openi":
import openi
kwargs = {"token": token} if token else {}
openi.login(**kwargs)
model_image = repo_config["model_image"]
openi.model.download_model(repo_id, model_image, dir_path)
fs = glob.glob(os.path.join(dir_path, model_image, "*.pth"))
for file in fs:
shutil.move(file, dir_path)
shutil.rmtree(os.path.join(dir_path, model_image))
else:
for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
hf_hub_download(
repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False
)
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
elif optimizer is None and not skip_optimizer:
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
new_opt_dict = optimizer.state_dict()
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
optimizer.load_state_dict(new_opt_dict)
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "emb_g" not in k
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except:
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
logger.warn(
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
)
else:
logger.error(f"{k} is not in the checkpoint")
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/base.json",
help="JSON file for configuration",
)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
with open(config_save_path, "w", encoding="utf-8") as f:
f.write(data)
else:
with open(config_save_path, "r", vencoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
def name_key(_f):
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
def time_key(_f):
return os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
def x_sorted(_x):
return sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (
x_sorted("G")[:-n_ckpts_to_keep]
+ x_sorted("D")[:-n_ckpts_to_keep]
+ x_sorted("WD")[:-n_ckpts_to_keep]
)
]
def del_info(fn):
return logger.info(f".. Free up space by deleting ckpt {fn}")
def del_routine(x):
return [os.remove(x), del_info(x)]
[del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
# print("config_path: ", config_path)
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def load_model(model_path, config_path):
hps = get_hparams_from_file(config_path)
net = SynthesizerTrn(
# len(symbols),
108,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to("cpu")
_ = net.eval()
_ = load_checkpoint(model_path, net, None, skip_optimizer=True)
return net
def mix_model(
network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
):
if hasattr(network1, "module"):
state_dict1 = network1.module.state_dict()
state_dict2 = network2.module.state_dict()
else:
state_dict1 = network1.state_dict()
state_dict2 = network2.state_dict()
for k in state_dict1.keys():
if k not in state_dict2.keys():
continue
if "enc_p" in k:
state_dict1[k] = (
state_dict1[k].clone() * tone_ratio[0]
+ state_dict2[k].clone() * tone_ratio[1]
)
else:
state_dict1[k] = (
state_dict1[k].clone() * voice_ratio[0]
+ state_dict2[k].clone() * voice_ratio[1]
)
for k in state_dict2.keys():
if k not in state_dict1.keys():
state_dict1[k] = state_dict2[k].clone()
torch.save(
{"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
output_path,
)
def get_steps(model_path):
matches = re.findall(r"\d+", model_path)
return matches[-1] if matches else None

463
utils/bert_vits2_utils.py Normal file
View File

@ -0,0 +1,463 @@
import gc
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
import logging
import gradio as gr
import librosa
# bert_vits2
from .bert_vits2 import utils
from .bert_vits2.infer import get_net_g, latest_version, infer_multilang, infer
from .bert_vits2.config import config
from .bert_vits2 import re_matching
from .bert_vits2.tools.sentence import split_by_language
logger = logging.getLogger(__name__)
class TextToSpeech:
def __init__(self,
device='cuda',
):
self.device = device = torch.device(device)
if config.webui_config.debug:
logger.info("Enable DEBUG")
hps = utils.get_hparams_from_file(config.webui_config.config_path)
self.hps = hps
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
self.version = version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
self.net_g = net_g
self.speaker_ids = speaker_ids = hps.data.spk2id
self.speakers = speakers = list(speaker_ids.keys())
self.speaker = speakers[0]
self.languages = languages = ["ZH", "JP", "EN", "mix", "auto"]
def free_up_memory(self):
# Prior inference run might have large variables not cleaned up due to exception during the run.
# Free up as much memory as possible to allow this run to be successful.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_mix(self, slice):
_speaker = slice.pop()
_text, _lang = [], []
for lang, content in slice:
content = content.split("|")
content = [part for part in content if part != ""]
if len(content) == 0:
continue
if len(_text) == 0:
_text = [[part] for part in content]
_lang = [[lang] for part in content]
else:
_text[-1].append(content[0])
_lang[-1].append(lang)
if len(content) > 1:
_text += [[part] for part in content[1:]]
_lang += [[lang] for part in content[1:]]
return _text, _lang, _speaker
def process_auto(self, text):
_text, _lang = [], []
for slice in text.split("|"):
if slice == "":
continue
temp_text, temp_lang = [], []
sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
for sentence, lang in sentences_list:
if sentence == "":
continue
temp_text.append(sentence)
if lang == "ja":
lang = "jp"
temp_lang.append(lang.upper())
_text.append(temp_text)
_lang.append(temp_lang)
return _text, _lang
def generate_audio(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
skip_start=False,
skip_end=False,
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
style_text=style_text,
style_weight=style_weight,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def generate_audio_multilang(
self,
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
skip_start=False,
skip_end=False,
en_ratio=1.0
):
audio_list = []
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
self.free_up_memory()
with torch.no_grad():
for idx, piece in enumerate(slices):
skip_start = idx != 0
skip_end = idx != len(slices) - 1
audio = infer_multilang(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language[idx],
hps=self.hps,
net_g=self.net_g,
device=self.device,
skip_start=skip_start,
skip_end=skip_end,
en_ratio=en_ratio
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
return audio_list
def process_text(self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text=None,
style_weight=0,
en_ratio=1.0
):
hps = self.hps
audio_list = []
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
for slice in re_matching.text_matching(text):
_text, _lang, _speaker = self.process_mix(slice)
if _speaker is None:
continue
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
elif language.lower() == "auto":
_text, _lang = self.process_auto(text)
print(f"Text: {_text}\nLang: {_lang}")
audio_list.extend(
self.generate_audio_multilang(
_text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
_lang,
reference_audio,
emotion,
en_ratio=en_ratio
)
)
else:
audio_list.extend(
self.generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
)
return audio_list
def tts_split(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
cut_by_sent,
interval_between_para,
interval_between_sent,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
):
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
text = text.replace("|", "")
para_list = re_matching.cut_para(text)
para_list = [p for p in para_list if p != ""]
audio_list = []
for p in para_list:
if not cut_by_sent:
audio_list += self.process_text(
p,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
audio_list.append(silence)
else:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
sent_list = [s for s in sent_list if s != ""]
for s in sent_list:
audio_list_sent += self.process_text(
s,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
en_ratio
)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros(
(int)(44100 * (interval_between_para - interval_between_sent))
)
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
np.concatenate(audio_list_sent)
) # 对完整句子做音量归一
audio_list.append(audio16bit)
audio_concat = np.concatenate(audio_list)
return ("Success", (self.hps.data.sampling_rate, audio_concat))
def tts_fn(
self,
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
prompt_mode,
style_text=None,
style_weight=0,
):
if style_text == "":
style_text = None
if prompt_mode == "Audio prompt":
if reference_audio == None:
return ("Invalid audio prompt", None)
else:
reference_audio = self.load_audio(reference_audio)[1]
else:
reference_audio = None
audio_list = self.process_text(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
reference_audio,
emotion,
style_text,
style_weight,
)
audio_concat = np.concatenate(audio_list)
return "Success", (self.hps.data.sampling_rate, audio_concat)
def load_audio(self, path):
audio, sr = librosa.load(path, 48000)
# audio = librosa.resample(audio, 44100, 48000)
return sr, audio
def format_utils(self, text, speaker):
_text, _lang = self.process_auto(text)
res = f"[{speaker}]"
for lang_s, content_s in zip(_lang, _text):
for lang, content in zip(lang_s, content_s):
# res += f"<{lang.lower()}>{content}"
# 部分中文会被识别成日文,强转成中文
lang = lang.lower().replace("jp", "zh")
res += f"<{lang}>{content}"
res += "|"
return "mix", res[:-1]
def synthesize(self,
text,
speaker_idx=0, # self.speakers 的 index指定说话
sdp_ratio=0.5,
noise_scale=0.6,
noise_scale_w=0.9,
length_scale=1.0, # 越大语速越慢
language="mix", # ["ZH", "EN", "mix"] 三选一
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
audio_prompt=None,
text_prompt="",
prompt_mode="Text prompts",
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
# "效果较不明确,留空即为不使用该功能"
style_weight=0.7, # "主文本和辅助文本的bert混合比率0表示仅主文本1表示仅辅助文本
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
):
"""
return: audio, sample_rate
"""
speaker = self.speakers[speaker_idx]
if language == "mix":
language, text = self.format_utils(text, speaker)
text_output, audio_output = self.tts_split(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
opt_cut_by_send,
interval_between_para,
interval_between_sent,
audio_prompt,
text_prompt,
style_text,
style_weight,
en_ratio
)
else:
text_output, audio_output = self.tts_fn(
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
audio_prompt,
text_prompt,
prompt_mode,
style_text,
style_weight
)
# return text_output, audio_output
return audio_output[1], audio_output[0]
def print_speakers_info(self):
for i, speaker in enumerate(self.speakers):
print(f"id: {i}, speaker: {speaker}")