From d930e71410ada143d61b8156d23acee437063034 Mon Sep 17 00:00:00 2001 From: IrvingGao <1729854488@qq.com> Date: Wed, 15 May 2024 21:56:10 +0800 Subject: [PATCH] =?UTF-8?q?[update]=20=E6=9B=B4=E6=96=B0funasr=20parformer?= =?UTF-8?q?=5Fstreaming=20cache=E7=AE=A1=E7=90=86=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0seesion=5Fid=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/stt/funasr_utils.py | 82 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py index 58a01d9..7ee2b74 100644 --- a/utils/stt/funasr_utils.py +++ b/utils/stt/funasr_utils.py @@ -79,9 +79,9 @@ class FunAutoSpeechRecognizer(STTBase): def _init_asr(self): # 随机初始化一段音频数据 init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) - self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) - self.audio_cache = {} - self.asr_cache = {} + self.session_signup("init") + self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") + self.session_signout("init") # print("init ASR model done.") # when chat trying to use asr , sign up @@ -108,6 +108,82 @@ class FunAutoSpeechRecognizer(STTBase): """ text_dict = dict(text=[], is_end=is_end) + audio_cache = self.audio_cache[session_id] + # asr_cache = self.asr_cache[session_id] + + audio_data = self.check_audio_type(audio_data) + if audio_cache is None: + audio_cache = audio_data + else: + if audio_cache.shape[0] > 0: + audio_cache = np.concatenate([audio_cache, audio_data], axis=0) + + if not is_end and audio_cache.shape[0] < self.chunk_partial_size: + self.audio_cache[session_id] = audio_cache + return text_dict + + total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size) + + if is_end: + # if the audio data is the end of a sentence, \ + # we need to add one more chunk to the end to \ + # ensure the end of the sentence is recognized correctly. + auto_det_end = True + + if auto_det_end: + total_chunk_num += 1 + + end_idx = None + for i in range(total_chunk_num): + if auto_det_end: + is_end = i == total_chunk_num - 1 + start_idx = i*self.chunk_partial_size + if auto_det_end: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 + else: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 + # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") + # t_stamp = time.time() + + speech_chunk = audio_cache[start_idx:end_idx] + + # TODO: exceptions processes + print("i:", i) + try: + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id) + except ValueError as e: + print(f"ValueError: {e}") + continue + text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) + # print(f"each chunk time: {time.time()-t_stamp}") + + if is_end: + audio_cache = None + asr_cache = {} + else: + if end_idx: + audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache + text_dict['is_end'] = is_end + + + self.audio_cache[session_id] = audio_cache + # self.asr_cache[session_id] = asr_cache + return text_dict + + def streaming_recognize_origin(self, + session_id, + audio_data, + is_end=False, + auto_det_end=False): + """recognize partial result + + Args: + audio_data: bytes or numpy array, partial audio data + is_end: bool, whether the audio data is the end of a sentence + auto_det_end: bool, whether to automatically detect the end of a audio data + """ + text_dict = dict(text=[], is_end=is_end) + audio_cache = self.audio_cache[session_id] asr_cache = self.asr_cache[session_id]