1
0
Fork 0

更新持续流式聊天接口和语音电话接口,使得其可以正常处理完所有数据再退出

This commit is contained in:
killua4396 2024-05-02 09:55:47 +08:00
parent 9f62dbe694
commit f548d57595
1 changed files with 176 additions and 190 deletions

View File

@ -297,95 +297,64 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
#--------------------------------------------------------------------------------------------------- #---------------------------------------------------------------------------------------------------
# 持续流式聊天
async def streaming_chat_lasting_handler(ws, db, redis): #------------------------------------------持续流式聊天----------------------------------------------
print("Websocket连接成功") #处理用户输入
while True: async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event):
logger.debug("用户输入处理函数启动")
is_future_done = False
try: try:
print("等待接受") while not input_finished_event.is_set():
data = await asyncio.wait_for(ws.receive_text(), timeout=60) scl_data_json = json.loads(await ws.receive_text())
data_json = json.loads(data) if scl_data_json['is_close']:
if data_json["is_close"]: input_finished_event.set()
close_message = {"type": "close", "code": 200, "msg": ""}
await ws.send_text(json.dumps(close_message, ensure_ascii=False))
print("连接关闭")
await asyncio.sleep(0.5)
await ws.close()
return;
except asyncio.TimeoutError:
print("连接超时")
await ws.close()
return;
current_message = "" # 用于存储用户消息
response_type = RESPONSE_TEXT # 用于获取返回类型
session_id = ""
if Config.STRAM_CHAT.ASR == "LOCAL":
try:
while True:
if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入
if data_json["meta_info"]["voice_synthesize"]:
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
session_id = data_json["meta_info"]["session_id"]
current_message = data_json['text']
break break
if not is_future_done:
if not data_json['meta_info']['is_end']: # 还在发 future_session_id.set_result(scl_data_json['meta_info']['session_id'])
asr_result = asr.streaming_recognize(data_json["audio"]) if scl_data_json['meta_info']['voice_synthesize']:
current_message += ''.join(asr_result['text']) future_response_type.set_result(RESPONSE_AUDIO)
else: # 发完了
asr_result = asr.streaming_recognize(data_json["audio"], is_end=True)
session_id = data_json["meta_info"]["session_id"]
current_message += ''.join(asr_result['text'])
if data_json["meta_info"]["voice_synthesize"]:
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
break
data_json = json.loads(await ws.receive_text())
except Exception as e:
error_info = f"接收用户消息错误: {str(e)}"
error_message = {"type": "error", "code": "500", "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
elif Config.STRAM_CHAT.ASR == "REMOTE":
error_info = f"远程ASR服务暂未开通"
error_message = {"type": "error", "code": "500", "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
print(f"接收到用户消息: {current_message}")
# 查询Session记录
session_content_str = ""
if redis.exists(session_id):
session_content_str = redis.get(session_id)
else: else:
session_db = db.query(Session).filter(Session.id == session_id).first() future_response_type.set_result(RESPONSE_TEXT)
if not session_db: is_future_done = True
error_info = f"未找到session记录: {str(e)}" if scl_data_json['text']:
error_message = {"type": "error", "code": 500, "msg": error_info} await llm_input_q.put(scl_data_json['text'])
logger.error(error_info) if scl_data_json['meta_info']['is_end']:
await ws.send_text(json.dumps(error_message, ensure_ascii=False)) user_input_frame = {"audio": scl_data_json['audio'], "is_end": True}
await ws.close() await user_input_q.put(user_input_frame)
return user_input_frame = {"audio": scl_data_json['audio'], "is_end": False}
session_content_str = session_db.content await user_input_q.put(user_input_frame)
except KeyError as ke:
if 'state' in scl_data_json and 'method' in scl_data_json:
logger.debug("收到心跳包")
session_content = json.loads(session_content_str) #语音识别
llm_info = json.loads(session_content["llm_info"]) async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event):
tts_info = json.loads(session_content["tts_info"]) logger.debug("语音识别函数启动")
user_info = json.loads(session_content["user_info"]) current_message = ""
while not (input_finished_event.is_set() and user_input_q.empty()):
aduio_frame = await user_input_q.get()
if aduio_frame['is_end']:
asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True)
current_message += ''.join(asr_result['text'])
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
else:
asr_result = asr.streaming_recognize(aduio_frame['audio'])
current_message += ''.join(asr_result['text'])
asr_finished_event.set()
#大模型调用
async def scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event):
logger.debug("llm调用函数启动")
while not (asr_finished_event.is_set() and llm_input_q.empty()):
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get()
messages.append({'role': 'user', "content": current_message}) messages.append({'role': 'user', "content": current_message})
token_count = session_content["token"]
try:
payload = json.dumps({ payload = json.dumps({
"model": llm_info["model"], "model": llm_info["model"],
"stream": True, "stream": True,
"messages": messages, "messages": messages,
"tool_choice": "auto",
"max_tokens": 10000, "max_tokens": 10000,
"temperature": llm_info["temperature"], "temperature": llm_info["temperature"],
"top_p": llm_info["top_p"] "top_p": llm_info["top_p"]
@ -395,112 +364,121 @@ async def streaming_chat_lasting_handler(ws, db, redis):
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
except Exception as e:
error_info = f"发送信息给大模型时发生错误: {str(e)}"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
def split_string_with_punctuation(text):
punctuations = "!?。"
result = []
current_sentence = ""
for char in text:
current_sentence += char
if char in punctuations:
result.append(current_sentence)
current_sentence = ""
# 判断最后一个字符是否为标点符号
if current_sentence and current_sentence[-1] not in punctuations:
# 如果最后一段不以标点符号结尾,则加入拆分数组
result.append(current_sentence)
return result
llm_response = ""
response_buf = ""
try:
if Config.STRAM_CHAT.TTS == "LOCAL":
if response.status_code == 200: if response.status_code == 200:
for chunk in response.iter_lines(): for chunk in response.iter_lines():
if chunk: if chunk:
if response_type == RESPONSE_AUDIO:
chunk_data = parseChunkDelta(chunk) chunk_data = parseChunkDelta(chunk)
llm_response += chunk_data llm_frame = {"message": chunk_data, "is_end": False}
response_buf += chunk_data await llm_response_q.put(llm_frame)
split_buf = split_string_with_punctuation(response_buf) llm_frame = {"message": "", "is_end": True}
response_buf = "" await llm_response_q.put(llm_frame)
if len(split_buf) != 0: llm_finished_event.set()
for sentence in split_buf:
sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
await ws.send_bytes(audio) # 返回音频二进制流数据
if response_type == RESPONSE_TEXT:
chunk_data = parseChunkDelta(chunk)
llm_response += chunk_data
text_response = {"type": "text", "code": 200, "msg": chunk_data}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
elif Config.STRAM_CHAT.TTS == "REMOTE": #大模型返回断句
error_info = f"暂不支持远程音频合成" async def scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event):
error_message = {"type": "error", "code": 500, "msg": error_info} logger.debug("llm返回处理函数启动")
logger.error(error_info) llm_response = ""
await ws.send_text(json.dumps(error_message, ensure_ascii=False)) current_sentence = ""
await ws.close() is_first = True
return while not (llm_finished_event.is_set() and llm_response_q.empty()):
end_response = {"type": "end", "code": 200, "msg": ""} llm_frame = await llm_response_q.get()
await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束 llm_response += llm_frame['message']
print(f"llm消息: {llm_response}") sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
except Exception as e: for sentence in sentences:
error_info = f"音频合成与向前端返回时错误: {str(e)}" sentence_frame = {"message": sentence, "is_end": False}
error_message = {"type": "error", "code": 500, "msg": error_info} await split_result_q.put(sentence_frame)
logger.error(error_info) if llm_frame['is_end']:
await ws.send_text(json.dumps(error_message, ensure_ascii=False)) sentence_frame = {"message": "", "is_end": True}
await ws.close() await split_result_q.put(sentence_frame)
return is_first = True
session_content = get_session_content(session_id,redis,db)
try: messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response}) messages.append({'role': 'assistant', "content": llm_response})
token_count += len(llm_response) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
session_content["messages"] = json.dumps(messages, ensure_ascii=False) redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
session_content["token"] = token_count logger.debug(f"llm返回结果: {llm_response}")
redis.set(session_id, json.dumps(session_content, ensure_ascii=False)) llm_response = ""
except Exception as e: current_sentence = ""
error_info = f"更新session时错误: {str(e)}" split_finished_event.set()
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info) #文本返回及语音合成
await ws.send_text(json.dumps(error_message, ensure_ascii=False)) async def scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event):
logger.debug("返回处理函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
sentence_frame = await split_result_q.get()
sentence = sentence_frame['message']
if sentence_frame['is_end']:
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
continue
if response_type == RESPONSE_TEXT:
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}")
chat_finish_event.set()
async def streaming_chat_lasting_handler(ws,db,redis):
logger.debug("streaming chat lasting websocket 连接建立")
user_input_q = asyncio.Queue() # 用于存储用户输入
llm_input_q = asyncio.Queue() # 用于存储llm输入
llm_response_q = asyncio.Queue() # 用于存储llm输出
split_result_q = asyncio.Queue() # 用于存储llm返回后断句输出
input_finished_event = asyncio.Event()
asr_finished_event = asyncio.Event()
llm_finished_event = asyncio.Event()
split_finished_event = asyncio.Event()
chat_finish_event = asyncio.Event()
future_session_id = asyncio.Future()
future_response_type = asyncio.Future()
asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event))
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
session_id = await future_session_id #获取session_id
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(scl_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,asr_finished_event,llm_finished_event))
asyncio.create_task(scl_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event))
asyncio.create_task(scl_response_handler(ws,tts_info,response_type,split_result_q,split_finished_event,chat_finish_event))
while not chat_finish_event.is_set():
await asyncio.sleep(3)
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
await ws.close() await ws.close()
return logger.debug("streaming chat lasting websocket 连接断开")
print("处理完毕") #---------------------------------------------------------------------------------------------------
#--------------------------------语音通话接口-------------------------------------- #--------------------------------语音通话接口--------------------------------------
#音频数据生产函数 #音频数据生产函数
async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event): async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event):
logger.debug("音频数据生产函数启动") logger.debug("音频数据生产函数启动")
is_future_done = False is_future_done = False
while not shutdown_event.is_set(): while not input_finished_event.is_set():
voice_call_data_json = json.loads(await ws.receive_text()) voice_call_data_json = json.loads(await ws.receive_text())
if not is_future_done: #在第一次循环中读取session_id if not is_future_done: #在第一次循环中读取session_id
future.set_result(voice_call_data_json['meta_info']['session_id']) future.set_result(voice_call_data_json['meta_info']['session_id'])
is_future_done = True is_future_done = True
if voice_call_data_json["is_close"]: if voice_call_data_json["is_close"]:
shutdown_event.set() input_finished_event.set()
break break
else: else:
await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q
#音频数据消费函数 #音频数据消费函数
async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event): async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event):
logger.debug("音频数据消费者函数启动") logger.debug("音频数据消费者函数启动")
vad = VAD() vad = VAD()
current_message = "" current_message = ""
vad_count = 0 vad_count = 0
while not (shutdown_event.is_set() and audio_q.empty()): while not (input_finished_event.is_set() and audio_q.empty()):
audio_data = await audio_q.get() audio_data = await audio_q.get()
if vad.is_speech(audio_data): if vad.is_speech(audio_data):
if vad_count > 0: if vad_count > 0:
@ -516,11 +494,12 @@ async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event):
await asr_result_q.put(current_message) await asr_result_q.put(current_message)
current_message = "" current_message = ""
vad_count = 0 vad_count = 0
asr_finished_event.set()
#asr结果消费以及llm返回生产函数 #asr结果消费以及llm返回生产函数
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event): async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event):
logger.debug("asr结果消费以及llm返回生产函数启动") logger.debug("asr结果消费以及llm返回生产函数启动")
while not (shutdown_event.is_set() and asr_result_q.empty()): while not (asr_finished_event.is_set() and asr_result_q.empty()):
session_content = get_session_content(session_id,redis,db) session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
current_message = await asr_result_q.get() current_message = await asr_result_q.get()
@ -547,14 +526,15 @@ async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_r
await llm_response_q.put(llm_frame) await llm_response_q.put(llm_frame)
llm_frame = {'message':"",'is_end':True} llm_frame = {'message':"",'is_end':True}
await llm_response_q.put(llm_frame) await llm_response_q.put(llm_frame)
llm_finished_event.set()
#llm结果返回函数 #llm结果返回函数
async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event): async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event):
logger.debug("llm结果返回函数启动") logger.debug("llm结果返回函数启动")
llm_response = "" llm_response = ""
current_sentence = "" current_sentence = ""
is_first = True is_first = True
while not (shutdown_event.is_set() and llm_response_q.empty()): while not (llm_finished_event.is_set() and llm_response_q.empty()):
llm_frame = await llm_response_q.get() llm_frame = await llm_response_q.get()
llm_response += llm_frame['message'] llm_response += llm_frame['message']
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
@ -570,11 +550,12 @@ async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,sp
logger.debug(f"llm返回结果: {llm_response}") logger.debug(f"llm返回结果: {llm_response}")
llm_response = "" llm_response = ""
current_sentence = "" current_sentence = ""
split_finished_event.set()
#语音合成及返回函数 #语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event): async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动") logger.debug("语音合成及返回函数启动")
while not (shutdown_event.is_set() and split_result_q.empty()): while not (split_finished_event.is_set() and split_result_q.empty()):
sentence = await split_result_q.get() sentence = await split_result_q.get()
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence} text_response = {"type": "text", "code": 200, "msg": sentence}
@ -583,6 +564,7 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event):
logger.debug(f"websocket返回:{sentence}") logger.debug(f"websocket返回:{sentence}")
asyncio.sleep(0.5) asyncio.sleep(0.5)
await ws.close() await ws.close()
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis): async def voice_call_handler(ws, db, redis):
@ -592,22 +574,26 @@ async def voice_call_handler(ws, db, redis):
llm_response_q = asyncio.Queue() llm_response_q = asyncio.Queue()
split_result_q = asyncio.Queue() split_result_q = asyncio.Queue()
shutdown_event = asyncio.Event() input_finished_event = asyncio.Event()
asr_finished_event = asyncio.Event()
llm_finished_event = asyncio.Event()
split_finished_event = asyncio.Event()
voice_call_end_event = asyncio.Event()
future = asyncio.Future() future = asyncio.Future()
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者 asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者 asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
#获取session内容 #获取session内容
session_id = await future #获取session_id session_id = await future #获取session_id
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者 asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者
asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果 asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果
asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果 asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果
while not shutdown_event.is_set(): while not voice_call_end_event.is_set():
await asyncio.sleep(5) await asyncio.sleep(3)
await ws.close() await ws.close()
logger.debug("voice_call websocket 连接断开") logger.debug("voice_call websocket 连接断开")
#------------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------------