From f548d57595906833bca05f216dac4133c9df4ff0 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Thu, 2 May 2024 09:55:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=8C=81=E7=BB=AD=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=81=8A=E5=A4=A9=E6=8E=A5=E5=8F=A3=E5=92=8C=E8=AF=AD?= =?UTF-8?q?=E9=9F=B3=E7=94=B5=E8=AF=9D=E6=8E=A5=E5=8F=A3=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E5=BE=97=E5=85=B6=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=AE=8C=E6=89=80=E6=9C=89=E6=95=B0=E6=8D=AE=E5=86=8D?= =?UTF-8?q?=E9=80=80=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat.py | 366 +++++++++++++++++++--------------------- 1 file changed, 176 insertions(+), 190 deletions(-) diff --git a/app/controllers/chat.py b/app/controllers/chat.py index 4aa2602..434287d 100644 --- a/app/controllers/chat.py +++ b/app/controllers/chat.py @@ -297,210 +297,188 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): #--------------------------------------------------------------------------------------------------- -# 持续流式聊天 -async def streaming_chat_lasting_handler(ws, db, redis): - print("Websocket连接成功") - while True: - try: - print("等待接受") - data = await asyncio.wait_for(ws.receive_text(), timeout=60) - data_json = json.loads(data) - if data_json["is_close"]: - close_message = {"type": "close", "code": 200, "msg": ""} - await ws.send_text(json.dumps(close_message, ensure_ascii=False)) - print("连接关闭") - await asyncio.sleep(0.5) - await ws.close() - return; - except asyncio.TimeoutError: - print("连接超时") - await ws.close() - return; - current_message = "" # 用于存储用户消息 - response_type = RESPONSE_TEXT # 用于获取返回类型 - session_id = "" - if Config.STRAM_CHAT.ASR == "LOCAL": - try: - while True: - if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入 - if data_json["meta_info"]["voice_synthesize"]: - response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 - session_id = data_json["meta_info"]["session_id"] - current_message = data_json['text'] - break - if not data_json['meta_info']['is_end']: # 还在发 - asr_result = asr.streaming_recognize(data_json["audio"]) - current_message += ''.join(asr_result['text']) - else: # 发完了 - asr_result = asr.streaming_recognize(data_json["audio"], is_end=True) - session_id = data_json["meta_info"]["session_id"] - current_message += ''.join(asr_result['text']) - if data_json["meta_info"]["voice_synthesize"]: - response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 - break - data_json = json.loads(await ws.receive_text()) +#------------------------------------------持续流式聊天---------------------------------------------- +#处理用户输入 +async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event): + logger.debug("用户输入处理函数启动") + is_future_done = False + try: + while not input_finished_event.is_set(): + scl_data_json = json.loads(await ws.receive_text()) + if scl_data_json['is_close']: + input_finished_event.set() + break + if not is_future_done: + future_session_id.set_result(scl_data_json['meta_info']['session_id']) + if scl_data_json['meta_info']['voice_synthesize']: + future_response_type.set_result(RESPONSE_AUDIO) + else: + future_response_type.set_result(RESPONSE_TEXT) + is_future_done = True + if scl_data_json['text']: + await llm_input_q.put(scl_data_json['text']) + if scl_data_json['meta_info']['is_end']: + user_input_frame = {"audio": scl_data_json['audio'], "is_end": True} + await user_input_q.put(user_input_frame) + user_input_frame = {"audio": scl_data_json['audio'], "is_end": False} + await user_input_q.put(user_input_frame) + except KeyError as ke: + if 'state' in scl_data_json and 'method' in scl_data_json: + logger.debug("收到心跳包") - except Exception as e: - error_info = f"接收用户消息错误: {str(e)}" - error_message = {"type": "error", "code": "500", "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - elif Config.STRAM_CHAT.ASR == "REMOTE": - error_info = f"远程ASR服务暂未开通" - error_message = {"type": "error", "code": "500", "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - - print(f"接收到用户消息: {current_message}") - # 查询Session记录 - session_content_str = "" - if redis.exists(session_id): - session_content_str = redis.get(session_id) +#语音识别 +async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event): + logger.debug("语音识别函数启动") + current_message = "" + while not (input_finished_event.is_set() and user_input_q.empty()): + aduio_frame = await user_input_q.get() + if aduio_frame['is_end']: + asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True) + current_message += ''.join(asr_result['text']) + await llm_input_q.put(current_message) + logger.debug(f"接收到用户消息: {current_message}") else: - session_db = db.query(Session).filter(Session.id == session_id).first() - if not session_db: - error_info = f"未找到session记录: {str(e)}" - error_message = {"type": "error", "code": 500, "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - session_content_str = session_db.content + asr_result = asr.streaming_recognize(aduio_frame['audio']) + current_message += ''.join(asr_result['text']) + asr_finished_event.set() - session_content = json.loads(session_content_str) - llm_info = json.loads(session_content["llm_info"]) - tts_info = json.loads(session_content["tts_info"]) - user_info = json.loads(session_content["user_info"]) +#大模型调用 +async def scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event): + logger.debug("llm调用函数启动") + while not (asr_finished_event.is_set() and llm_input_q.empty()): + session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) + current_message = await llm_input_q.get() messages.append({'role': 'user', "content": current_message}) - token_count = session_content["token"] + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens": 10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) + if response.status_code == 200: + for chunk in response.iter_lines(): + if chunk: + chunk_data = parseChunkDelta(chunk) + llm_frame = {"message": chunk_data, "is_end": False} + await llm_response_q.put(llm_frame) + llm_frame = {"message": "", "is_end": True} + await llm_response_q.put(llm_frame) + llm_finished_event.set() - try: - payload = json.dumps({ - "model": llm_info["model"], - "stream": True, - "messages": messages, - "tool_choice": "auto", - "max_tokens": 10000, - "temperature": llm_info["temperature"], - "top_p": llm_info["top_p"] - }) - headers = { - 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", - 'Content-Type': 'application/json' - } - response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) - except Exception as e: - error_info = f"发送信息给大模型时发生错误: {str(e)}" - error_message = {"type": "error", "code": 500, "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - - def split_string_with_punctuation(text): - punctuations = "!?。" - result = [] - current_sentence = "" - for char in text: - current_sentence += char - if char in punctuations: - result.append(current_sentence) - current_sentence = "" - # 判断最后一个字符是否为标点符号 - if current_sentence and current_sentence[-1] not in punctuations: - # 如果最后一段不以标点符号结尾,则加入拆分数组 - result.append(current_sentence) - return result - - llm_response = "" - response_buf = "" - - try: - if Config.STRAM_CHAT.TTS == "LOCAL": - if response.status_code == 200: - for chunk in response.iter_lines(): - if chunk: - if response_type == RESPONSE_AUDIO: - chunk_data = parseChunkDelta(chunk) - llm_response += chunk_data - response_buf += chunk_data - split_buf = split_string_with_punctuation(response_buf) - response_buf = "" - if len(split_buf) != 0: - for sentence in split_buf: - sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) - text_response = {"type": "text", "code": 200, "msg": sentence} - await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 - await ws.send_bytes(audio) # 返回音频二进制流数据 - if response_type == RESPONSE_TEXT: - chunk_data = parseChunkDelta(chunk) - llm_response += chunk_data - text_response = {"type": "text", "code": 200, "msg": chunk_data} - await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 - - elif Config.STRAM_CHAT.TTS == "REMOTE": - error_info = f"暂不支持远程音频合成" - error_message = {"type": "error", "code": 500, "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - end_response = {"type": "end", "code": 200, "msg": ""} - await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束 - print(f"llm消息: {llm_response}") - except Exception as e: - error_info = f"音频合成与向前端返回时错误: {str(e)}" - error_message = {"type": "error", "code": 500, "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - - try: +#大模型返回断句 +async def scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event): + logger.debug("llm返回处理函数启动") + llm_response = "" + current_sentence = "" + is_first = True + while not (llm_finished_event.is_set() and llm_response_q.empty()): + llm_frame = await llm_response_q.get() + llm_response += llm_frame['message'] + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) + for sentence in sentences: + sentence_frame = {"message": sentence, "is_end": False} + await split_result_q.put(sentence_frame) + if llm_frame['is_end']: + sentence_frame = {"message": "", "is_end": True} + await split_result_q.put(sentence_frame) + is_first = True + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) - token_count += len(llm_response) - session_content["messages"] = json.dumps(messages, ensure_ascii=False) - session_content["token"] = token_count - redis.set(session_id, json.dumps(session_content, ensure_ascii=False)) - except Exception as e: - error_info = f"更新session时错误: {str(e)}" - error_message = {"type": "error", "code": 500, "msg": error_info} - logger.error(error_info) - await ws.send_text(json.dumps(error_message, ensure_ascii=False)) - await ws.close() - return - print("处理完毕") + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + logger.debug(f"llm返回结果: {llm_response}") + llm_response = "" + current_sentence = "" + split_finished_event.set() + +#文本返回及语音合成 +async def scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event): + logger.debug("返回处理函数启动") + while not (split_finished_event.is_set() and split_result_q.empty()): + sentence_frame = await split_result_q.get() + sentence = sentence_frame['message'] + if sentence_frame['is_end']: + await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + continue + if response_type == RESPONSE_TEXT: + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + elif response_type == RESPONSE_AUDIO: + sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_bytes(audio) + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + logger.debug(f"websocket返回: {sentence}") + chat_finish_event.set() + +async def streaming_chat_lasting_handler(ws,db,redis): + logger.debug("streaming chat lasting websocket 连接建立") + user_input_q = asyncio.Queue() # 用于存储用户输入 + llm_input_q = asyncio.Queue() # 用于存储llm输入 + llm_response_q = asyncio.Queue() # 用于存储llm输出 + split_result_q = asyncio.Queue() # 用于存储llm返回后断句输出 + + input_finished_event = asyncio.Event() + asr_finished_event = asyncio.Event() + llm_finished_event = asyncio.Event() + split_finished_event = asyncio.Event() + chat_finish_event = asyncio.Event() + future_session_id = asyncio.Future() + future_response_type = asyncio.Future() + asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event)) + asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event)) + + session_id = await future_session_id #获取session_id + response_type = await future_response_type #获取返回类型 + tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) + llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) + + asyncio.create_task(scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event)) + asyncio.create_task(scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) + asyncio.create_task(scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event)) + + while not chat_finish_event.is_set(): + await asyncio.sleep(3) + await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) + await ws.close() + logger.debug("streaming chat lasting websocket 连接断开") +#--------------------------------------------------------------------------------------------------- + #--------------------------------语音通话接口-------------------------------------- #音频数据生产函数 -async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event): +async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event): logger.debug("音频数据生产函数启动") is_future_done = False - while not shutdown_event.is_set(): + while not input_finished_event.is_set(): voice_call_data_json = json.loads(await ws.receive_text()) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) is_future_done = True if voice_call_data_json["is_close"]: - shutdown_event.set() + input_finished_event.set() break else: await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q #音频数据消费函数 -async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event): +async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" vad_count = 0 - while not (shutdown_event.is_set() and audio_q.empty()): + while not (input_finished_event.is_set() and audio_q.empty()): audio_data = await audio_q.get() if vad.is_speech(audio_data): if vad_count > 0: @@ -516,11 +494,12 @@ async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event): await asr_result_q.put(current_message) current_message = "" vad_count = 0 + asr_finished_event.set() #asr结果消费以及llm返回生产函数 -async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event): +async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event): logger.debug("asr结果消费以及llm返回生产函数启动") - while not (shutdown_event.is_set() and asr_result_q.empty()): + while not (asr_finished_event.is_set() and asr_result_q.empty()): session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await asr_result_q.get() @@ -547,14 +526,15 @@ async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_r await llm_response_q.put(llm_frame) llm_frame = {'message':"",'is_end':True} await llm_response_q.put(llm_frame) + llm_finished_event.set() #llm结果返回函数 -async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event): +async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event): logger.debug("llm结果返回函数启动") llm_response = "" current_sentence = "" is_first = True - while not (shutdown_event.is_set() and llm_response_q.empty()): + while not (llm_finished_event.is_set() and llm_response_q.empty()): llm_frame = await llm_response_q.get() llm_response += llm_frame['message'] sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) @@ -570,11 +550,12 @@ async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,sp logger.debug(f"llm返回结果: {llm_response}") llm_response = "" current_sentence = "" + split_finished_event.set() #语音合成及返回函数 -async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event): +async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event): logger.debug("语音合成及返回函数启动") - while not (shutdown_event.is_set() and split_result_q.empty()): + while not (split_finished_event.is_set() and split_result_q.empty()): sentence = await split_result_q.get() sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) text_response = {"type": "text", "code": 200, "msg": sentence} @@ -583,6 +564,7 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event): logger.debug(f"websocket返回:{sentence}") asyncio.sleep(0.5) await ws.close() + voice_call_end_event.set() async def voice_call_handler(ws, db, redis): @@ -592,22 +574,26 @@ async def voice_call_handler(ws, db, redis): llm_response_q = asyncio.Queue() split_result_q = asyncio.Queue() - shutdown_event = asyncio.Event() + input_finished_event = asyncio.Event() + asr_finished_event = asyncio.Event() + llm_finished_event = asyncio.Event() + split_finished_event = asyncio.Event() + voice_call_end_event = asyncio.Event() future = asyncio.Future() - asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者 - asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者 + asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者 + asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) - asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者 - asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果 - asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果 + asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者 + asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果 + asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果 - while not shutdown_event.is_set(): - await asyncio.sleep(5) + while not voice_call_end_event.is_set(): + await asyncio.sleep(3) await ws.close() logger.debug("voice_call websocket 连接断开") #------------------------------------------------------------------------------------------ \ No newline at end of file