forked from killua/TakwayPlatform
修复了收到心跳包后会断开连接的bug
This commit is contained in:
parent
20363bf9a1
commit
f1a844c84a
|
@ -70,10 +70,12 @@ def split_string_with_punctuation(current_sentence,text,is_first):
|
||||||
current_sentence = ''
|
current_sentence = ''
|
||||||
return result, current_sentence, is_first
|
return result, current_sentence, is_first
|
||||||
|
|
||||||
|
#vad预处理
|
||||||
def vad_preprocess(audio):
|
def vad_preprocess(audio):
|
||||||
if len(audio)<1280:
|
if len(audio)<1280:
|
||||||
return ('A'*1280)
|
return ('A'*1280)
|
||||||
return audio[:1280],audio[1280:]
|
return audio[:1280],audio[1280:]
|
||||||
|
|
||||||
#--------------------------------------------------------
|
#--------------------------------------------------------
|
||||||
|
|
||||||
# 创建新聊天
|
# 创建新聊天
|
||||||
|
@ -308,8 +310,9 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||||
async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event):
|
async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event):
|
||||||
logger.debug("用户输入处理函数启动")
|
logger.debug("用户输入处理函数启动")
|
||||||
is_future_done = False
|
is_future_done = False
|
||||||
try:
|
|
||||||
while not input_finished_event.is_set():
|
while not input_finished_event.is_set():
|
||||||
|
try:
|
||||||
scl_data_json = json.loads(await ws.receive_text())
|
scl_data_json = json.loads(await ws.receive_text())
|
||||||
if scl_data_json['is_close']:
|
if scl_data_json['is_close']:
|
||||||
input_finished_event.set()
|
input_finished_event.set()
|
||||||
|
@ -328,9 +331,11 @@ async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
await user_input_q.put(user_input_frame)
|
await user_input_q.put(user_input_frame)
|
||||||
user_input_frame = {"audio": scl_data_json['audio'], "is_end": False}
|
user_input_frame = {"audio": scl_data_json['audio'], "is_end": False}
|
||||||
await user_input_q.put(user_input_frame)
|
await user_input_q.put(user_input_frame)
|
||||||
except KeyError as ke:
|
except KeyError as ke:
|
||||||
if 'state' in scl_data_json and 'method' in scl_data_json:
|
if 'state' in scl_data_json and 'method' in scl_data_json:
|
||||||
logger.debug("收到心跳包")
|
logger.debug("收到心跳包")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
#语音识别
|
#语音识别
|
||||||
async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event):
|
async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event):
|
||||||
|
@ -463,12 +468,12 @@ async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
|
|
||||||
#--------------------------------语音通话接口--------------------------------------
|
#--------------------------------语音通话接口--------------------------------------
|
||||||
#音频数据生产函数
|
#音频数据生产函数
|
||||||
async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event):
|
async def voice_call_audio_producer(ws,audio_q,future,input_finished_event):
|
||||||
logger.debug("音频数据生产函数启动")
|
logger.debug("音频数据生产函数启动")
|
||||||
is_future_done = False
|
is_future_done = False
|
||||||
audio_data = ""
|
audio_data = ""
|
||||||
try:
|
while not input_finished_event.is_set():
|
||||||
while not input_finished_event.is_set():
|
try:
|
||||||
voice_call_data_json = json.loads(await ws.receive_text())
|
voice_call_data_json = json.loads(await ws.receive_text())
|
||||||
if not is_future_done: #在第一次循环中读取session_id
|
if not is_future_done: #在第一次循环中读取session_id
|
||||||
future.set_result(voice_call_data_json['meta_info']['session_id'])
|
future.set_result(voice_call_data_json['meta_info']['session_id'])
|
||||||
|
@ -480,12 +485,12 @@ async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event):
|
||||||
audio_data += voice_call_data_json["audio"]
|
audio_data += voice_call_data_json["audio"]
|
||||||
while len(audio_data) > 1280:
|
while len(audio_data) > 1280:
|
||||||
vad_frame,audio_data = vad_preprocess(audio_data)
|
vad_frame,audio_data = vad_preprocess(audio_data)
|
||||||
await audio_queue.put(vad_frame) #将音频数据存入audio_q
|
await audio_q.put(vad_frame) #将音频数据存入audio_q
|
||||||
except KeyError as ke:
|
except KeyError as ke:
|
||||||
logger.info(f"收到心跳包")
|
logger.info(f"收到心跳包")
|
||||||
|
|
||||||
#音频数据消费函数
|
#音频数据消费函数
|
||||||
async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event):
|
async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event):
|
||||||
logger.debug("音频数据消费者函数启动")
|
logger.debug("音频数据消费者函数启动")
|
||||||
vad = VAD()
|
vad = VAD()
|
||||||
current_message = ""
|
current_message = ""
|
||||||
|
@ -504,6 +509,8 @@ async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,as
|
||||||
if current_message:
|
if current_message:
|
||||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||||
await asr_result_q.put(current_message)
|
await asr_result_q.put(current_message)
|
||||||
|
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
||||||
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
current_message = ""
|
current_message = ""
|
||||||
vad_count = 0
|
vad_count = 0
|
||||||
asr_finished_event.set()
|
asr_finished_event.set()
|
||||||
|
@ -570,7 +577,7 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event
|
||||||
while not (split_finished_event.is_set() and split_result_q.empty()):
|
while not (split_finished_event.is_set() and split_result_q.empty()):
|
||||||
sentence = await split_result_q.get()
|
sentence = await split_result_q.get()
|
||||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||||
text_response = {"type": "text", "code": 200, "msg": sentence}
|
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
logger.debug(f"websocket返回:{sentence}")
|
logger.debug(f"websocket返回:{sentence}")
|
||||||
|
@ -581,19 +588,20 @@ async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event
|
||||||
|
|
||||||
async def voice_call_handler(ws, db, redis):
|
async def voice_call_handler(ws, db, redis):
|
||||||
logger.debug("voice_call websocket 连接建立")
|
logger.debug("voice_call websocket 连接建立")
|
||||||
audio_q = asyncio.Queue()
|
audio_q = asyncio.Queue() #音频队列
|
||||||
asr_result_q = asyncio.Queue()
|
asr_result_q = asyncio.Queue() #语音识别结果队列
|
||||||
llm_response_q = asyncio.Queue()
|
llm_response_q = asyncio.Queue() #大模型返回队列
|
||||||
split_result_q = asyncio.Queue()
|
split_result_q = asyncio.Queue() #断句结果队列
|
||||||
|
|
||||||
input_finished_event = asyncio.Event()
|
input_finished_event = asyncio.Event() #用户输入结束事件
|
||||||
asr_finished_event = asyncio.Event()
|
asr_finished_event = asyncio.Event() #语音识别结束事件
|
||||||
llm_finished_event = asyncio.Event()
|
llm_finished_event = asyncio.Event() #大模型结束事件
|
||||||
split_finished_event = asyncio.Event()
|
split_finished_event = asyncio.Event() #断句结束事件
|
||||||
voice_call_end_event = asyncio.Event()
|
voice_call_end_event = asyncio.Event() #语音电话终止事件
|
||||||
future = asyncio.Future()
|
|
||||||
|
future = asyncio.Future() #用于获取传输的session_id
|
||||||
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
|
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
|
||||||
asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
|
asyncio.create_task(voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
|
||||||
|
|
||||||
#获取session内容
|
#获取session内容
|
||||||
session_id = await future #获取session_id
|
session_id = await future #获取session_id
|
||||||
|
|
|
@ -10,5 +10,8 @@ class VAD():
|
||||||
self.min_act_time = min_act_time # 最小活动时间,单位秒
|
self.min_act_time = min_act_time # 最小活动时间,单位秒
|
||||||
|
|
||||||
def is_speech(self,data):
|
def is_speech(self,data):
|
||||||
byte_data = base64.b64decode(data)
|
try:
|
||||||
return self.vad.is_speech(byte_data, self.RATE)
|
byte_data = base64.b64decode(data)
|
||||||
|
return self.vad.is_speech(byte_data, self.RATE)
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
Loading…
Reference in New Issue