FunASR/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py

331 lines
14 KiB
Python
Raw Normal View History

2024-05-18 15:50:56 +08:00
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
import numpy as np
from .utils.utils import (
CharTokenizer,
Hypothesis,
ONNXRuntimeError,
OrtInferSession,
TokenIDConverter,
get_logger,
read_yaml,
)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
logging = get_logger()
class Paraformer:
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
chunk_size: List = [5, 10, 5],
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
encoder_model_file = os.path.join(model_dir, "model.onnx")
decoder_model_file = os.path.join(model_dir, "decoder.onnx")
if quantize:
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
print(".onnx is not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
self.pe = SinusoidalPositionEncoderOnline()
self.ort_encoder_infer = OrtInferSession(
encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.ort_decoder_infer = OrtInferSession(
decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.chunk_size = chunk_size
self.encoder_output_size = config["encoder_conf"]["output_size"]
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
self.fsmn_dims = config["encoder_conf"]["output_size"]
self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
self.cif_threshold = config["predictor_conf"]["threshold"]
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
def prepare_cache(self, cache: dict = {}, batch_size=1):
if len(cache) > 0:
return cache
cache["start_idx"] = 0
cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
cache["chunk_size"] = self.chunk_size
cache["last_chunk"] = False
cache["feats"] = np.zeros(
(batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
).astype(np.float32)
cache["decoder_fsmn"] = []
for i in range(self.fsmn_layer):
fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
cache["decoder_fsmn"].append(fsmn_cache)
return cache
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
# process last chunk
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
if cache["is_final"]:
cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
if not cache["last_chunk"]:
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
else:
cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
return overlap_feats
def __call__(self, audio_in: np.ndarray, **kwargs):
waveforms = np.expand_dims(audio_in, axis=0)
param_dict = kwargs.get("param_dict", dict())
is_final = param_dict.get("is_final", False)
cache = param_dict.get("cache", dict())
asr_res = []
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
cache["last_chunk"] = True
feats = cache["feats"]
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
feats, feats_len = self.extract_feat(waveforms, is_final)
if feats.shape[1] != 0:
feats *= self.encoder_output_size**0.5
cache = self.prepare_cache(cache)
cache["is_final"] = is_final
# fbank -> position encoding -> overlap chunk
feats = self.pe.forward(feats, cache["start_idx"])
cache["start_idx"] += feats.shape[1]
if is_final:
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
cache["last_chunk"] = True
feats = self.add_overlap_chunk(feats, cache)
else:
# first chunk
feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache["last_chunk"] = True
feats_chunk2 = self.add_overlap_chunk(
feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
cache,
)
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
res = {}
for pred in asr_res_chunk:
for key, value in pred.items():
if key in res:
res[key][0] += value[0]
res[key][1].extend(value[1])
else:
res[key] = [value[0], value[1]]
return [res]
else:
feats = self.add_overlap_chunk(feats, cache)
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
# encoder forward
enc_input = [feats, feats_len]
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
# predictor forward
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
# decoder forward
asr_res = []
if acoustic_embeds.shape[1] > 0:
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
dec_input.extend(cache["decoder_fsmn"])
dec_output = self.ort_decoder_infer(dec_input)
logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
cache["decoder_fsmn"] = [
item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
]
preds = self.decode(logits, acoustic_embeds_len)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(
self, waveforms: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
for idx, waveform in enumerate(waveforms):
waveforms_lens[idx] = waveform.shape[-1]
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
return feats.astype(np.float32), feats_len.astype(np.int32)
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[:valid_token_num]
# texts = sentence_postprocess(token)
return token
def cif_search(self, hidden, alphas, cache=None):
batch_size, len_time, hidden_size = hidden.shape
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
alphas[:, : self.chunk_size[0]] = 0.0
alphas[:, sum(self.chunk_size[:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
tail_alphas = np.tile(tail_alphas, (batch_size, 1))
hidden = np.concatenate((hidden, tail_hidden), axis=1)
alphas = np.concatenate((alphas, tail_alphas), axis=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = np.zeros(hidden_size).astype(np.float32)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.cif_threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.cif_threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.cif_threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(len(list_frame))
list_fires.append(list_fire)
list_frames.append(list_frame)
max_token_len = max(token_length)
list_ls = []
for b in range(batch_size):
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
np.int32
)