diff --git a/.gitignore b/.gitignore index c28b5ad..78ba9c6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,7 @@ vits_model nohup.out /app/takway-ai.top.key -/app/takway-ai.top.pem \ No newline at end of file +/app/takway-ai.top.pem + +tests/assets/BarbieDollsVoice.mp3 +utils/tts/openvoice_model/checkpoint.pth diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 0556153..eee3cfc 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -8,7 +8,7 @@ from ..models import UserCharacter, Session, Character, User, Audio from utils.audio_utils import VAD from fastapi import WebSocket, HTTPException, status from datetime import datetime -from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv +from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from config import get_config import numpy as np import websockets @@ -280,22 +280,41 @@ async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_fini logger.debug(f"接收到用户消息: {current_message}") elif Config.STRAM_CHAT.ASR == "XF": status = FIRST_FRAME - async with websockets.connect(generate_xf_asr_url()) as xf_websocket: - while not (user_input_finish_event.is_set() and user_input_q.empty()): + xf_websocket = await xf_asr_websocket_factory() #获取一个讯飞语音识别接口websocket连接 + segment_duration_threshold = 25 #设置一个连接时长上限,讯飞语音接口超过30秒会自动断开连接,所以该值设置成25秒 + segment_start_time = asyncio.get_event_loop().time() + current_message = "" + while not (user_input_finish_event.is_set() and user_input_q.empty()): + try: audio_data = await user_input_q.get() + current_time = asyncio.get_event_loop().time() + if current_time - segment_start_time > segment_duration_threshold: + await xf_websocket.send(make_last_frame()) + current_message += parse_xfasr_recv(await xf_websocket.recv()) + await xf_websocket.close() + xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接 + status = FIRST_FRAME + segment_start_time = current_time if status == FIRST_FRAME: await xf_websocket.send(make_first_frame(audio_data)) status = CONTINUE_FRAME elif status == CONTINUE_FRAME: await xf_websocket.send(make_continue_frame(audio_data)) - await xf_websocket.send(make_last_frame("")) - - current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv())) - if current_message in ["","嗯"]: - await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) - return - await llm_input_q.put(current_message) - logger.debug(f"接收到用户消息: {current_message}") + except websockets.exceptions.ConnectionClosedOK: + logger.debug("讯飞语音识别接口连接断开,重新创建连接") + xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接 + status = FIRST_FRAME + segment_start_time = asyncio.get_event_loop().time() + await xf_websocket.send(make_last_frame()) + current_message += parse_xfasr_recv(await xf_websocket.recv()) + await xf_websocket.close() + if current_message in ["嗯", ""]: + await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) + return + await llm_input_q.put(current_message) + logger.debug(f"接收到用户消息: {current_message}") + + #大模型调用 async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): diff --git a/utils/xf_asr_utils.py b/utils/xf_asr_utils.py index 7f9bb3f..7fa2000 100644 --- a/utils/xf_asr_utils.py +++ b/utils/xf_asr_utils.py @@ -1,3 +1,4 @@ +import websockets import datetime import hashlib import base64 @@ -61,4 +62,8 @@ def parse_xfasr_recv(message): for i in data: for w in i['cw']: result += w['w'] - return result \ No newline at end of file + return result + +async def xf_asr_websocket_factory(): + url = generate_xf_asr_url() + return await websockets.connect(url) \ No newline at end of file