forked from killua/TakwayPlatform
349 lines
14 KiB
Python
349 lines
14 KiB
Python
import os
|
||
import re
|
||
from glob import glob
|
||
import hashlib
|
||
from tqdm.auto import tqdm
|
||
import soundfile as sf
|
||
import numpy as np
|
||
import torch
|
||
from typing import Optional, Union
|
||
# melo
|
||
from melo.api import TTS
|
||
from melo.utils import get_text_for_tts_infer
|
||
# openvoice
|
||
from .openvoice import se_extractor
|
||
from .openvoice.api import ToneColorConverter
|
||
from .openvoice.mel_processing import spectrogram_torch
|
||
# torchaudio
|
||
import torchaudio.functional as F
|
||
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
|
||
SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
|
||
|
||
# 存储缓存文件的路径
|
||
CACHE_PATH = r"D:\python\OpenVoice\processed"
|
||
|
||
OPENVOICE_BASE_TTS={
|
||
"model_type": "open_voice_base_tts",
|
||
# 转换的语言
|
||
"language": "ZH",
|
||
}
|
||
|
||
OPENVOICE_TONE_COLOR_CONVERTER={
|
||
"model_type": "open_voice_converter",
|
||
# 模型参数路径
|
||
"converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
|
||
}
|
||
|
||
class TextToSpeech:
|
||
def __init__(self,
|
||
use_tone_convert=True,
|
||
device="cuda",
|
||
debug:bool=False,
|
||
):
|
||
self.debug = debug
|
||
self.device = device
|
||
self.use_tone_convert = use_tone_convert
|
||
self.source_se = None
|
||
self.target_se = None
|
||
|
||
self.initialize_base_tts(**OPENVOICE_BASE_TTS)
|
||
print(self.use_tone_convert)
|
||
if self.use_tone_convert:
|
||
self.initialize_tone_color_converter(**OPENVOICE_TONE_COLOR_CONVERTER)
|
||
self.initialize_source_se()
|
||
|
||
|
||
def initialize_tone_color_converter(self, **kwargs):
|
||
"""
|
||
初始化 tone color converter
|
||
"""
|
||
model_type = kwargs.pop('model_type')
|
||
self.tone_color_converter_model_type = model_type
|
||
if model_type == 'open_voice_converter':
|
||
# 加载模型
|
||
converter_path = kwargs.pop('converter_path')
|
||
self.tone_color_converter = ToneColorConverter(f'{converter_path}/config.json', self.device)
|
||
self.tone_color_converter.load_ckpt(f'{converter_path}/checkpoint.pth')
|
||
if self.debug:
|
||
print("load tone color converter successfully!")
|
||
else:
|
||
raise NotImplementedError(f"only [open_voice_converter] model type expected, but get [{model_type}]. ")
|
||
|
||
def initialize_base_tts(self, **kwargs):
|
||
"""
|
||
初始化 base tts model
|
||
"""
|
||
model_type = kwargs.pop('model_type')
|
||
self.base_tts_model_type = model_type
|
||
if model_type == "open_voice_base_tts":
|
||
language = kwargs.pop('language')
|
||
self.base_tts_model = TTS(language=language, device=self.device)
|
||
speaker_ids = self.base_tts_model.hps.data.spk2id
|
||
flag = False
|
||
for speaker_key in speaker_ids.keys():
|
||
if flag:
|
||
Warning(f'loaded model has more than one speaker, only the first speaker is used. The input speaker ids are {speaker_ids}')
|
||
break
|
||
self.speaker_id = speaker_ids[speaker_key]
|
||
self.speaker_key = speaker_key.lower().replace('_', '-')
|
||
flag=True
|
||
if self.debug:
|
||
print("load base tts model successfully!")
|
||
# 第一次使用tts时会加载bert模型
|
||
self._base_tts("初始化bert模型。")
|
||
else:
|
||
raise NotImplementedError(f"only [open_voice_base_tts] model type expected, but get [{model_type}]. ")
|
||
|
||
def initialize_source_se(self):
|
||
"""
|
||
初始化source se
|
||
"""
|
||
if self.source_se is not None:
|
||
Warning("replace source speaker embedding with new source speaker embedding!")
|
||
self.source_se = torch.load(os.path.join(SOURCE_SE_DIR, f"{self.speaker_key}.pth"), map_location=self.device)
|
||
|
||
def initialize_target_se(self, se: Union[np.ndarray, torch.Tensor]):
|
||
"""
|
||
设置 target se
|
||
param:
|
||
se: 输入的se,类型可以为np.ndarray或torch.Tensor
|
||
"""
|
||
if self.target_se is not None:
|
||
Warning("replace target source speaker embedding with new target speaker embedding!")
|
||
if isinstance(se, np.ndarray):
|
||
self.target_se = torch.tensor(se.astype(np.float32)).to(self.device)
|
||
elif isinstance(se, torch.Tensor):
|
||
self.target_se = se.float().to(self.device)
|
||
|
||
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
|
||
"""
|
||
将字节流的audio转为numpy类型,也可以传入numpy类型
|
||
return: np.float32
|
||
"""
|
||
# TODO 是否归一化判断
|
||
if isinstance(audio_data, bytes):
|
||
audio_data = np.frombuffer(audio_data, dtype=np.int16).flatten().astype(np.float32) / 32768.0
|
||
elif isinstance(audio_data, np.ndarray):
|
||
if audio_data.dtype != np.float32:
|
||
audio_data = audio_data.astype(np.int16).flatten().astype(np.float32) / 32768.0
|
||
else:
|
||
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||
return audio_data
|
||
|
||
def audio2emb(self, audio_data: Union[bytes, np.ndarray], rate=44100, vad=True):
|
||
"""
|
||
将输入的字节流/numpy类型的audio转为speaker embedding
|
||
param:
|
||
audio_data: 输入的音频字节
|
||
rate: 输入音频的采样率
|
||
vad: 是否使用vad模型
|
||
return: np.ndarray
|
||
"""
|
||
audio_data = self.audio2numpy(audio_data)
|
||
|
||
from scipy.io import wavfile
|
||
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
|
||
wavfile.write(audio_path, rate=rate, data=audio_data)
|
||
|
||
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
|
||
# device = self.tone_color_converter.device
|
||
# version = self.tone_color_converter.version
|
||
# if self.debug:
|
||
# print("OpenVoice version:", version)
|
||
|
||
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
|
||
|
||
|
||
# if vad:
|
||
# wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
|
||
# else:
|
||
# wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
|
||
|
||
# audio_segs = glob(f'{wavs_folder}/*.wav')
|
||
# if len(audio_segs) == 0:
|
||
# raise NotImplementedError('No audio segments found!')
|
||
# # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
|
||
# se = self.tone_color_converter.extract_se(audio_segs)
|
||
return se.cpu().detach().numpy()
|
||
|
||
def tensor2numpy(self, audio_data: torch.Tensor):
|
||
"""
|
||
tensor类型转numpy
|
||
"""
|
||
return audio_data.cpu().detach().float().numpy()
|
||
|
||
def numpy2bytes(self, audio_data: np.ndarray):
|
||
"""
|
||
numpy类型转bytes
|
||
"""
|
||
return (audio_data*32768.0).astype(np.int32).tobytes()
|
||
|
||
def _base_tts(self,
|
||
text: str,
|
||
sdp_ratio=0.2,
|
||
noise_scale=0.6,
|
||
noise_scale_w=0.8,
|
||
speed=1.0,
|
||
quite=True):
|
||
"""
|
||
base语音合成
|
||
param:
|
||
text: 要合成的文本
|
||
sdp_ratio: SDP在合成时的占比, 理论上此比率越高, 合成的语音语调方差越大.
|
||
noise_scale: 样本噪声张量的噪声标度。
|
||
noise_scale_w: 推理中随机持续时间预测器的噪声标度
|
||
speed: 说话语速
|
||
quite: 是否显示进度条
|
||
return:
|
||
audio: tensor
|
||
sr: 生成音频的采样速率
|
||
"""
|
||
speaker_id = self.speaker_id
|
||
if self.base_tts_model_type != "open_voice_base_tts":
|
||
raise NotImplementedError("only [open_voice_base_tts] model type expected.")
|
||
language = self.base_tts_model.language
|
||
texts = self.base_tts_model.split_sentences_into_pieces(text, language, quite)
|
||
audio_list = []
|
||
if quite:
|
||
tx = texts
|
||
else:
|
||
tx = tqdm(texts)
|
||
for t in tx:
|
||
if language in ['EN', 'ZH_MIX_EN']:
|
||
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
|
||
device = self.base_tts_model.device
|
||
bert, ja_bert, phones, tones, lang_ids = get_text_for_tts_infer(t, language, self.base_tts_model.hps, device, self.base_tts_model.symbol_to_id)
|
||
with torch.no_grad():
|
||
x_tst = phones.to(device).unsqueeze(0)
|
||
tones = tones.to(device).unsqueeze(0)
|
||
lang_ids = lang_ids.to(device).unsqueeze(0)
|
||
bert = bert.to(device).unsqueeze(0)
|
||
ja_bert = ja_bert.to(device).unsqueeze(0)
|
||
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
||
del phones
|
||
speakers = torch.LongTensor([speaker_id]).to(device)
|
||
audio = self.base_tts_model.model.infer(
|
||
x_tst,
|
||
x_tst_lengths,
|
||
speakers,
|
||
tones,
|
||
lang_ids,
|
||
bert,
|
||
ja_bert,
|
||
sdp_ratio=sdp_ratio,
|
||
noise_scale=noise_scale,
|
||
noise_scale_w = noise_scale_w,
|
||
length_scale = 1. / speed,
|
||
)[0][0, 0].data
|
||
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
|
||
audio_list.append(audio)
|
||
torch.cuda.empty_cache()
|
||
audio_segments = []
|
||
sr = self.base_tts_model.hps.data.sampling_rate
|
||
for segment_data in audio_list:
|
||
audio_segments.append(segment_data.reshape(-1).contiguous())
|
||
audio_segments.append(torch.tensor([0]*int((sr * 0.05) / speed), dtype=segment_data.dtype, device=segment_data.device))
|
||
audio_segments = torch.cat(audio_segments, dim=-1)
|
||
if self.debug:
|
||
print("generate base speech!")
|
||
print("**********************,tts sr",sr)
|
||
print(f"audio segment length is [{audio_segments.shape}]")
|
||
return audio_segments, sr
|
||
|
||
def _convert_tone(self,
|
||
audio_data: torch.Tensor,
|
||
source_se: Optional[np.ndarray]=None,
|
||
target_se: Optional[np.ndarray]=None,
|
||
tau :float=0.3,
|
||
message :str="default"):
|
||
"""
|
||
音色转换
|
||
param:
|
||
audio_data: _base_tts输出的音频数据
|
||
source_se: 如果为None, 则使用self.source_se
|
||
target_se: 如果为None, 则使用self.target_se
|
||
tau:
|
||
message: 水印信息 TODO
|
||
return:
|
||
audio: tensor
|
||
sr: 生成音频的采样速率
|
||
"""
|
||
if source_se is None:
|
||
source_se = self.source_se
|
||
if target_se is None:
|
||
target_se = self.target_se
|
||
|
||
hps = self.tone_color_converter.hps
|
||
sr = hps.data.sampling_rate
|
||
if self.debug:
|
||
print("**********************************, convert sr", sr)
|
||
audio_data = audio_data.float()
|
||
|
||
with torch.no_grad():
|
||
y = audio_data.to(self.tone_color_converter.device)
|
||
y = y.unsqueeze(0)
|
||
spec = spectrogram_torch(y, hps.data.filter_length,
|
||
sr, hps.data.hop_length, hps.data.win_length,
|
||
center=False).to(self.tone_color_converter.device)
|
||
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.tone_color_converter.device)
|
||
audio = self.tone_color_converter.model.voice_conversion(spec, spec_lengths, sid_src=source_se, sid_tgt=target_se, tau=tau)[0][
|
||
0, 0].data
|
||
# audio = self.tone_color_converter.add_watermark(audio, message)
|
||
if self.debug:
|
||
print("tone color has been converted!")
|
||
return audio, sr
|
||
|
||
def tts(self,
|
||
text: str,
|
||
sdp_ratio=0.2,
|
||
noise_scale=0.6,
|
||
noise_scale_w=0.8,
|
||
speed=1.0,
|
||
quite=True,
|
||
|
||
source_se: Optional[np.ndarray]=None,
|
||
target_se: Optional[np.ndarray]=None,
|
||
tau :float=0.3,
|
||
message :str="default"):
|
||
"""
|
||
整体pipeline
|
||
_base_tts()
|
||
_convert_tone()
|
||
tensor2numpy()
|
||
numpy2bytes()
|
||
param:
|
||
见_base_tts和_convert_tone
|
||
return:
|
||
audio: 字节流音频数据
|
||
sr: 音频数据的采样率
|
||
"""
|
||
audio, sr = self._base_tts(text,
|
||
sdp_ratio=sdp_ratio,
|
||
noise_scale=noise_scale,
|
||
noise_scale_w=noise_scale_w,
|
||
speed=speed,
|
||
quite=quite)
|
||
if self.use_tone_convert:
|
||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||
audio = F.resample(audio, tts_sr, converter_sr)
|
||
print(audio.dtype)
|
||
audio, sr = self._convert_tone(audio,
|
||
source_se=source_se,
|
||
target_se=target_se,
|
||
tau=tau,
|
||
message=message)
|
||
audio = self.tensor2numpy(audio)
|
||
audio = self.numpy2bytes(audio)
|
||
return audio, sr
|
||
|
||
def save_audio(self, audio, sample_rate, save_path):
|
||
"""
|
||
将numpy类型的音频数据保存至本地
|
||
param:
|
||
audio: numpy类型的音频数据
|
||
sample_rate: 数据采样率
|
||
save_path: 保存路径
|
||
"""
|
||
sf.write(save_path, audio, sample_rate)
|
||
print(f"Audio saved to {save_path}") |