diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 4d06359..02b34f5 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -223,18 +223,24 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f logger.error(f"用户输入处理函数发生错误: {str(e)}") #语音识别 -async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event): +async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") + is_signup = False try: current_message = "" while not (user_input_finish_event.is_set() and user_input_q.empty()): + if not is_signup: + asr.session_signup(session_id) + is_signup = True audio_data = await user_input_q.get() - asr_result = asr.streaming_recognize(audio_data) + asr_result = asr.streaming_recognize(session_id,audio_data) current_message += ''.join(asr_result['text']) - asr_result = asr.streaming_recognize(b'',is_end=True) + asr_result = asr.streaming_recognize(session_id,b'',is_end=True) current_message += ''.join(asr_result['text']) await llm_input_q.put(current_message) + asr.session_signout(session_id) except Exception as e: + asr.session_signout(session_id) logger.error(f"语音识别函数发生错误: {str(e)}") logger.debug(f"接收到用户消息: {current_message}") @@ -305,12 +311,11 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): future_session_id = asyncio.Future() future_response_type = asyncio.Future() asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event)) - asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event)) - session_id = await future_session_id #获取session_id update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 + asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event)) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) @@ -346,6 +351,7 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f is_future_done = True if scl_data_json['text']: await llm_input_q.put(scl_data_json['text']) + continue if scl_data_json['meta_info']['is_end']: user_input_frame = {"audio": scl_data_json['audio'], "is_end": True} await user_input_q.put(user_input_frame) @@ -362,25 +368,31 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f break #语音识别 -async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event): +async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_event,asr_finished_event): logger.debug("语音识别函数启动") + is_signup = False current_message = "" while not (input_finished_event.is_set() and user_input_q.empty()): try: aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3) + if not is_signup: + asr.session_signup(session_id) + is_signup = True if aduio_frame['is_end']: - asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True) + asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True) current_message += ''.join(asr_result['text']) await llm_input_q.put(current_message) logger.debug(f"接收到用户消息: {current_message}") else: - asr_result = asr.streaming_recognize(aduio_frame['audio']) + asr_result = asr.streaming_recognize(session_id,aduio_frame['audio']) current_message += ''.join(asr_result['text']) except asyncio.TimeoutError: continue except Exception as e: + asr.session_signout(session_id) logger.error(f"语音识别函数发生错误: {str(e)}") break + asr.session_signout(session_id) asr_finished_event.set() #大模型调用 @@ -455,7 +467,6 @@ async def streaming_chat_lasting_handler(ws,db,redis): future_response_type = asyncio.Future() asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event)) - asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event)) session_id = await future_session_id #获取session_id update_session_activity(session_id,db) @@ -463,6 +474,7 @@ async def streaming_chat_lasting_handler(ws,db,redis): tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) + asyncio.create_task(scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_event,asr_finished_event)) asyncio.create_task(scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event)) while not chat_finished_event.is_set(): @@ -505,23 +517,27 @@ async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): #音频数据消费函数 -async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event): +async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_finished_event,asr_finished_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" vad_count = 0 + is_signup = False while not (input_finished_event.is_set() and audio_q.empty()): try: + if not is_signup: + asr.session_signup(session_id) + is_signup = True audio_data = await asyncio.wait_for(audio_q.get(),timeout=3) if vad.is_speech(audio_data): if vad_count > 0: vad_count -= 1 - asr_result = asr.streaming_recognize(audio_data) + asr_result = asr.streaming_recognize(session_id, audio_data) current_message += ''.join(asr_result['text']) else: vad_count += 1 if vad_count >= 25: #连续25帧没有语音,则认为说完了 - asr_result = asr.streaming_recognize(audio_data, is_end=True) + asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True) if current_message: logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) @@ -532,8 +548,10 @@ async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event except asyncio.TimeoutError: continue except Exception as e: + asr.session_signout(session_id) logger.error(f"音频数据消费者函数发生错误: {str(e)}") break + asr.session_signout(session_id) asr_finished_event.set() #asr结果消费以及llm返回生产函数 @@ -621,7 +639,6 @@ async def voice_call_handler(ws, db, redis): future = asyncio.Future() #用于获取传输的session_id asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者 - asyncio.create_task(voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id @@ -629,6 +646,7 @@ async def voice_call_handler(ws, db, redis): tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) + asyncio.create_task(voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 asyncio.create_task(voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event)) #创建llm处理者 while not voice_call_end_event.is_set(): await asyncio.sleep(3) diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py index 84cc9e9..58a01d9 100644 --- a/utils/stt/funasr_utils.py +++ b/utils/stt/funasr_utils.py @@ -41,11 +41,12 @@ class FunAutoSpeechRecognizer(STTBase): self.chunk_size = [0, 10, 5] else: raise ValueError("`chunk_ms` should be 480 or 600, and type is int.") - self.chunk_partial_size = self.chunk_size[1] * 960 - self.audio_cache = None + self.chunk_partial_size = self.chunk_size[1] * 960 + self.audio_cache = {} self.asr_cache = {} - - + + # self.audio_cache = None + # self.asr_cache = {} self._init_asr() @@ -79,24 +80,22 @@ class FunAutoSpeechRecognizer(STTBase): # 随机初始化一段音频数据 init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) - self.audio_cache = None + self.audio_cache = {} self.asr_cache = {} # print("init ASR model done.") - def recognize(self, audio_data): - """recognize audio data to text""" - audio_data = self.check_audio_type(audio_data) - result = self.asr_model.generate(input=audio_data, - batch_size_s=300, - hotword=self.hotwords) - - # print(result) - text = '' - for res in result: - text += res['text'] - return text - - def streaming_recognize(self, + # when chat trying to use asr , sign up + def session_signup(self,session_id): + self.audio_cache[session_id] = None + self.asr_cache[session_id] = {} + + # when chat finish using asr , sign out + def session_signout(self,session_id): + del self.audio_cache[session_id] + del self.asr_cache[session_id] + + def streaming_recognize(self, + session_id, audio_data, is_end=False, auto_det_end=False): @@ -108,19 +107,22 @@ class FunAutoSpeechRecognizer(STTBase): auto_det_end: bool, whether to automatically detect the end of a audio data """ text_dict = dict(text=[], is_end=is_end) + + audio_cache = self.audio_cache[session_id] + asr_cache = self.asr_cache[session_id] audio_data = self.check_audio_type(audio_data) - if self.audio_cache is None: - self.audio_cache = audio_data + if audio_cache is None: + audio_cache = audio_data else: - # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") - if self.audio_cache.shape[0] > 0: - self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) + if audio_cache.shape[0] > 0: + audio_cache = np.concatenate([audio_cache, audio_data], axis=0) - if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: + if not is_end and audio_cache.shape[0] < self.chunk_partial_size: + self.audio_cache[session_id] = audio_cache return text_dict - total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size) if is_end: # if the audio data is the end of a sentence, \ @@ -131,7 +133,6 @@ class FunAutoSpeechRecognizer(STTBase): if auto_det_end: total_chunk_num += 1 - # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") end_idx = None for i in range(total_chunk_num): if auto_det_end: @@ -144,11 +145,11 @@ class FunAutoSpeechRecognizer(STTBase): # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") # t_stamp = time.time() - speech_chunk = self.audio_cache[start_idx:end_idx] + speech_chunk = audio_cache[start_idx:end_idx] # TODO: exceptions processes try: - res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) except ValueError as e: print(f"ValueError: {e}") continue @@ -156,15 +157,186 @@ class FunAutoSpeechRecognizer(STTBase): # print(f"each chunk time: {time.time()-t_stamp}") if is_end: - self.audio_cache = None - self.asr_cache = {} + audio_cache = None + asr_cache = {} else: if end_idx: - self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache + audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache text_dict['is_end'] = is_end - # print(f"text_dict: {text_dict}") + + self.audio_cache[session_id] = audio_cache + self.asr_cache[session_id] = asr_cache return text_dict +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ####################################################### # +# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR +# ####################################################### # +# import io +# import numpy as np +# import base64 +# import wave +# from funasr import AutoModel +# from .base_stt import STTBase + +# def decode_str2bytes(data): +# # 将Base64编码的字节串解码为字节串 +# if data is None: +# return None +# return base64.b64decode(data.encode('utf-8')) + +# class FunAutoSpeechRecognizer(STTBase): +# def __init__(self, +# model_path="paraformer-zh-streaming", +# device="cuda", +# RATE=16000, +# cfg_path=None, +# debug=False, +# chunk_ms=480, +# encoder_chunk_look_back=4, +# decoder_chunk_look_back=1, +# **kwargs): +# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) + +# self.asr_model = AutoModel(model=model_path, device=device, **kwargs) + +# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention +# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention + +# #[0, 8, 4] 480ms, [0, 10, 5] 600ms +# if chunk_ms == 480: +# self.chunk_size = [0, 8, 4] +# elif chunk_ms == 600: +# self.chunk_size = [0, 10, 5] +# else: +# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.") +# self.chunk_partial_size = self.chunk_size[1] * 960 +# self.audio_cache = None +# self.asr_cache = {} + + + +# self._init_asr() + +# def check_audio_type(self, audio_data): +# """check audio data type and convert it to bytes if necessary.""" +# if isinstance(audio_data, bytes): +# pass +# elif isinstance(audio_data, list): +# audio_data = b''.join(audio_data) +# elif isinstance(audio_data, str): +# audio_data = decode_str2bytes(audio_data) +# elif isinstance(audio_data, io.BytesIO): +# wf = wave.open(audio_data, 'rb') +# audio_data = wf.readframes(wf.getnframes()) +# elif isinstance(audio_data, np.ndarray): +# pass +# else: +# raise TypeError(f"audio_data must be bytes, list, str, \ +# io.BytesIO or numpy array, but got {type(audio_data)}") + +# if isinstance(audio_data, bytes): +# audio_data = np.frombuffer(audio_data, dtype=np.int16) +# elif isinstance(audio_data, np.ndarray): +# if audio_data.dtype != np.int16: +# audio_data = audio_data.astype(np.int16) +# else: +# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") +# return audio_data + +# def _init_asr(self): +# # 随机初始化一段音频数据 +# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) +# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) +# self.audio_cache = None +# self.asr_cache = {} +# # print("init ASR model done.") + +# def recognize(self, audio_data): +# """recognize audio data to text""" +# audio_data = self.check_audio_type(audio_data) +# result = self.asr_model.generate(input=audio_data, +# batch_size_s=300, +# hotword=self.hotwords) + +# # print(result) +# text = '' +# for res in result: +# text += res['text'] +# return text + +# def streaming_recognize(self, +# audio_data, +# is_end=False, +# auto_det_end=False): +# """recognize partial result + +# Args: +# audio_data: bytes or numpy array, partial audio data +# is_end: bool, whether the audio data is the end of a sentence +# auto_det_end: bool, whether to automatically detect the end of a audio data +# """ +# text_dict = dict(text=[], is_end=is_end) + +# audio_data = self.check_audio_type(audio_data) +# if self.audio_cache is None: +# self.audio_cache = audio_data +# else: +# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") +# if self.audio_cache.shape[0] > 0: +# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) + +# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: +# return text_dict + +# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + +# if is_end: +# # if the audio data is the end of a sentence, \ +# # we need to add one more chunk to the end to \ +# # ensure the end of the sentence is recognized correctly. +# auto_det_end = True + +# if auto_det_end: +# total_chunk_num += 1 + +# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") +# end_idx = None +# for i in range(total_chunk_num): +# if auto_det_end: +# is_end = i == total_chunk_num - 1 +# start_idx = i*self.chunk_partial_size +# if auto_det_end: +# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 +# else: +# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 +# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") +# # t_stamp = time.time() + +# speech_chunk = self.audio_cache[start_idx:end_idx] + +# # TODO: exceptions processes +# try: +# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) +# except ValueError as e: +# print(f"ValueError: {e}") +# continue +# text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) +# # print(f"each chunk time: {time.time()-t_stamp}") + +# if is_end: +# self.audio_cache = None +# self.asr_cache = {} +# else: +# if end_idx: +# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache +# text_dict['is_end'] = is_end + +# # print(f"text_dict: {text_dict}") +# return text_dict + + \ No newline at end of file