[update] 更新funasr parformer_streaming cache管理,增加seesion_id字段
This commit is contained in:
parent
a943a281d5
commit
4322b03418
|
@ -26,11 +26,9 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
chunk_ms=480,
|
chunk_ms=480,
|
||||||
encoder_chunk_look_back=4,
|
encoder_chunk_look_back=4,
|
||||||
decoder_chunk_look_back=1,
|
decoder_chunk_look_back=1,
|
||||||
mutli_cache_enable=True,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
|
||||||
self.mutli_cache_enable = mutli_cache_enable
|
|
||||||
|
|
||||||
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||||
|
|
||||||
|
@ -82,12 +80,9 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
def _init_asr(self):
|
def _init_asr(self):
|
||||||
# 随机初始化一段音频数据
|
# 随机初始化一段音频数据
|
||||||
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
if self.mutli_cache_enable:
|
|
||||||
self.session_signup("init")
|
self.session_signup("init")
|
||||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init")
|
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init")
|
||||||
self.session_signout("init")
|
self.session_signout("init")
|
||||||
else:
|
|
||||||
self.asr_model.generate(input=init_audio_data, cache=None, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
|
||||||
# print("init ASR model done.")
|
# print("init ASR model done.")
|
||||||
|
|
||||||
# when chat trying to use asr , sign up
|
# when chat trying to use asr , sign up
|
||||||
|
@ -115,7 +110,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
text_dict = dict(text=[], is_end=is_end)
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
audio_cache = self.audio_cache[session_id]
|
audio_cache = self.audio_cache[session_id]
|
||||||
# asr_cache = self.asr_cache[session_id]
|
|
||||||
|
|
||||||
audio_data = self.check_audio_type(audio_data)
|
audio_data = self.check_audio_type(audio_data)
|
||||||
if audio_cache is None:
|
if audio_cache is None:
|
||||||
|
@ -156,8 +150,7 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
# TODO: exceptions processes
|
# TODO: exceptions processes
|
||||||
# print("i:", i)
|
# print("i:", i)
|
||||||
try:
|
try:
|
||||||
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id if self.mutli_cache_enable else None)
|
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id)
|
||||||
# , session_id=session_id
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"ValueError: {e}")
|
print(f"ValueError: {e}")
|
||||||
continue
|
continue
|
||||||
|
@ -166,7 +159,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
|
|
||||||
if is_end:
|
if is_end:
|
||||||
audio_cache = None
|
audio_cache = None
|
||||||
asr_cache = {}
|
|
||||||
else:
|
else:
|
||||||
if end_idx:
|
if end_idx:
|
||||||
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
|
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
@ -174,7 +166,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
|
|
||||||
|
|
||||||
self.audio_cache[session_id] = audio_cache
|
self.audio_cache[session_id] = audio_cache
|
||||||
# self.asr_cache[session_id] = asr_cache
|
|
||||||
return text_dict
|
return text_dict
|
||||||
|
|
||||||
def streaming_recognize_origin(self,
|
def streaming_recognize_origin(self,
|
||||||
|
|
Loading…
Reference in New Issue