diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py index 07f7e78..9b7eec7 100644 --- a/utils/stt/funasr_utils.py +++ b/utils/stt/funasr_utils.py @@ -26,11 +26,9 @@ class FunAutoSpeechRecognizer(STTBase): chunk_ms=480, encoder_chunk_look_back=4, decoder_chunk_look_back=1, - mutli_cache_enable=True, **kwargs): super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) - self.mutli_cache_enable = mutli_cache_enable self.asr_model = AutoModel(model=model_path, device=device, **kwargs) @@ -82,12 +80,9 @@ class FunAutoSpeechRecognizer(STTBase): def _init_asr(self): # 随机初始化一段音频数据 init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) - if self.mutli_cache_enable: - self.session_signup("init") - self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") - self.session_signout("init") - else: - self.asr_model.generate(input=init_audio_data, cache=None, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + self.session_signup("init") + self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") + self.session_signout("init") # print("init ASR model done.") # when chat trying to use asr , sign up @@ -115,7 +110,6 @@ class FunAutoSpeechRecognizer(STTBase): text_dict = dict(text=[], is_end=is_end) audio_cache = self.audio_cache[session_id] - # asr_cache = self.asr_cache[session_id] audio_data = self.check_audio_type(audio_data) if audio_cache is None: @@ -156,8 +150,7 @@ class FunAutoSpeechRecognizer(STTBase): # TODO: exceptions processes # print("i:", i) try: - res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id if self.mutli_cache_enable else None) - # , session_id=session_id + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id) except ValueError as e: print(f"ValueError: {e}") continue @@ -166,7 +159,6 @@ class FunAutoSpeechRecognizer(STTBase): if is_end: audio_cache = None - asr_cache = {} else: if end_idx: audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache @@ -174,7 +166,6 @@ class FunAutoSpeechRecognizer(STTBase): self.audio_cache[session_id] = audio_cache - # self.asr_cache[session_id] = asr_cache return text_dict def streaming_recognize_origin(self,