feat: 使全局asr实例有一定并发能力
This commit is contained in:
parent
11c429befa
commit
4974570f20
|
@ -223,18 +223,24 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
||||||
|
|
||||||
#语音识别
|
#语音识别
|
||||||
async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event):
|
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
|
is_signup = False
|
||||||
try:
|
try:
|
||||||
current_message = ""
|
current_message = ""
|
||||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||||
|
if not is_signup:
|
||||||
|
asr.session_signup(session_id)
|
||||||
|
is_signup = True
|
||||||
audio_data = await user_input_q.get()
|
audio_data = await user_input_q.get()
|
||||||
asr_result = asr.streaming_recognize(audio_data)
|
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
asr_result = asr.streaming_recognize(b'',is_end=True)
|
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
|
asr.session_signout(session_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
asr.session_signout(session_id)
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
|
|
||||||
|
@ -305,12 +311,11 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||||
future_session_id = asyncio.Future()
|
future_session_id = asyncio.Future()
|
||||||
future_response_type = asyncio.Future()
|
future_response_type = asyncio.Future()
|
||||||
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
|
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
|
||||||
asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event))
|
|
||||||
|
|
||||||
|
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
update_session_activity(session_id,db)
|
update_session_activity(session_id,db)
|
||||||
response_type = await future_response_type #获取返回类型
|
response_type = await future_response_type #获取返回类型
|
||||||
|
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
@ -346,6 +351,7 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
is_future_done = True
|
is_future_done = True
|
||||||
if scl_data_json['text']:
|
if scl_data_json['text']:
|
||||||
await llm_input_q.put(scl_data_json['text'])
|
await llm_input_q.put(scl_data_json['text'])
|
||||||
|
continue
|
||||||
if scl_data_json['meta_info']['is_end']:
|
if scl_data_json['meta_info']['is_end']:
|
||||||
user_input_frame = {"audio": scl_data_json['audio'], "is_end": True}
|
user_input_frame = {"audio": scl_data_json['audio'], "is_end": True}
|
||||||
await user_input_q.put(user_input_frame)
|
await user_input_q.put(user_input_frame)
|
||||||
|
@ -362,25 +368,31 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
break
|
break
|
||||||
|
|
||||||
#语音识别
|
#语音识别
|
||||||
async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event):
|
async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_event,asr_finished_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
|
is_signup = False
|
||||||
current_message = ""
|
current_message = ""
|
||||||
while not (input_finished_event.is_set() and user_input_q.empty()):
|
while not (input_finished_event.is_set() and user_input_q.empty()):
|
||||||
try:
|
try:
|
||||||
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
||||||
|
if not is_signup:
|
||||||
|
asr.session_signup(session_id)
|
||||||
|
is_signup = True
|
||||||
if aduio_frame['is_end']:
|
if aduio_frame['is_end']:
|
||||||
asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True)
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
else:
|
else:
|
||||||
asr_result = asr.streaming_recognize(aduio_frame['audio'])
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
asr.session_signout(session_id)
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
break
|
break
|
||||||
|
asr.session_signout(session_id)
|
||||||
asr_finished_event.set()
|
asr_finished_event.set()
|
||||||
|
|
||||||
#大模型调用
|
#大模型调用
|
||||||
|
@ -455,7 +467,6 @@ async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
future_response_type = asyncio.Future()
|
future_response_type = asyncio.Future()
|
||||||
|
|
||||||
asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event))
|
asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event))
|
||||||
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
|
|
||||||
|
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
update_session_activity(session_id,db)
|
update_session_activity(session_id,db)
|
||||||
|
@ -463,6 +474,7 @@ async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
asyncio.create_task(scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_event,asr_finished_event))
|
||||||
asyncio.create_task(scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event))
|
asyncio.create_task(scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event))
|
||||||
|
|
||||||
while not chat_finished_event.is_set():
|
while not chat_finished_event.is_set():
|
||||||
|
@ -505,23 +517,27 @@ async def voice_call_audio_producer(ws,audio_q,future,input_finished_event):
|
||||||
|
|
||||||
|
|
||||||
#音频数据消费函数
|
#音频数据消费函数
|
||||||
async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event):
|
async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_finished_event,asr_finished_event):
|
||||||
logger.debug("音频数据消费者函数启动")
|
logger.debug("音频数据消费者函数启动")
|
||||||
vad = VAD()
|
vad = VAD()
|
||||||
current_message = ""
|
current_message = ""
|
||||||
vad_count = 0
|
vad_count = 0
|
||||||
|
is_signup = False
|
||||||
while not (input_finished_event.is_set() and audio_q.empty()):
|
while not (input_finished_event.is_set() and audio_q.empty()):
|
||||||
try:
|
try:
|
||||||
|
if not is_signup:
|
||||||
|
asr.session_signup(session_id)
|
||||||
|
is_signup = True
|
||||||
audio_data = await asyncio.wait_for(audio_q.get(),timeout=3)
|
audio_data = await asyncio.wait_for(audio_q.get(),timeout=3)
|
||||||
if vad.is_speech(audio_data):
|
if vad.is_speech(audio_data):
|
||||||
if vad_count > 0:
|
if vad_count > 0:
|
||||||
vad_count -= 1
|
vad_count -= 1
|
||||||
asr_result = asr.streaming_recognize(audio_data)
|
asr_result = asr.streaming_recognize(session_id, audio_data)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
else:
|
else:
|
||||||
vad_count += 1
|
vad_count += 1
|
||||||
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
||||||
asr_result = asr.streaming_recognize(audio_data, is_end=True)
|
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
||||||
if current_message:
|
if current_message:
|
||||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||||
await asr_result_q.put(current_message)
|
await asr_result_q.put(current_message)
|
||||||
|
@ -532,8 +548,10 @@ async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
asr.session_signout(session_id)
|
||||||
logger.error(f"音频数据消费者函数发生错误: {str(e)}")
|
logger.error(f"音频数据消费者函数发生错误: {str(e)}")
|
||||||
break
|
break
|
||||||
|
asr.session_signout(session_id)
|
||||||
asr_finished_event.set()
|
asr_finished_event.set()
|
||||||
|
|
||||||
#asr结果消费以及llm返回生产函数
|
#asr结果消费以及llm返回生产函数
|
||||||
|
@ -621,7 +639,6 @@ async def voice_call_handler(ws, db, redis):
|
||||||
|
|
||||||
future = asyncio.Future() #用于获取传输的session_id
|
future = asyncio.Future() #用于获取传输的session_id
|
||||||
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
|
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
|
||||||
asyncio.create_task(voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
|
|
||||||
|
|
||||||
#获取session内容
|
#获取session内容
|
||||||
session_id = await future #获取session_id
|
session_id = await future #获取session_id
|
||||||
|
@ -629,6 +646,7 @@ async def voice_call_handler(ws, db, redis):
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
asyncio.create_task(voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
|
||||||
asyncio.create_task(voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event)) #创建llm处理者
|
asyncio.create_task(voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event)) #创建llm处理者
|
||||||
while not voice_call_end_event.is_set():
|
while not voice_call_end_event.is_set():
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
|
|
@ -41,11 +41,12 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
self.chunk_size = [0, 10, 5]
|
self.chunk_size = [0, 10, 5]
|
||||||
else:
|
else:
|
||||||
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||||
self.chunk_partial_size = self.chunk_size[1] * 960
|
self.chunk_partial_size = self.chunk_size[1] * 960
|
||||||
self.audio_cache = None
|
self.audio_cache = {}
|
||||||
self.asr_cache = {}
|
self.asr_cache = {}
|
||||||
|
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
|
||||||
self._init_asr()
|
self._init_asr()
|
||||||
|
|
||||||
|
@ -79,24 +80,22 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
# 随机初始化一段音频数据
|
# 随机初始化一段音频数据
|
||||||
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
self.audio_cache = None
|
self.audio_cache = {}
|
||||||
self.asr_cache = {}
|
self.asr_cache = {}
|
||||||
# print("init ASR model done.")
|
# print("init ASR model done.")
|
||||||
|
|
||||||
def recognize(self, audio_data):
|
# when chat trying to use asr , sign up
|
||||||
"""recognize audio data to text"""
|
def session_signup(self,session_id):
|
||||||
audio_data = self.check_audio_type(audio_data)
|
self.audio_cache[session_id] = None
|
||||||
result = self.asr_model.generate(input=audio_data,
|
self.asr_cache[session_id] = {}
|
||||||
batch_size_s=300,
|
|
||||||
hotword=self.hotwords)
|
# when chat finish using asr , sign out
|
||||||
|
def session_signout(self,session_id):
|
||||||
# print(result)
|
del self.audio_cache[session_id]
|
||||||
text = ''
|
del self.asr_cache[session_id]
|
||||||
for res in result:
|
|
||||||
text += res['text']
|
def streaming_recognize(self,
|
||||||
return text
|
session_id,
|
||||||
|
|
||||||
def streaming_recognize(self,
|
|
||||||
audio_data,
|
audio_data,
|
||||||
is_end=False,
|
is_end=False,
|
||||||
auto_det_end=False):
|
auto_det_end=False):
|
||||||
|
@ -108,19 +107,22 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
auto_det_end: bool, whether to automatically detect the end of a audio data
|
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
"""
|
"""
|
||||||
text_dict = dict(text=[], is_end=is_end)
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
|
audio_cache = self.audio_cache[session_id]
|
||||||
|
asr_cache = self.asr_cache[session_id]
|
||||||
|
|
||||||
audio_data = self.check_audio_type(audio_data)
|
audio_data = self.check_audio_type(audio_data)
|
||||||
if self.audio_cache is None:
|
if audio_cache is None:
|
||||||
self.audio_cache = audio_data
|
audio_cache = audio_data
|
||||||
else:
|
else:
|
||||||
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
if audio_cache.shape[0] > 0:
|
||||||
if self.audio_cache.shape[0] > 0:
|
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
|
||||||
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
|
||||||
|
|
||||||
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
self.audio_cache[session_id] = audio_cache
|
||||||
return text_dict
|
return text_dict
|
||||||
|
|
||||||
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
if is_end:
|
if is_end:
|
||||||
# if the audio data is the end of a sentence, \
|
# if the audio data is the end of a sentence, \
|
||||||
|
@ -131,7 +133,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
if auto_det_end:
|
if auto_det_end:
|
||||||
total_chunk_num += 1
|
total_chunk_num += 1
|
||||||
|
|
||||||
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
|
||||||
end_idx = None
|
end_idx = None
|
||||||
for i in range(total_chunk_num):
|
for i in range(total_chunk_num):
|
||||||
if auto_det_end:
|
if auto_det_end:
|
||||||
|
@ -144,11 +145,11 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
# t_stamp = time.time()
|
# t_stamp = time.time()
|
||||||
|
|
||||||
speech_chunk = self.audio_cache[start_idx:end_idx]
|
speech_chunk = audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
# TODO: exceptions processes
|
# TODO: exceptions processes
|
||||||
try:
|
try:
|
||||||
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"ValueError: {e}")
|
print(f"ValueError: {e}")
|
||||||
continue
|
continue
|
||||||
|
@ -156,15 +157,186 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
# print(f"each chunk time: {time.time()-t_stamp}")
|
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
if is_end:
|
if is_end:
|
||||||
self.audio_cache = None
|
audio_cache = None
|
||||||
self.asr_cache = {}
|
asr_cache = {}
|
||||||
else:
|
else:
|
||||||
if end_idx:
|
if end_idx:
|
||||||
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
text_dict['is_end'] = is_end
|
text_dict['is_end'] = is_end
|
||||||
|
|
||||||
# print(f"text_dict: {text_dict}")
|
|
||||||
|
self.audio_cache[session_id] = audio_cache
|
||||||
|
self.asr_cache[session_id] = asr_cache
|
||||||
return text_dict
|
return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# ####################################################### #
|
||||||
|
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
|
||||||
|
# ####################################################### #
|
||||||
|
# import io
|
||||||
|
# import numpy as np
|
||||||
|
# import base64
|
||||||
|
# import wave
|
||||||
|
# from funasr import AutoModel
|
||||||
|
# from .base_stt import STTBase
|
||||||
|
|
||||||
|
# def decode_str2bytes(data):
|
||||||
|
# # 将Base64编码的字节串解码为字节串
|
||||||
|
# if data is None:
|
||||||
|
# return None
|
||||||
|
# return base64.b64decode(data.encode('utf-8'))
|
||||||
|
|
||||||
|
# class FunAutoSpeechRecognizer(STTBase):
|
||||||
|
# def __init__(self,
|
||||||
|
# model_path="paraformer-zh-streaming",
|
||||||
|
# device="cuda",
|
||||||
|
# RATE=16000,
|
||||||
|
# cfg_path=None,
|
||||||
|
# debug=False,
|
||||||
|
# chunk_ms=480,
|
||||||
|
# encoder_chunk_look_back=4,
|
||||||
|
# decoder_chunk_look_back=1,
|
||||||
|
# **kwargs):
|
||||||
|
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
|
||||||
|
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||||
|
|
||||||
|
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||||
|
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
|
||||||
|
|
||||||
|
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
|
||||||
|
# if chunk_ms == 480:
|
||||||
|
# self.chunk_size = [0, 8, 4]
|
||||||
|
# elif chunk_ms == 600:
|
||||||
|
# self.chunk_size = [0, 10, 5]
|
||||||
|
# else:
|
||||||
|
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||||
|
# self.chunk_partial_size = self.chunk_size[1] * 960
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# self._init_asr()
|
||||||
|
|
||||||
|
# def check_audio_type(self, audio_data):
|
||||||
|
# """check audio data type and convert it to bytes if necessary."""
|
||||||
|
# if isinstance(audio_data, bytes):
|
||||||
|
# pass
|
||||||
|
# elif isinstance(audio_data, list):
|
||||||
|
# audio_data = b''.join(audio_data)
|
||||||
|
# elif isinstance(audio_data, str):
|
||||||
|
# audio_data = decode_str2bytes(audio_data)
|
||||||
|
# elif isinstance(audio_data, io.BytesIO):
|
||||||
|
# wf = wave.open(audio_data, 'rb')
|
||||||
|
# audio_data = wf.readframes(wf.getnframes())
|
||||||
|
# elif isinstance(audio_data, np.ndarray):
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# raise TypeError(f"audio_data must be bytes, list, str, \
|
||||||
|
# io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
# if isinstance(audio_data, bytes):
|
||||||
|
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
# elif isinstance(audio_data, np.ndarray):
|
||||||
|
# if audio_data.dtype != np.int16:
|
||||||
|
# audio_data = audio_data.astype(np.int16)
|
||||||
|
# else:
|
||||||
|
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||||
|
# return audio_data
|
||||||
|
|
||||||
|
# def _init_asr(self):
|
||||||
|
# # 随机初始化一段音频数据
|
||||||
|
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
|
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
# # print("init ASR model done.")
|
||||||
|
|
||||||
|
# def recognize(self, audio_data):
|
||||||
|
# """recognize audio data to text"""
|
||||||
|
# audio_data = self.check_audio_type(audio_data)
|
||||||
|
# result = self.asr_model.generate(input=audio_data,
|
||||||
|
# batch_size_s=300,
|
||||||
|
# hotword=self.hotwords)
|
||||||
|
|
||||||
|
# # print(result)
|
||||||
|
# text = ''
|
||||||
|
# for res in result:
|
||||||
|
# text += res['text']
|
||||||
|
# return text
|
||||||
|
|
||||||
|
# def streaming_recognize(self,
|
||||||
|
# audio_data,
|
||||||
|
# is_end=False,
|
||||||
|
# auto_det_end=False):
|
||||||
|
# """recognize partial result
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# audio_data: bytes or numpy array, partial audio data
|
||||||
|
# is_end: bool, whether the audio data is the end of a sentence
|
||||||
|
# auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
|
# """
|
||||||
|
# text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
|
# audio_data = self.check_audio_type(audio_data)
|
||||||
|
# if self.audio_cache is None:
|
||||||
|
# self.audio_cache = audio_data
|
||||||
|
# else:
|
||||||
|
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||||
|
# if self.audio_cache.shape[0] > 0:
|
||||||
|
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||||
|
|
||||||
|
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
# return text_dict
|
||||||
|
|
||||||
|
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
|
# if is_end:
|
||||||
|
# # if the audio data is the end of a sentence, \
|
||||||
|
# # we need to add one more chunk to the end to \
|
||||||
|
# # ensure the end of the sentence is recognized correctly.
|
||||||
|
# auto_det_end = True
|
||||||
|
|
||||||
|
# if auto_det_end:
|
||||||
|
# total_chunk_num += 1
|
||||||
|
|
||||||
|
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||||
|
# end_idx = None
|
||||||
|
# for i in range(total_chunk_num):
|
||||||
|
# if auto_det_end:
|
||||||
|
# is_end = i == total_chunk_num - 1
|
||||||
|
# start_idx = i*self.chunk_partial_size
|
||||||
|
# if auto_det_end:
|
||||||
|
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||||
|
# else:
|
||||||
|
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||||
|
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
|
# # t_stamp = time.time()
|
||||||
|
|
||||||
|
# speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
|
# # TODO: exceptions processes
|
||||||
|
# try:
|
||||||
|
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
# except ValueError as e:
|
||||||
|
# print(f"ValueError: {e}")
|
||||||
|
# continue
|
||||||
|
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||||
|
# # print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
|
# if is_end:
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
# else:
|
||||||
|
# if end_idx:
|
||||||
|
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
# text_dict['is_end'] = is_end
|
||||||
|
|
||||||
|
# # print(f"text_dict: {text_dict}")
|
||||||
|
# return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue