1
0
Fork 0
TakwayPlatform/utils/tts/openvoice_utils.py

352 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
from glob import glob
import hashlib
from tqdm.auto import tqdm
import soundfile as sf
import numpy as np
import torch
from typing import Optional, Union
# melo
from melo.api import TTS
from melo.utils import get_text_for_tts_infer
# openvoice
from .openvoice import se_extractor
from .openvoice.api import ToneColorConverter
from .openvoice.mel_processing import spectrogram_torch
# torchaudio
import torchaudio.functional as F
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
# 存储缓存文件的路径
CACHE_PATH = r"D:\python\OpenVoice\processed"
OPENVOICE_BASE_TTS={
"model_type": "open_voice_base_tts",
# 转换的语言
"language": "ZH",
}
OPENVOICE_TONE_COLOR_CONVERTER={
"model_type": "open_voice_converter",
# 模型参数路径
"converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
}
class TextToSpeech:
def __init__(self,
use_tone_convert=True,
device="cuda",
debug:bool=False,
):
self.debug = debug
self.device = device
self.use_tone_convert = use_tone_convert
# 默认的源说话人 se
self.source_se = None
# 默认的目标说话人 se
self.target_se = None
self.initialize_base_tts(**OPENVOICE_BASE_TTS)
if self.debug:
print("use tone converter is", self.use_tone_convert)
if self.use_tone_convert:
self.initialize_tone_color_converter(**OPENVOICE_TONE_COLOR_CONVERTER)
self.initialize_source_se()
def initialize_tone_color_converter(self, **kwargs):
"""
初始化 tone color converter
"""
model_type = kwargs.pop('model_type')
self.tone_color_converter_model_type = model_type
if model_type == 'open_voice_converter':
# 加载模型
converter_path = kwargs.pop('converter_path')
self.tone_color_converter = ToneColorConverter(f'{converter_path}/config.json', self.device)
self.tone_color_converter.load_ckpt(f'{converter_path}/checkpoint.pth')
if self.debug:
print("load tone color converter successfully!")
else:
raise NotImplementedError(f"only [open_voice_converter] model type expected, but get [{model_type}]. ")
def initialize_base_tts(self, **kwargs):
"""
初始化 base tts model
"""
model_type = kwargs.pop('model_type')
self.base_tts_model_type = model_type
if model_type == "open_voice_base_tts":
language = kwargs.pop('language')
self.base_tts_model = TTS(language=language, device=self.device)
speaker_ids = self.base_tts_model.hps.data.spk2id
flag = False
for speaker_key in speaker_ids.keys():
if flag:
Warning(f'loaded model has more than one speaker, only the first speaker is used. The input speaker ids are {speaker_ids}')
break
self.speaker_id = speaker_ids[speaker_key]
self.speaker_key = speaker_key.lower().replace('_', '-')
flag=True
if self.debug:
print("load base tts model successfully!")
# 第一次使用tts时会加载bert模型
self._base_tts("初始化bert模型。")
else:
raise NotImplementedError(f"only [open_voice_base_tts] model type expected, but get [{model_type}]. ")
def initialize_source_se(self):
"""
初始化source se
"""
if self.source_se is not None:
Warning("replace source speaker embedding with new source speaker embedding!")
self.source_se = torch.load(os.path.join(SOURCE_SE_DIR, f"{self.speaker_key}.pth"), map_location=self.device)
def initialize_target_se(self, se: Union[np.ndarray, torch.Tensor]):
"""
设置 target se
param:
se: 输入的se类型可以为np.ndarray或torch.Tensor
"""
if self.target_se is not None:
Warning("replace target source speaker embedding with new target speaker embedding!")
if isinstance(se, np.ndarray):
self.target_se = torch.tensor(se.astype(np.float32)).to(self.device)
elif isinstance(se, torch.Tensor):
self.target_se = se.float().to(self.device)
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
"""
将字节流的audio转为numpy类型也可以传入numpy类型
return: np.float32
"""
# TODO 是否归一化判断
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16).flatten().astype(np.float32) / 32768.0
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.int16).flatten().astype(np.float32) / 32768.0
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
return audio_data
def audio2emb(self, audio_data: Union[bytes, np.ndarray], rate=44100, vad=True):
"""
将输入的字节流/numpy类型的audio转为speaker embedding
param:
audio_data: 输入的音频字节
rate: 输入音频的采样率
vad: 是否使用vad模型
return: np.ndarray
"""
audio_data = self.audio2numpy(audio_data)
from scipy.io import wavfile
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
wavfile.write(audio_path, rate=rate, data=audio_data)
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
# device = self.tone_color_converter.device
# version = self.tone_color_converter.version
# if self.debug:
# print("OpenVoice version:", version)
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
# if vad:
# wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
# else:
# wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
# audio_segs = glob(f'{wavs_folder}/*.wav')
# if len(audio_segs) == 0:
# raise NotImplementedError('No audio segments found!')
# # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
# se = self.tone_color_converter.extract_se(audio_segs)
return se.cpu().detach().numpy()
def tensor2numpy(self, audio_data: torch.Tensor):
"""
tensor类型转numpy
"""
return audio_data.cpu().detach().float().numpy()
def numpy2bytes(self, audio_data: np.ndarray):
"""
numpy类型转bytes
"""
return (audio_data*32768.0).astype(np.int32).tobytes()
def _base_tts(self,
text: str,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
quite=True):
"""
base语音合成
param:
text: 要合成的文本
sdp_ratio: SDP在合成时的占比, 理论上此比率越高, 合成的语音语调方差越大.
noise_scale: 样本噪声张量的噪声标度。
noise_scale_w: 推理中随机持续时间预测器的噪声标度
speed: 说话语速
quite: 是否显示进度条
return:
audio: tensor
sr: 生成音频的采样速率
"""
speaker_id = self.speaker_id
if self.base_tts_model_type != "open_voice_base_tts":
raise NotImplementedError("only [open_voice_base_tts] model type expected.")
language = self.base_tts_model.language
texts = self.base_tts_model.split_sentences_into_pieces(text, language, quite)
audio_list = []
if quite:
tx = texts
else:
tx = tqdm(texts)
for t in tx:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
device = self.base_tts_model.device
bert, ja_bert, phones, tones, lang_ids = get_text_for_tts_infer(t, language, self.base_tts_model.hps, device, self.base_tts_model.symbol_to_id)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = self.base_tts_model.model.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w = noise_scale_w,
length_scale = 1. / speed,
)[0][0, 0].data
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
audio_list.append(audio)
torch.cuda.empty_cache()
audio_segments = []
sr = self.base_tts_model.hps.data.sampling_rate
for segment_data in audio_list:
audio_segments.append(segment_data.reshape(-1).contiguous())
audio_segments.append(torch.tensor([0]*int((sr * 0.05) / speed), dtype=segment_data.dtype, device=segment_data.device))
audio_segments = torch.cat(audio_segments, dim=-1)
if self.debug:
print("generate base speech!")
print("**********************,tts sr",sr)
print(f"audio segment length is [{audio_segments.shape}]")
return audio_segments, sr
def _convert_tone(self,
audio_data: torch.Tensor,
source_se: Optional[np.ndarray]=None,
target_se: Optional[np.ndarray]=None,
tau :float=0.3,
message :str="default"):
"""
音色转换
param:
audio_data: _base_tts输出的音频数据
source_se: 如果为None, 则使用self.source_se
target_se: 如果为None, 则使用self.target_se
tau:
message: 水印信息 TODO
return:
audio: tensor
sr: 生成音频的采样速率
"""
if source_se is None:
source_se = self.source_se
if target_se is None:
target_se = self.target_se
hps = self.tone_color_converter.hps
sr = hps.data.sampling_rate
if self.debug:
print("**********************************, convert sr", sr)
audio_data = audio_data.float()
with torch.no_grad():
y = audio_data.to(self.tone_color_converter.device)
y = y.unsqueeze(0)
spec = spectrogram_torch(y, hps.data.filter_length,
sr, hps.data.hop_length, hps.data.win_length,
center=False).to(self.tone_color_converter.device)
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.tone_color_converter.device)
audio = self.tone_color_converter.model.voice_conversion(spec, spec_lengths, sid_src=source_se, sid_tgt=target_se, tau=tau)[0][
0, 0].data
# audio = self.tone_color_converter.add_watermark(audio, message)
if self.debug:
print("tone color has been converted!")
return audio, sr
def tts(self,
text: str,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
quite=True,
source_se: Optional[np.ndarray]=None,
target_se: Optional[np.ndarray]=None,
tau :float=0.3,
message :str="default"):
"""
整体pipeline
_base_tts()
_convert_tone()
tensor2numpy()
numpy2bytes()
param:
见_base_tts和_convert_tone
return:
audio: 字节流音频数据
sr: 音频数据的采样率
"""
audio, sr = self._base_tts(text,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
speed=speed,
quite=quite)
if self.use_tone_convert:
tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr)
print(audio.dtype)
audio, sr = self._convert_tone(audio,
source_se=source_se,
target_se=target_se,
tau=tau,
message=message)
audio = self.tensor2numpy(audio)
audio = self.numpy2bytes(audio)
return audio, sr
def save_audio(self, audio, sample_rate, save_path):
"""
将numpy类型的音频数据保存至本地
param:
audio: numpy类型的音频数据
sample_rate: 数据采样率
save_path: 保存路径
"""
sf.write(save_path, audio, sample_rate)
print(f"Audio saved to {save_path}")