From f1a844c84a69336cd902ef9975306d8c5a5e15f1 Mon Sep 17 00:00:00 2001 From: Killua777 <1223086337@qq.com> Date: Sat, 4 May 2024 09:07:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E6=94=B6=E5=88=B0?= =?UTF-8?q?=E5=BF=83=E8=B7=B3=E5=8C=85=E5=90=8E=E4=BC=9A=E6=96=AD=E5=BC=80?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat.py | 58 +++++++++++++++++++++++------------------ utils/audio_utils.py | 7 +++-- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/app/controllers/chat.py b/app/controllers/chat.py index dc0d0c1..4668467 100644 --- a/app/controllers/chat.py +++ b/app/controllers/chat.py @@ -70,10 +70,12 @@ def split_string_with_punctuation(current_sentence,text,is_first): current_sentence = '' return result, current_sentence, is_first +#vad预处理 def vad_preprocess(audio): if len(audio)<1280: return ('A'*1280) return audio[:1280],audio[1280:] + #-------------------------------------------------------- # 创建新聊天 @@ -308,8 +310,9 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event): logger.debug("用户输入处理函数启动") is_future_done = False - try: - while not input_finished_event.is_set(): + + while not input_finished_event.is_set(): + try: scl_data_json = json.loads(await ws.receive_text()) if scl_data_json['is_close']: input_finished_event.set() @@ -328,10 +331,12 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f await user_input_q.put(user_input_frame) user_input_frame = {"audio": scl_data_json['audio'], "is_end": False} await user_input_q.put(user_input_frame) - except KeyError as ke: - if 'state' in scl_data_json and 'method' in scl_data_json: - logger.debug("收到心跳包") + except KeyError as ke: + if 'state' in scl_data_json and 'method' in scl_data_json: + logger.debug("收到心跳包") + continue + #语音识别 async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event): logger.debug("语音识别函数启动") @@ -463,12 +468,12 @@ async def streaming_chat_lasting_handler(ws,db,redis): #--------------------------------语音通话接口-------------------------------------- #音频数据生产函数 -async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event): +async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): logger.debug("音频数据生产函数启动") is_future_done = False audio_data = "" - try: - while not input_finished_event.is_set(): + while not input_finished_event.is_set(): + try: voice_call_data_json = json.loads(await ws.receive_text()) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) @@ -480,12 +485,12 @@ async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event): audio_data += voice_call_data_json["audio"] while len(audio_data) > 1280: vad_frame,audio_data = vad_preprocess(audio_data) - await audio_queue.put(vad_frame) #将音频数据存入audio_q - except KeyError as ke: - logger.info(f"收到心跳包") + await audio_q.put(vad_frame) #将音频数据存入audio_q + except KeyError as ke: + logger.info(f"收到心跳包") #音频数据消费函数 -async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event): +async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" @@ -504,6 +509,8 @@ async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,as if current_message: logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) + text_response = {"type": "user_text", "code": 200, "msg": current_message} + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 current_message = "" vad_count = 0 asr_finished_event.set() @@ -570,7 +577,7 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event while not (split_finished_event.is_set() and split_result_q.empty()): sentence = await split_result_q.get() sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) - text_response = {"type": "text", "code": 200, "msg": sentence} + text_response = {"type": "llm_text", "code": 200, "msg": sentence} await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 logger.debug(f"websocket返回:{sentence}") @@ -581,19 +588,20 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") - audio_q = asyncio.Queue() - asr_result_q = asyncio.Queue() - llm_response_q = asyncio.Queue() - split_result_q = asyncio.Queue() - - input_finished_event = asyncio.Event() - asr_finished_event = asyncio.Event() - llm_finished_event = asyncio.Event() - split_finished_event = asyncio.Event() - voice_call_end_event = asyncio.Event() - future = asyncio.Future() + audio_q = asyncio.Queue() #音频队列 + asr_result_q = asyncio.Queue() #语音识别结果队列 + llm_response_q = asyncio.Queue() #大模型返回队列 + split_result_q = asyncio.Queue() #断句结果队列 + + input_finished_event = asyncio.Event() #用户输入结束事件 + asr_finished_event = asyncio.Event() #语音识别结束事件 + llm_finished_event = asyncio.Event() #大模型结束事件 + split_finished_event = asyncio.Event() #断句结束事件 + voice_call_end_event = asyncio.Event() #语音电话终止事件 + + future = asyncio.Future() #用于获取传输的session_id asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者 - asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 + asyncio.create_task(voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id diff --git a/utils/audio_utils.py b/utils/audio_utils.py index 48d0c72..5e5ba4d 100644 --- a/utils/audio_utils.py +++ b/utils/audio_utils.py @@ -10,5 +10,8 @@ class VAD(): self.min_act_time = min_act_time # 最小活动时间,单位秒 def is_speech(self,data): - byte_data = base64.b64decode(data) - return self.vad.is_speech(byte_data, self.RATE) \ No newline at end of file + try: + byte_data = base64.b64decode(data) + return self.vad.is_speech(byte_data, self.RATE) + except Exception as e: + return False \ No newline at end of file