331 lines
14 KiB
Python
331 lines
14 KiB
Python
# -*- encoding: utf-8 -*-
|
|
|
|
import os.path
|
|
from pathlib import Path
|
|
from typing import List, Union, Tuple
|
|
import json
|
|
import copy
|
|
import librosa
|
|
import numpy as np
|
|
|
|
from .utils.utils import (
|
|
CharTokenizer,
|
|
Hypothesis,
|
|
ONNXRuntimeError,
|
|
OrtInferSession,
|
|
TokenIDConverter,
|
|
get_logger,
|
|
read_yaml,
|
|
)
|
|
from .utils.postprocess_utils import sentence_postprocess
|
|
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
|
|
|
|
logging = get_logger()
|
|
|
|
|
|
class Paraformer:
|
|
def __init__(
|
|
self,
|
|
model_dir: Union[str, Path] = None,
|
|
batch_size: int = 1,
|
|
chunk_size: List = [5, 10, 5],
|
|
device_id: Union[str, int] = "-1",
|
|
quantize: bool = False,
|
|
intra_op_num_threads: int = 4,
|
|
cache_dir: str = None,
|
|
**kwargs,
|
|
):
|
|
|
|
if not Path(model_dir).exists():
|
|
try:
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
except:
|
|
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
|
try:
|
|
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
|
except:
|
|
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
|
model_dir
|
|
)
|
|
|
|
encoder_model_file = os.path.join(model_dir, "model.onnx")
|
|
decoder_model_file = os.path.join(model_dir, "decoder.onnx")
|
|
if quantize:
|
|
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
|
|
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
|
|
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
|
|
print(".onnx is not exist, begin to export onnx")
|
|
try:
|
|
from funasr import AutoModel
|
|
except:
|
|
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
|
|
|
model = AutoModel(model=model_dir)
|
|
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
|
|
|
|
config_file = os.path.join(model_dir, "config.yaml")
|
|
cmvn_file = os.path.join(model_dir, "am.mvn")
|
|
config = read_yaml(config_file)
|
|
token_list = os.path.join(model_dir, "tokens.json")
|
|
with open(token_list, "r", encoding="utf-8") as f:
|
|
token_list = json.load(f)
|
|
|
|
self.converter = TokenIDConverter(token_list)
|
|
self.tokenizer = CharTokenizer()
|
|
self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
|
|
self.pe = SinusoidalPositionEncoderOnline()
|
|
self.ort_encoder_infer = OrtInferSession(
|
|
encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
|
)
|
|
self.ort_decoder_infer = OrtInferSession(
|
|
decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
|
)
|
|
self.batch_size = batch_size
|
|
self.chunk_size = chunk_size
|
|
self.encoder_output_size = config["encoder_conf"]["output_size"]
|
|
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
|
|
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
|
|
self.fsmn_dims = config["encoder_conf"]["output_size"]
|
|
self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
|
|
self.cif_threshold = config["predictor_conf"]["threshold"]
|
|
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
|
|
|
|
def prepare_cache(self, cache: dict = {}, batch_size=1):
|
|
if len(cache) > 0:
|
|
return cache
|
|
cache["start_idx"] = 0
|
|
cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
|
|
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
|
|
cache["chunk_size"] = self.chunk_size
|
|
cache["last_chunk"] = False
|
|
cache["feats"] = np.zeros(
|
|
(batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
|
|
).astype(np.float32)
|
|
cache["decoder_fsmn"] = []
|
|
for i in range(self.fsmn_layer):
|
|
fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
|
|
cache["decoder_fsmn"].append(fsmn_cache)
|
|
return cache
|
|
|
|
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
|
|
if len(cache) == 0:
|
|
return feats
|
|
# process last chunk
|
|
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
|
|
if cache["is_final"]:
|
|
cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
|
|
if not cache["last_chunk"]:
|
|
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
|
|
overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
|
|
else:
|
|
cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
|
|
return overlap_feats
|
|
|
|
def __call__(self, audio_in: np.ndarray, **kwargs):
|
|
waveforms = np.expand_dims(audio_in, axis=0)
|
|
param_dict = kwargs.get("param_dict", dict())
|
|
is_final = param_dict.get("is_final", False)
|
|
cache = param_dict.get("cache", dict())
|
|
asr_res = []
|
|
|
|
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
|
|
cache["last_chunk"] = True
|
|
feats = cache["feats"]
|
|
feats_len = np.array([feats.shape[1]]).astype(np.int32)
|
|
asr_res = self.infer(feats, feats_len, cache)
|
|
return asr_res
|
|
|
|
feats, feats_len = self.extract_feat(waveforms, is_final)
|
|
if feats.shape[1] != 0:
|
|
feats *= self.encoder_output_size**0.5
|
|
cache = self.prepare_cache(cache)
|
|
cache["is_final"] = is_final
|
|
|
|
# fbank -> position encoding -> overlap chunk
|
|
feats = self.pe.forward(feats, cache["start_idx"])
|
|
cache["start_idx"] += feats.shape[1]
|
|
if is_final:
|
|
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
|
|
cache["last_chunk"] = True
|
|
feats = self.add_overlap_chunk(feats, cache)
|
|
else:
|
|
# first chunk
|
|
feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
|
|
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
|
|
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
|
|
|
|
# last chunk
|
|
cache["last_chunk"] = True
|
|
feats_chunk2 = self.add_overlap_chunk(
|
|
feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
|
|
cache,
|
|
)
|
|
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
|
|
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
|
|
|
|
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
|
|
res = {}
|
|
for pred in asr_res_chunk:
|
|
for key, value in pred.items():
|
|
if key in res:
|
|
res[key][0] += value[0]
|
|
res[key][1].extend(value[1])
|
|
else:
|
|
res[key] = [value[0], value[1]]
|
|
return [res]
|
|
else:
|
|
feats = self.add_overlap_chunk(feats, cache)
|
|
|
|
feats_len = np.array([feats.shape[1]]).astype(np.int32)
|
|
asr_res = self.infer(feats, feats_len, cache)
|
|
|
|
return asr_res
|
|
|
|
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
|
|
# encoder forward
|
|
enc_input = [feats, feats_len]
|
|
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
|
|
|
|
# predictor forward
|
|
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
|
|
|
|
# decoder forward
|
|
asr_res = []
|
|
if acoustic_embeds.shape[1] > 0:
|
|
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
|
|
dec_input.extend(cache["decoder_fsmn"])
|
|
dec_output = self.ort_decoder_infer(dec_input)
|
|
logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
|
|
cache["decoder_fsmn"] = [
|
|
item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
|
|
]
|
|
|
|
preds = self.decode(logits, acoustic_embeds_len)
|
|
for pred in preds:
|
|
pred = sentence_postprocess(pred)
|
|
asr_res.append({"preds": pred})
|
|
|
|
return asr_res
|
|
|
|
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
|
def load_wav(path: str) -> np.ndarray:
|
|
waveform, _ = librosa.load(path, sr=fs)
|
|
return waveform
|
|
|
|
if isinstance(wav_content, np.ndarray):
|
|
return [wav_content]
|
|
|
|
if isinstance(wav_content, str):
|
|
return [load_wav(wav_content)]
|
|
|
|
if isinstance(wav_content, list):
|
|
return [load_wav(path) for path in wav_content]
|
|
|
|
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
|
|
|
def extract_feat(
|
|
self, waveforms: np.ndarray, is_final: bool = False
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
|
|
for idx, waveform in enumerate(waveforms):
|
|
waveforms_lens[idx] = waveform.shape[-1]
|
|
|
|
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
|
|
return feats.astype(np.float32), feats_len.astype(np.int32)
|
|
|
|
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
|
return [
|
|
self.decode_one(am_score, token_num)
|
|
for am_score, token_num in zip(am_scores, token_nums)
|
|
]
|
|
|
|
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
|
|
yseq = am_score.argmax(axis=-1)
|
|
score = am_score.max(axis=-1)
|
|
score = np.sum(score, axis=-1)
|
|
|
|
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
# asr_model.sos:1 asr_model.eos:2
|
|
yseq = np.array([1] + yseq.tolist() + [2])
|
|
hyp = Hypothesis(yseq=yseq, score=score)
|
|
|
|
# remove sos/eos and get results
|
|
last_pos = -1
|
|
token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
|
# remove blank symbol id, which is assumed to be 0
|
|
token_int = list(filter(lambda x: x not in (0, 2), token_int))
|
|
|
|
# Change integer-ids to tokens
|
|
token = self.converter.ids2tokens(token_int)
|
|
token = token[:valid_token_num]
|
|
# texts = sentence_postprocess(token)
|
|
return token
|
|
|
|
def cif_search(self, hidden, alphas, cache=None):
|
|
batch_size, len_time, hidden_size = hidden.shape
|
|
token_length = []
|
|
list_fires = []
|
|
list_frames = []
|
|
cache_alphas = []
|
|
cache_hiddens = []
|
|
alphas[:, : self.chunk_size[0]] = 0.0
|
|
alphas[:, sum(self.chunk_size[:2]) :] = 0.0
|
|
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
|
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
|
|
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
|
|
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
|
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
|
|
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
|
|
tail_alphas = np.tile(tail_alphas, (batch_size, 1))
|
|
hidden = np.concatenate((hidden, tail_hidden), axis=1)
|
|
alphas = np.concatenate((alphas, tail_alphas), axis=1)
|
|
|
|
len_time = alphas.shape[1]
|
|
for b in range(batch_size):
|
|
integrate = 0.0
|
|
frames = np.zeros(hidden_size).astype(np.float32)
|
|
list_frame = []
|
|
list_fire = []
|
|
for t in range(len_time):
|
|
alpha = alphas[b][t]
|
|
if alpha + integrate < self.cif_threshold:
|
|
integrate += alpha
|
|
list_fire.append(integrate)
|
|
frames += alpha * hidden[b][t]
|
|
else:
|
|
frames += (self.cif_threshold - integrate) * hidden[b][t]
|
|
list_frame.append(frames)
|
|
integrate += alpha
|
|
list_fire.append(integrate)
|
|
integrate -= self.cif_threshold
|
|
frames = integrate * hidden[b][t]
|
|
|
|
cache_alphas.append(integrate)
|
|
if integrate > 0.0:
|
|
cache_hiddens.append(frames / integrate)
|
|
else:
|
|
cache_hiddens.append(frames)
|
|
|
|
token_length.append(len(list_frame))
|
|
list_fires.append(list_fire)
|
|
list_frames.append(list_frame)
|
|
|
|
max_token_len = max(token_length)
|
|
list_ls = []
|
|
for b in range(batch_size):
|
|
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
|
|
if token_length[b] == 0:
|
|
list_ls.append(pad_frames)
|
|
else:
|
|
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
|
|
|
|
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
|
|
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
|
|
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
|
|
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
|
|
|
|
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
|
|
np.int32
|
|
)
|