From 09ffeb6ab6586089c3e468c02d4917b880f7f26f Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Tue, 4 Jun 2024 18:05:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E8=AE=AF=E9=A3=9Easr?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat_controller.py | 77 ++++++++++++++++++------------ config/development.py | 8 ++-- utils/xf_asr_utils.py | 27 +++++++++++ 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 5b5a8b9..2bf22d0 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -8,9 +8,10 @@ from ..models import UserCharacter, Session, Character, User, Audio from utils.audio_utils import VAD from fastapi import WebSocket, HTTPException, status from datetime import datetime -from utils.xf_asr_utils import generate_xf_asr_url +from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from config import get_config import numpy as np +import websockets import struct import uuid import json @@ -114,7 +115,7 @@ def get_emb(session_id,db): emb_npy = np.load(io.BytesIO(audio_record.emb_data)) return emb_npy except Exception as e: - logger.error("未找到音频:"+str(e)) + logger.debug("未找到音频:"+str(e)) return np.array([]) #-------------------------------------------------------- @@ -247,35 +248,51 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f #语音识别 async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") - is_signup = False - audio = "" - try: - current_message = "" - while not (user_input_finish_event.is_set() and user_input_q.empty()): - if not is_signup: - asr.session_signup(session_id) - is_signup = True - audio_data = await user_input_q.get() - audio += audio_data - asr_result = asr.streaming_recognize(session_id,audio_data) + if Config.STRAM_CHAT.ASR == "LOCAL": + is_signup = False + audio = "" + try: + current_message = "" + while not (user_input_finish_event.is_set() and user_input_q.empty()): + if not is_signup: + asr.session_signup(session_id) + is_signup = True + audio_data = await user_input_q.get() + audio += audio_data + asr_result = asr.streaming_recognize(session_id,audio_data) + current_message += ''.join(asr_result['text']) + asr_result = asr.streaming_recognize(session_id,b'',is_end=True) current_message += ''.join(asr_result['text']) - asr_result = asr.streaming_recognize(session_id,b'',is_end=True) - current_message += ''.join(asr_result['text']) - if current_message == "": - await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) - return - current_message = asr.punctuation_correction(current_message) - emotion_dict = asr.emtion_recognition(audio) #情感辨识 - if not isinstance(emotion_dict, str): - max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) - current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}" - await llm_input_q.put(current_message) - asr.session_signout(session_id) - except Exception as e: - asr.session_signout(session_id) - logger.error(f"语音识别函数发生错误: {str(e)}") - logger.debug(f"接收到用户消息: {current_message}") - + if current_message == "": + await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) + return + current_message = asr.punctuation_correction(current_message) + emotion_dict = asr.emtion_recognition(audio) #情感辨识 + if not isinstance(emotion_dict, str): + max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) + current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}" + await llm_input_q.put(current_message) + asr.session_signout(session_id) + except Exception as e: + asr.session_signout(session_id) + logger.error(f"语音识别函数发生错误: {str(e)}") + logger.debug(f"接收到用户消息: {current_message}") + elif Config.STRAM_CHAT.ASR == "XF": + status = FIRST_FRAME + async with websockets.connect(generate_xf_asr_url()) as xf_websocket: + while not (user_input_finish_event.is_set() and user_input_q.empty()): + audio_data = await user_input_q.get() + if status == FIRST_FRAME: + await xf_websocket.send(make_first_frame(audio_data)) + status = CONTINUE_FRAME + elif status == CONTINUE_FRAME: + await xf_websocket.send(make_continue_frame(audio_data)) + await xf_websocket.send(make_last_frame("")) + + current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv())) + await llm_input_q.put(current_message) + logger.debug(f"接收到用户消息: {current_message}") + #大模型调用 async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): logger.debug("llm调用函数启动") diff --git a/config/development.py b/config/development.py index cffca47..16b7fba 100644 --- a/config/development.py +++ b/config/development.py @@ -8,9 +8,9 @@ class DevelopmentConfig: PORT = 8001 #uvicorn运行端口 WORKERS = 12 #uvicorn进程数(通常与cpu核数相同) class XF_ASR: - APP_ID = "your_app_id" #讯飞语音识别APP_ID - API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET - API_KEY = "your_api_key" #讯飞语音识别API_KEY + APP_ID = "f1c121c1" #讯飞语音识别APP_ID + API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET + API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY DOMAIN = "iat" LANGUAGE = "zh_cn" ACCENT = "mandarin" @@ -23,6 +23,6 @@ class DevelopmentConfig: URL = "https://api.minimax.chat/v1/t2a_pro", GROUP_ID ="1759482180095975904" class STRAM_CHAT: - ASR = "LOCAL" + ASR = "XF" # 语音识别引擎,可选XF或者LOCAL TTS = "LOCAL" \ No newline at end of file diff --git a/utils/xf_asr_utils.py b/utils/xf_asr_utils.py index d141a8e..e92a866 100644 --- a/utils/xf_asr_utils.py +++ b/utils/xf_asr_utils.py @@ -2,6 +2,7 @@ import datetime import hashlib import base64 import hmac +import json from urllib.parse import urlencode from wsgiref.handlers import format_date_time from datetime import datetime @@ -35,3 +36,29 @@ def generate_xf_asr_url(): } url = url + '?' + urlencode(v) return url + + +def make_first_frame(buf): + first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000}, + "data":{"status":0,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}} + return json.dumps(first_frame) + +def make_continue_frame(buf): + continue_frame = {"data":{"status":1,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}} + return json.dumps(continue_frame) + +def make_last_frame(buf): + last_frame = {"data":{"status":2,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}} + return json.dumps(last_frame) + +def parse_xfasr_recv(message): + code = message['code'] + if code!=0: + raise Exception("讯飞ASR错误码:"+str(code)) + else: + data = message['data']['result']['ws'] + result = "" + for i in data: + for w in i['cw']: + result += w['w'] + return result \ No newline at end of file