1
0
Fork 0
TakwayDisplayPlatform/utils/vits_utils.py

105 lines
4.0 KiB
Python
Raw Normal View History

2024-06-09 22:54:13 +08:00
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
# vits
from .vits import utils, commons
from .vits.models import SynthesizerTrn
from .vits.text import text_to_sequence
class TextToSpeech:
def __init__(self,
model_path="./utils/vits_model",
device='cuda',
RATE=22050,
debug=False,
):
self.debug = debug
self.RATE = RATE
self.device = torch.device(device)
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
self._init_jieba()
def _init_jieba(self):
text = self._preprocess_text("初始化", 0)
self._generate_audio(text, 100, 0.6, 0.668, 1.0)
2024-06-09 22:54:13 +08:00
def _tts_model_init(self, model_path):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(self.device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
if self.debug:
print("Model loaded.")
return hps_ms, net_g_ms, speakers
def _get_text(self, text):
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
if self.hps_ms.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm, clean_text
def _preprocess_text(self, text, language):
if language == 0:
return f"[ZH]{text}[ZH]"
elif language == 1:
return f"[JA]{text}[JA]"
return text
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
import time
start_time = time.time()
stn_tst, clean_text = self._get_text(text)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
speaker_id = LongTensor([speaker_id]).to(self.device)
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
if self.debug:
print(f"Synthesis time: {time.time() - start_time} s")
return audio
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
if not len(text):
return b''
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, tts_info['language'])
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
if return_bytes:
audio = self.convert_numpy_to_bytes(audio)
return audio
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
sf.write(file_name, audio, samplerate=sample_rate)
print(f"VITS Audio saved to {file_name}")