117 lines
4.7 KiB
Python
117 lines
4.7 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
from torch import LongTensor
|
|
from typing import Optional
|
|
import soundfile as sf
|
|
# vits
|
|
from .vits import utils, commons
|
|
from .vits.models import SynthesizerTrn
|
|
from .vits.text import text_to_sequence
|
|
|
|
def tts_model_init(model_path='./vits_model', device='cuda'):
|
|
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
|
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
|
|
net_g_ms = SynthesizerTrn(
|
|
len(hps_ms.symbols),
|
|
hps_ms.data.filter_length // 2 + 1,
|
|
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
|
n_speakers=hps_ms.data.n_speakers,
|
|
**hps_ms.model)
|
|
net_g_ms = net_g_ms.eval().to(device)
|
|
speakers = hps_ms.speakers
|
|
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
|
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
|
|
return hps_ms, net_g_ms, speakers
|
|
|
|
class TextToSpeech:
|
|
def __init__(self,
|
|
model_path="./utils/tts/vits_model",
|
|
device='cuda',
|
|
RATE=22050,
|
|
debug=False,
|
|
):
|
|
self.debug = debug
|
|
self.RATE = RATE
|
|
self.device = torch.device(device)
|
|
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
|
|
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
|
|
|
|
def _tts_model_init(self, model_path):
|
|
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
|
net_g_ms = SynthesizerTrn(
|
|
len(hps_ms.symbols),
|
|
hps_ms.data.filter_length // 2 + 1,
|
|
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
|
n_speakers=hps_ms.data.n_speakers,
|
|
**hps_ms.model)
|
|
net_g_ms = net_g_ms.eval().to(self.device)
|
|
speakers = hps_ms.speakers
|
|
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
|
if self.debug:
|
|
print("Model loaded.")
|
|
return hps_ms, net_g_ms, speakers
|
|
|
|
def _get_text(self, text):
|
|
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
|
|
if self.hps_ms.data.add_blank:
|
|
text_norm = commons.intersperse(text_norm, 0)
|
|
text_norm = LongTensor(text_norm)
|
|
return text_norm, clean_text
|
|
|
|
def _preprocess_text(self, text, language):
|
|
if language == 0:
|
|
return f"[ZH]{text}[ZH]"
|
|
elif language == 1:
|
|
return f"[JA]{text}[JA]"
|
|
return text
|
|
|
|
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
|
|
import time
|
|
start_time = time.time()
|
|
stn_tst, clean_text = self._get_text(text)
|
|
with torch.no_grad():
|
|
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
|
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
|
|
speaker_id = LongTensor([speaker_id]).to(self.device)
|
|
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
|
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
|
if self.debug:
|
|
print(f"Synthesis time: {time.time() - start_time} s")
|
|
return audio
|
|
|
|
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
|
if not len(text):
|
|
return "输入文本不能为空!", None
|
|
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
|
if len(text) > 100 and self.limitation:
|
|
return f"输入文字过长!{len(text)}>100", None
|
|
text = self._preprocess_text(text, tts_info['language'])
|
|
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
|
if self.debug or save_audio:
|
|
self.save_audio(audio, self.RATE, 'output_file.wav')
|
|
if return_bytes:
|
|
audio = self.convert_numpy_to_bytes(audio)
|
|
return audio, self.RATE
|
|
|
|
def convert_numpy_to_bytes(self, audio_data):
|
|
if isinstance(audio_data, np.ndarray):
|
|
if audio_data.dtype == np.dtype('float32'):
|
|
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
|
audio_data = audio_data.tobytes()
|
|
return audio_data
|
|
else:
|
|
raise TypeError("audio_data must be a numpy array")
|
|
|
|
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
|
|
sf.write(file_name, audio, samplerate=sample_rate)
|
|
print(f"VITS Audio saved to {file_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|