diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py index 7ee2b74..07f7e78 100644 --- a/utils/stt/funasr_utils.py +++ b/utils/stt/funasr_utils.py @@ -26,9 +26,12 @@ class FunAutoSpeechRecognizer(STTBase): chunk_ms=480, encoder_chunk_look_back=4, decoder_chunk_look_back=1, + mutli_cache_enable=True, **kwargs): super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) + self.mutli_cache_enable = mutli_cache_enable + self.asr_model = AutoModel(model=model_path, device=device, **kwargs) self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention @@ -79,9 +82,12 @@ class FunAutoSpeechRecognizer(STTBase): def _init_asr(self): # 随机初始化一段音频数据 init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) - self.session_signup("init") - self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") - self.session_signout("init") + if self.mutli_cache_enable: + self.session_signup("init") + self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") + self.session_signout("init") + else: + self.asr_model.generate(input=init_audio_data, cache=None, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) # print("init ASR model done.") # when chat trying to use asr , sign up @@ -148,9 +154,10 @@ class FunAutoSpeechRecognizer(STTBase): speech_chunk = audio_cache[start_idx:end_idx] # TODO: exceptions processes - print("i:", i) + # print("i:", i) try: - res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id) + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id if self.mutli_cache_enable else None) + # , session_id=session_id except ValueError as e: print(f"ValueError: {e}") continue