FunASR/funasr/models/fsmn_vad_streaming/model.py

917 lines
38 KiB
Python
Raw Permalink Normal View History

2024-05-18 15:50:56 +08:00
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
import json
import time
import math
import torch
from torch import nn
from enum import Enum
from dataclasses import dataclass
from funasr.register import tables
from typing import List, Tuple, Dict, Any, Optional
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
**kwargs,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
window_size_ms: int,
sil_to_speech_time: int,
speech_to_sil_time: int,
frame_size_ms: int,
):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(
self, frameState: FrameState, frame_count: int, cache: dict = {}
) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if (
self.pre_frame_state == FrameState.kFrameStateSil
and self.win_sum >= self.sil_to_speech_frmcnt_thres
):
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if (
self.pre_frame_state == FrameState.kFrameStateSpeech
and self.win_sum <= self.speech_to_sil_frmcnt_thres
):
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
class Stats(object):
def __init__(
self,
sil_pdf_ids,
max_end_sil_frame_cnt_thresh,
speech_noise_thres,
):
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
self.speech_noise_thres = speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.last_drop_frames = 0
@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
vad_post_args: Dict[str, Any] = None,
**kwargs,
):
super().__init__()
self.vad_opts = VADXOptions(**kwargs)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
self.encoder_conf = encoder_conf
def ResetDetection(self, cache: dict = {}):
cache["stats"].continous_silence_frame_count = 0
cache["stats"].latest_confirmed_speech_frame = 0
cache["stats"].lastest_confirmed_silence_frame = -1
cache["stats"].confirmed_start_frame = -1
cache["stats"].confirmed_end_frame = -1
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
cache["windows_detector"].Reset()
cache["stats"].sil_frame = 0
cache["stats"].frame_probs = []
if cache["stats"].output_data_buf:
assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
cache["stats"].last_drop_frames = drop_frames
cache["stats"].data_buf_all = cache["stats"].data_buf_all[
real_drop_frames
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
def ComputeDecibel(self, cache: dict = {}) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if cache["stats"].data_buf_all is None:
cache["stats"].data_buf_all = cache["stats"].waveform[
0
] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf = cache["stats"].data_buf_all
else:
cache["stats"].data_buf_all = torch.cat(
(cache["stats"].data_buf_all, cache["stats"].waveform[0])
)
for offset in range(
0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length
):
cache["stats"].decibel.append(
10
* math.log10(
(cache["stats"].waveform[0][offset : offset + frame_sample_length])
.square()
.sum()
+ 0.000001
)
)
def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
scores = self.encoder(feats, cache=cache["encoder"]).to("cpu") # return B * T * D
assert (
scores.shape[1] == feats.shape[1]
), "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
cache["stats"].frm_cnt += scores.shape[1] # count total frames
if cache["stats"].scores is None:
cache["stats"].scores = scores # the first calculation
else:
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again
while cache["stats"].data_buf_start_frame < frame_idx:
if len(cache["stats"].data_buf) >= int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
):
cache["stats"].data_buf_start_frame += 1
cache["stats"].data_buf = cache["stats"].data_buf_all[
(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
def PopDataToOutputBuf(
self,
start_frm: int,
frm_cnt: int,
first_frm_is_start_point: bool,
last_frm_is_end_point: bool,
end_point_is_sent_end: bool,
cache: dict = {},
) -> None:
self.PopDataBufTillFrame(start_frm, cache=cache)
expected_sample_number = int(
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
)
if last_frm_is_end_point:
extra_sample = max(
0,
int(
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
- self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
),
)
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
if len(cache["stats"].data_buf) < expected_sample_number:
print("error in calling pop data_buf\n")
if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
cache["stats"].output_data_buf[-1].Reset()
cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
cache["stats"].output_data_buf[-1].doa = 0
cur_seg = cache["stats"].output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("warning\n")
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
else:
data_to_pop = int(
frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
)
if data_to_pop > len(cache["stats"].data_buf):
print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
data_to_pop = len(cache["stats"].data_buf)
expected_sample_number = len(cache["stats"].data_buf)
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
# cur_seg.buffer[out_pos ++] = data_buf_.back();
out_pos += 1
for sample_cpy_out in range(data_to_pop, expected_sample_number):
# cur_seg.buffer[out_pos++] = data_buf_.back()
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("Something wrong with the VAD algorithm\n")
cache["stats"].data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
cache["stats"].lastest_confirmed_silence_frame = valid_frame
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame, cache=cache)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
cache["stats"].latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
if self.vad_opts.do_start_point_detection:
pass
if cache["stats"].confirmed_start_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_start_frame = start_frame
if (
not fake_result
and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
):
self.PopDataToOutputBuf(
cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
)
def OnVoiceEnd(
self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}
) -> None:
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t, cache=cache)
if self.vad_opts.do_end_point_detection:
pass
if cache["stats"].confirmed_end_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_end_frame = end_frame
if not fake_result:
cache["stats"].sil_frame = 0
self.PopDataToOutputBuf(
cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
)
cache["stats"].number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(
self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}
) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self, cache: dict = {}) -> int:
return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
vad_latency = cache["windows_detector"].GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int, cache: dict = {}):
frame_state = FrameState.kFrameStateInvalid
cur_decibel = cache["stats"].decibel[t]
cur_snr = cur_decibel - cache["stats"].noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False, cache=cache)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(cache["stats"].sil_pdf_ids) > 0:
assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [
cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids
]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
speech_prob = math.log(sum_score)
if self.vad_opts.output_frame_probs:
frame_prob = E2EVadFrameProb()
frame_prob.noise_prob = noise_prob
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
cache["stats"].frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if cache["stats"].noise_average_decibel < -99.9:
cache["stats"].noise_average_decibel = cur_decibel
else:
cache["stats"].noise_average_decibel = (
cur_decibel
+ cache["stats"].noise_average_decibel
* (self.vad_opts.noise_frame_num_used_for_snr - 1)
) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def forward(
self,
feats: torch.Tensor,
waveform: torch.tensor,
cache: dict = {},
is_final: bool = False,
**kwargs,
):
# if len(cache) == 0:
# self.AllResetDetection()
# self.waveform = waveform # compute decibel for each frame
cache["stats"].waveform = waveform
is_streaming_input = kwargs.get("is_streaming_input", True)
self.ComputeDecibel(cache=cache)
self.ComputeScores(feats, cache=cache)
if not is_final:
self.DetectCommonFrames(cache=cache)
else:
self.DetectLastFrames(cache=cache)
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(cache["stats"].output_data_buf) > 0:
for i in range(
cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)
):
if (
is_streaming_input
): # in this case, return [beg, -1], [], [-1, end], [beg, end]
if not cache["stats"].output_data_buf[i].contain_seg_start_point:
continue
if (
not cache["stats"].next_seg
and not cache["stats"].output_data_buf[i].contain_seg_end_point
):
continue
start_ms = (
cache["stats"].output_data_buf[i].start_ms
if cache["stats"].next_seg
else -1
)
if cache["stats"].output_data_buf[i].contain_seg_end_point:
end_ms = cache["stats"].output_data_buf[i].end_ms
cache["stats"].next_seg = True
cache["stats"].output_data_buf_offset += 1
else:
end_ms = -1
cache["stats"].next_seg = False
segment = [start_ms, end_ms]
else: # in this case, return [beg, end]
if not is_final and (
not cache["stats"].output_data_buf[i].contain_seg_start_point
or not cache["stats"].output_data_buf[i].contain_seg_end_point
):
continue
segment = [
cache["stats"].output_data_buf[i].start_ms,
cache["stats"].output_data_buf[i].end_ms,
]
cache["stats"].output_data_buf_offset += 1 # need update this parameter
segment_batch.append(segment)
if segment_batch:
segments.append(segment_batch)
# if is_final:
# # reset class variables and clear the dict for the next query
# self.AllResetDetection()
return segments
def init_cache(self, cache: dict = {}, **kwargs):
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
if kwargs.get("max_end_silence_time") is not None:
# update the max_end_silence_time
self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
windows_detector = WindowDetector(
self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms,
)
windows_detector.Reset()
stats = Stats(
sil_pdf_ids=self.vad_opts.sil_pdf_ids,
max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
- self.vad_opts.speech_to_sil_time_thres,
speech_noise_thres=self.vad_opts.speech_noise_thres,
)
cache["windows_detector"] = windows_detector
cache["stats"] = stats
return cache
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = {},
**kwargs,
):
if len(cache) == 0:
self.init_cache(cache, **kwargs)
meta_data = {}
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
time1 = time.perf_counter()
is_streaming_input = (
kwargs.get("is_streaming_input", False)
if chunk_size >= 15000
else kwargs.get("is_streaming_input", True)
)
is_final = (
kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
)
cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
cache=cfg,
)
_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
is_streaming_input = cfg["is_streaming_input"]
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
assert len(audio_sample_list) == 1, "batch_size must be set 1"
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
segments = []
for i in range(n):
kwargs["is_final"] = _is_final and i == n - 1
audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
# extract fbank feats
speech, speech_lengths = extract_fbank(
[audio_sample_i],
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
cache=cache["frontend"],
is_final=kwargs["is_final"],
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
)
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
batch = {
"feats": speech,
"waveform": cache["frontend"]["waveforms"],
"is_final": kwargs["is_final"],
"cache": cache,
"is_streaming_input": is_streaming_input,
}
segments_i = self.forward(**batch)
if len(segments_i) > 0:
segments.extend(*segments_i)
cache["prev_samples"] = audio_sample[:-m]
if _is_final:
self.init_cache(cache)
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{1}best_recog"]
results = []
result_i = {"key": key[0], "value": segments}
# if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
# result_i = json.dumps(result_i)
results.append(result_i)
if ibest_writer is not None:
ibest_writer["text"][key[0]] = segments
return results, meta_data
def export(self, **kwargs):
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
def DetectCommonFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
return 0
def DetectLastFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
if i != 0:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
else:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
return 0
def DetectOneFrame(
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}
) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = cache["windows_detector"].DetectOneFrame(
tmp_cur_frm_state, cur_frm_idx, cache=cache
)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = cache["stats"].continous_silence_frame_count
cache["stats"].continous_silence_frame_count = 0
cache["stats"].pre_end_silence_detected = False
start_frame = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(
cache["stats"].data_buf_start_frame,
cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache),
)
self.OnVoiceStart(start_frame, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t, cache=cache)
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t, cache=cache)
if (
cur_frm_idx - cache["stats"].confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
cache["stats"].continous_silence_frame_count = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
cur_frm_idx - cache["stats"].confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
cache["stats"].continous_silence_frame_count = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
cur_frm_idx - cache["stats"].confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
cache["stats"].max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
cache["stats"].continous_silence_frame_count += 1
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if (
(self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
and (
cache["stats"].continous_silence_frame_count * frm_shift_in_ms
> self.vad_opts.max_start_silence_time
)
) or (is_final_frame and cache["stats"].number_end_time_detected == 0):
for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t, cache=cache)
self.OnVoiceStart(0, True, cache=cache)
self.OnVoiceEnd(0, True, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
self.OnSilenceDetected(
cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache
)
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
cache["stats"].continous_silence_frame_count * frm_shift_in_ms
>= cache["stats"].max_end_sil_frame_cnt_thresh
):
lookback_frame = int(
cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms
)
if self.vad_opts.do_extend:
lookback_frame -= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif (
cur_frm_idx - cache["stats"].confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if cache["stats"].continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
):
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
if (
cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
):
self.ResetDetection(cache=cache)