From f11339ff92be59dd04275f42244b70b7b9d4d680 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Sat, 4 May 2024 11:26:14 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=B0=86requesst=E5=BA=93=E7=9A=84http?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E8=AF=B7=E6=B1=82=EF=BC=8C=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BAhttpx=E7=9A=84=E5=BC=82=E6=AD=A5=E8=AF=B7=E6=B1=82=202?= =?UTF-8?q?.=E5=B0=86=E6=94=B6=E5=88=B0=E5=A4=A7=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E6=B6=88=E6=81=AF=E5=90=8E=E7=9A=84=E5=93=8D?= =?UTF-8?q?=E5=BA=94=E6=94=B9=E4=B8=BA=E5=90=8C=E6=AD=A5=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat.py | 405 +++++++++++++++++++++------------------- 1 file changed, 211 insertions(+), 194 deletions(-) diff --git a/app/controllers/chat.py b/app/controllers/chat.py index 4668467..9ff9464 100644 --- a/app/controllers/chat.py +++ b/app/controllers/chat.py @@ -9,8 +9,8 @@ from utils.xf_asr_utils import generate_xf_asr_url from config import get_config import uuid import json -import requests import asyncio +import httpx # 依赖注入获取logger logger = get_logger() @@ -48,17 +48,27 @@ def get_session_content(session_id,redis,db): #解析大模型流式返回内容 def parseChunkDelta(chunk): - decoded_data = chunk.decode('utf-8') - parsed_data = json.loads(decoded_data[6:]) - if 'delta' in parsed_data['choices'][0]: - delta_content = parsed_data['choices'][0]['delta'] - return delta_content['content'] - else: - return "" + try: + if chunk == "": + return "" + chunk_json_str = chunk[6:] + parsed_data = json.loads(chunk_json_str) + if 'delta' in parsed_data['choices'][0]: + delta_content = parsed_data['choices'][0]['delta'] + return delta_content['content'] + else: + return "end" + except KeyError: + logger.error(f"error chunk: {chunk}") #断句函数 -def split_string_with_punctuation(current_sentence,text,is_first): +def split_string_with_punctuation(current_sentence,text,is_first,is_end): result = [] + if is_end: + if current_sentence: + result.append(current_sentence) + current_sentence = '' + return result, current_sentence, is_first for char in text: current_sentence += char if is_first and char in ',.?!,。?!': @@ -210,8 +220,12 @@ async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event): logger.debug(f"接收到用户消息: {current_message}") #大模型调用 -async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event): +async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): logger.debug("llm调用函数启动") + llm_response = "" + current_sentence = "" + is_first = True + is_end = False session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await llm_input_q.get() @@ -228,13 +242,37 @@ async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_ 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } - response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) - if response.status_code == 200: - for chunk in response.iter_lines(): - if chunk: - chunk_data = parseChunkDelta(chunk) - await llm_response_q.put(chunk_data) - llm_response_finish_event.set() + async with httpx.AsyncClient() as client: + response = await client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) + async for chunk in response.aiter_lines(): + chunk_data = parseChunkDelta(chunk) + is_end = chunk_data == "end" + if not is_end: + llm_response += chunk_data + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) + for sentence in sentences: + if response_type == RESPONSE_TEXT: + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + elif response_type == RESPONSE_AUDIO: + sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_bytes(audio) + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + logger.debug(f"websocket返回: {sentence}") + if is_end: + logger.debug(f"llm返回结果: {llm_response}") + await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + is_end = False + + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + is_first = True + llm_response = "" + chat_finished_event.set() #大模型返回断句 async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event): @@ -275,12 +313,9 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): logger.debug("streaming chat temporary websocket 连接建立") user_input_q = asyncio.Queue() # 用于存储用户输入 llm_input_q = asyncio.Queue() # 用于存储llm输入 - llm_response_q = asyncio.Queue() # 用于存储llm输出 - split_result_q = asyncio.Queue() # 用于存储tts输出 user_input_finish_event = asyncio.Event() - llm_response_finish_event = asyncio.Event() - chat_finish_event = asyncio.Event() + chat_finished_event = asyncio.Event() future_session_id = asyncio.Future() future_response_type = asyncio.Future() asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event)) @@ -292,11 +327,9 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) - asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event)) - asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event)) - asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event)) + asyncio.create_task(sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event)) - while not chat_finish_event.is_set(): + while not chat_finished_event.is_set(): await asyncio.sleep(1) await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) await ws.close() @@ -313,7 +346,7 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f while not input_finished_event.is_set(): try: - scl_data_json = json.loads(await ws.receive_text()) + scl_data_json = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=3)) if scl_data_json['is_close']: input_finished_event.set() break @@ -335,6 +368,9 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f if 'state' in scl_data_json and 'method' in scl_data_json: logger.debug("收到心跳包") continue + except asyncio.TimeoutError: + continue + #语音识别 @@ -342,109 +378,90 @@ async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_fini logger.debug("语音识别函数启动") current_message = "" while not (input_finished_event.is_set() and user_input_q.empty()): - aduio_frame = await user_input_q.get() - if aduio_frame['is_end']: - asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True) - current_message += ''.join(asr_result['text']) - await llm_input_q.put(current_message) - logger.debug(f"接收到用户消息: {current_message}") - else: - asr_result = asr.streaming_recognize(aduio_frame['audio']) - current_message += ''.join(asr_result['text']) + try: + aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3) + if aduio_frame['is_end']: + asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True) + current_message += ''.join(asr_result['text']) + await llm_input_q.put(current_message) + logger.debug(f"接收到用户消息: {current_message}") + else: + asr_result = asr.streaming_recognize(aduio_frame['audio']) + current_message += ''.join(asr_result['text']) + except asyncio.TimeoutError: + continue asr_finished_event.set() #大模型调用 -async def scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event): +async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event): logger.debug("llm调用函数启动") - while not (asr_finished_event.is_set() and llm_input_q.empty()): - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) - current_message = await llm_input_q.get() - messages.append({'role': 'user', "content": current_message}) - payload = json.dumps({ - "model": llm_info["model"], - "stream": True, - "messages": messages, - "max_tokens": 10000, - "temperature": llm_info["temperature"], - "top_p": llm_info["top_p"] - }) - headers = { - 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", - 'Content-Type': 'application/json' - } - response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) - if response.status_code == 200: - for chunk in response.iter_lines(): - if chunk: - chunk_data = parseChunkDelta(chunk) - llm_frame = {"message": chunk_data, "is_end": False} - await llm_response_q.put(llm_frame) - llm_frame = {"message": "", "is_end": True} - await llm_response_q.put(llm_frame) - llm_finished_event.set() - -#大模型返回断句 -async def scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event): - logger.debug("llm返回处理函数启动") llm_response = "" current_sentence = "" is_first = True - while not (llm_finished_event.is_set() and llm_response_q.empty()): - llm_frame = await llm_response_q.get() - llm_response += llm_frame['message'] - sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) - for sentence in sentences: - sentence_frame = {"message": sentence, "is_end": False} - await split_result_q.put(sentence_frame) - if llm_frame['is_end']: - sentence_frame = {"message": "", "is_end": True} - await split_result_q.put(sentence_frame) - is_first = True + is_end = False + while not (asr_finished_event.is_set() and llm_input_q.empty()): + try: session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) - messages.append({'role': 'assistant', "content": llm_response}) - session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 - redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session - logger.debug(f"llm返回结果: {llm_response}") - llm_response = "" - current_sentence = "" - split_finished_event.set() + current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3) + messages.append({'role': 'user', "content": current_message}) + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens": 10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + async with httpx.AsyncClient() as client: + response = await client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) + async for chunk in response.aiter_lines(): + chunk_data = parseChunkDelta(chunk) + is_end = chunk_data == "end" + if not is_end: + llm_response += chunk_data + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) + for sentence in sentences: + if response_type == RESPONSE_TEXT: + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + elif response_type == RESPONSE_AUDIO: + sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_bytes(audio) + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + logger.debug(f"websocket返回: {sentence}") + if is_end: + logger.debug(f"llm返回结果: {llm_response}") + await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + is_end = False -#文本返回及语音合成 -async def scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event): - logger.debug("返回处理函数启动") - while not (split_finished_event.is_set() and split_result_q.empty()): - sentence_frame = await split_result_q.get() - sentence = sentence_frame['message'] - if sentence_frame['is_end']: - await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + is_first = True + llm_response = "" + except asyncio.TimeoutError: continue - if response_type == RESPONSE_TEXT: - response_message = {"type": "text", "code":200, "msg": sentence} - await ws.send_text(json.dumps(response_message, ensure_ascii=False)) - elif response_type == RESPONSE_AUDIO: - sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) - response_message = {"type": "text", "code":200, "msg": sentence} - await ws.send_bytes(audio) - await ws.send_text(json.dumps(response_message, ensure_ascii=False)) - logger.debug(f"websocket返回: {sentence}") - chat_finish_event.set() + chat_finished_event.set() async def streaming_chat_lasting_handler(ws,db,redis): logger.debug("streaming chat lasting websocket 连接建立") user_input_q = asyncio.Queue() # 用于存储用户输入 llm_input_q = asyncio.Queue() # 用于存储llm输入 - llm_response_q = asyncio.Queue() # 用于存储llm输出 - split_result_q = asyncio.Queue() # 用于存储llm返回后断句输出 input_finished_event = asyncio.Event() asr_finished_event = asyncio.Event() - llm_finished_event = asyncio.Event() - split_finished_event = asyncio.Event() - chat_finish_event = asyncio.Event() + chat_finished_event = asyncio.Event() future_session_id = asyncio.Future() future_response_type = asyncio.Future() + asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event)) asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event)) @@ -453,11 +470,9 @@ async def streaming_chat_lasting_handler(ws,db,redis): tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) - asyncio.create_task(scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event)) - asyncio.create_task(scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) - asyncio.create_task(scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event)) + asyncio.create_task(scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event)) - while not chat_finish_event.is_set(): + while not chat_finished_event.is_set(): await asyncio.sleep(3) await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) await ws.close() @@ -474,7 +489,7 @@ async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): audio_data = "" while not input_finished_event.is_set(): try: - voice_call_data_json = json.loads(await ws.receive_text()) + voice_call_data_json = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=3)) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) is_future_done = True @@ -488,6 +503,9 @@ async def voice_call_audio_producer(ws,audio_q,future,input_finished_event): await audio_q.put(vad_frame) #将音频数据存入audio_q except KeyError as ke: logger.info(f"收到心跳包") + except asyncio.TimeoutError: + continue + #音频数据消费函数 async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event): @@ -496,93 +514,98 @@ async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event current_message = "" vad_count = 0 while not (input_finished_event.is_set() and audio_q.empty()): - audio_data = await audio_q.get() - if vad.is_speech(audio_data): - if vad_count > 0: - vad_count -= 1 - asr_result = asr.streaming_recognize(audio_data) - current_message += ''.join(asr_result['text']) - else: - vad_count += 1 - if vad_count >= 25: #连续25帧没有语音,则认为说完了 - asr_result = asr.streaming_recognize(audio_data, is_end=True) - if current_message: - logger.debug(f"检测到静默,用户输入为:{current_message}") - await asr_result_q.put(current_message) - text_response = {"type": "user_text", "code": 200, "msg": current_message} - await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 - current_message = "" - vad_count = 0 + try: + audio_data = await asyncio.wait_for(audio_q.get(),timeout=3) + if vad.is_speech(audio_data): + if vad_count > 0: + vad_count -= 1 + asr_result = asr.streaming_recognize(audio_data) + current_message += ''.join(asr_result['text']) + else: + vad_count += 1 + if vad_count >= 25: #连续25帧没有语音,则认为说完了 + asr_result = asr.streaming_recognize(audio_data, is_end=True) + if current_message: + logger.debug(f"检测到静默,用户输入为:{current_message}") + await asr_result_q.put(current_message) + text_response = {"type": "user_text", "code": 200, "msg": current_message} + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 + current_message = "" + vad_count = 0 + except asyncio.TimeoutError: + continue asr_finished_event.set() #asr结果消费以及llm返回生产函数 -async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event): +async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event): logger.debug("asr结果消费以及llm返回生产函数启动") - while not (asr_finished_event.is_set() and asr_result_q.empty()): - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) - current_message = await asr_result_q.get() - messages.append({'role': 'user', "content": current_message}) - payload = json.dumps({ - "model": llm_info["model"], - "stream": True, - "messages": messages, - "max_tokens":10000, - "temperature": llm_info["temperature"], - "top_p": llm_info["top_p"] - }) - - headers = { - 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", - 'Content-Type': 'application/json' - } - response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) - if response.status_code == 200: - for chunk in response.iter_lines(): - if chunk: - chunk_data =parseChunkDelta(chunk) - llm_frame = {'message':chunk_data,'is_end':False} - await llm_response_q.put(llm_frame) - llm_frame = {'message':"",'is_end':True} - await llm_response_q.put(llm_frame) - llm_finished_event.set() - -#llm结果返回函数 -async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event): - logger.debug("llm结果返回函数启动") llm_response = "" current_sentence = "" is_first = True - while not (llm_finished_event.is_set() and llm_response_q.empty()): - llm_frame = await llm_response_q.get() - llm_response += llm_frame['message'] - sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) - for sentence in sentences: - await split_result_q.put(sentence) - if llm_frame['is_end']: - is_first = True + is_end = False + while not (asr_finished_event.is_set() and asr_result_q.empty()): + try: session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) - messages.append({'role': 'assistant', "content": llm_response}) - session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 - redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session - logger.debug(f"llm返回结果: {llm_response}") - llm_response = "" - current_sentence = "" - split_finished_event.set() + current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3) + messages.append({'role': 'user', "content": current_message}) + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens":10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + async with httpx.AsyncClient() as client: + response = await client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) + async for chunk in response.aiter_lines(): + chunk_data = parseChunkDelta(chunk) + is_end = chunk_data == "end" + if not is_end: + llm_response += chunk_data + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) + for sentence in sentences: + sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) + text_response = {"type": "llm_text", "code": 200, "msg": sentence} + await ws.send_bytes(audio) #返回音频二进制流数据 + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 + logger.debug(f"llm返回结果: {sentence}") + if is_end: + logger.debug(f"llm返回结果: {llm_response}") + await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) + is_end = False + + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + is_first = True + llm_response = "" + except asyncio.TimeoutError: + continue + voice_call_end_event.set() + #语音合成及返回函数 async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event): logger.debug("语音合成及返回函数启动") while not (split_finished_event.is_set() and split_result_q.empty()): - sentence = await split_result_q.get() - sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) - text_response = {"type": "llm_text", "code": 200, "msg": sentence} - await ws.send_bytes(audio) #返回音频二进制流数据 - await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 - logger.debug(f"websocket返回:{sentence}") - asyncio.sleep(0.5) - await ws.close() + try: + sentence = await asyncio.wait_for(split_result_q.get(),timeout=3) + sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) + text_response = {"type": "llm_text", "code": 200, "msg": sentence} + await ws.send_bytes(audio) #返回音频二进制流数据 + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 + logger.debug(f"websocket返回:{sentence}") + except asyncio.TimeoutError: + continue voice_call_end_event.set() @@ -590,13 +613,10 @@ async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") audio_q = asyncio.Queue() #音频队列 asr_result_q = asyncio.Queue() #语音识别结果队列 - llm_response_q = asyncio.Queue() #大模型返回队列 - split_result_q = asyncio.Queue() #断句结果队列 + input_finished_event = asyncio.Event() #用户输入结束事件 asr_finished_event = asyncio.Event() #语音识别结束事件 - llm_finished_event = asyncio.Event() #大模型结束事件 - split_finished_event = asyncio.Event() #断句结束事件 voice_call_end_event = asyncio.Event() #语音电话终止事件 future = asyncio.Future() #用于获取传输的session_id @@ -608,10 +628,7 @@ async def voice_call_handler(ws, db, redis): tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) - asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者 - asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果 - asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果 - + asyncio.create_task(voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event)) #创建llm处理者 while not voice_call_end_event.is_set(): await asyncio.sleep(3) await ws.close()