From 3c60c9a4185cd9ca246931046f27ab4c861fdb14 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Sun, 5 May 2024 14:50:26 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E8=AF=AD=E9=9F=B3=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=BA=86=E7=AE=80=E7=95=A5=E7=9A=84=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat.py | 178 +++++++++++++++++++++++----------------- 1 file changed, 104 insertions(+), 74 deletions(-) diff --git a/app/controllers/chat.py b/app/controllers/chat.py index f187a32..6b09398 100644 --- a/app/controllers/chat.py +++ b/app/controllers/chat.py @@ -63,22 +63,25 @@ def parseChunkDelta(chunk): #断句函数 def split_string_with_punctuation(current_sentence,text,is_first,is_end): - result = [] - if is_end: - if current_sentence: - result.append(current_sentence) - current_sentence = '' + try: + result = [] + if is_end: + if current_sentence: + result.append(current_sentence) + current_sentence = '' + return result, current_sentence, is_first + for char in text: + current_sentence += char + if is_first and char in ',.?!,。?!': + result.append(current_sentence) + current_sentence = '' + is_first = False + elif char in '。?!': + result.append(current_sentence) + current_sentence = '' return result, current_sentence, is_first - for char in text: - current_sentence += char - if is_first and char in ',.?!,。?!': - result.append(current_sentence) - current_sentence = '' - is_first = False - elif char in '。?!': - result.append(current_sentence) - current_sentence = '' - return result, current_sentence, is_first + except Exception as e: + logger.error(f"断句时出现错误: {str(e)}") #vad预处理 def vad_preprocess(audio): @@ -203,74 +206,84 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f break await user_input_q.put(sct_data_json['audio']) except KeyError as ke: - if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat': + if 'state' in sct_data_json and 'method' in sct_data_json: logger.debug("收到心跳包") + except Exception as e: + logger.error(f"用户输入处理函数发生错误: {str(e)}") #语音识别 async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") - current_message = "" - while not (user_input_finish_event.is_set() and user_input_q.empty()): - audio_data = await user_input_q.get() - asr_result = asr.streaming_recognize(audio_data) + try: + current_message = "" + while not (user_input_finish_event.is_set() and user_input_q.empty()): + audio_data = await user_input_q.get() + asr_result = asr.streaming_recognize(audio_data) + current_message += ''.join(asr_result['text']) + asr_result = asr.streaming_recognize(b'',is_end=True) current_message += ''.join(asr_result['text']) - asr_result = asr.streaming_recognize(b'',is_end=True) - current_message += ''.join(asr_result['text']) - await llm_input_q.put(current_message) + await llm_input_q.put(current_message) + except Exception as e: + logger.error(f"语音识别函数发生错误: {str(e)}") logger.debug(f"接收到用户消息: {current_message}") #大模型调用 async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): logger.debug("llm调用函数启动") - llm_response = "" - current_sentence = "" - is_first = True - is_end = False - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) - current_message = await llm_input_q.get() - messages.append({'role': 'user', "content": current_message}) - payload = json.dumps({ - "model": llm_info["model"], - "stream": True, - "messages": messages, - "max_tokens": 10000, - "temperature": llm_info["temperature"], - "top_p": llm_info["top_p"] - }) - headers = { - 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", - 'Content-Type': 'application/json' - } - response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) - for chunk in response.iter_lines(): - chunk_data = parseChunkDelta(chunk) - is_end = chunk_data == "end" - if not is_end: - llm_response += chunk_data - sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) - for sentence in sentences: - if response_type == RESPONSE_TEXT: - response_message = {"type": "text", "code":200, "msg": sentence} - await ws.send_text(json.dumps(response_message, ensure_ascii=False)) - elif response_type == RESPONSE_AUDIO: - sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) - response_message = {"type": "text", "code":200, "msg": sentence} - await ws.send_bytes(audio) - await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + try: + llm_response = "" + current_sentence = "" + is_first = True + is_end = False + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + current_message = await llm_input_q.get() + messages.append({'role': 'user', "content": current_message}) + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens": 10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型 + except Exception as e: + logger.error(f"llm调用发生错误: {str(e)}") + try: + for chunk in response.iter_lines(): + chunk_data = parseChunkDelta(chunk) + is_end = chunk_data == "end" + if not is_end: + llm_response += chunk_data + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句 + for sentence in sentences: + if response_type == RESPONSE_TEXT: + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 + elif response_type == RESPONSE_AUDIO: + sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_bytes(audio) #返回音频数据 + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 logger.debug(f"websocket返回: {sentence}") - if is_end: - logger.debug(f"llm返回结果: {llm_response}") - await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) - is_end = False - - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) - messages.append({'role': 'assistant', "content": llm_response}) - session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 - redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session - is_first = True - llm_response = "" + if is_end: + logger.debug(f"llm返回结果: {llm_response}") + await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + is_end = False #重置is_end标志位 + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + is_first = True + llm_response = "" + except Exception as e: + logger.error(f"处理llm返回结果发生错误: {str(e)}") chat_finished_event.set() async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): @@ -334,9 +347,10 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f continue except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"用户输入处理函数发生错误: {str(e)}") + break - - #语音识别 async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event): logger.debug("语音识别函数启动") @@ -354,6 +368,9 @@ async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_fini current_message += ''.join(asr_result['text']) except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"语音识别函数发生错误: {str(e)}") + break asr_finished_event.set() #大模型调用 @@ -413,6 +430,9 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis llm_response = "" except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"处理llm返回结果发生错误: {str(e)}") + break chat_finished_event.set() async def streaming_chat_lasting_handler(ws,db,redis): @@ -466,9 +486,13 @@ async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): vad_frame,audio_data = vad_preprocess(audio_data) await audio_q.put(vad_frame) #将音频数据存入audio_q except KeyError as ke: - logger.info(f"收到心跳包") + if 'state' in voice_call_data_json and 'method' in voice_call_data_json: + logger.info(f"收到心跳包") except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"音频数据生产函数发生错误: {str(e)}") + break #音频数据消费函数 @@ -498,6 +522,9 @@ async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event vad_count = 0 except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"音频数据消费者函数发生错误: {str(e)}") + break asr_finished_event.set() #asr结果消费以及llm返回生产函数 @@ -553,6 +580,9 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re llm_response = "" except asyncio.TimeoutError: continue + except Exception as e: + logger.error(f"处理llm返回结果发生错误: {str(e)}") + break voice_call_end_event.set()