feat: 增加讯飞asr功能

This commit is contained in:
killua4396 2024-06-04 18:05:23 +08:00
parent 30fdb9c6bd
commit 09ffeb6ab6
3 changed files with 78 additions and 34 deletions

View File

@ -8,9 +8,10 @@ from ..models import UserCharacter, Session, Character, User, Audio
from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status
from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from config import get_config
import numpy as np
import websockets
import struct
import uuid
import json
@ -114,7 +115,7 @@ def get_emb(session_id,db):
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
return emb_npy
except Exception as e:
logger.error("未找到音频:"+str(e))
logger.debug("未找到音频:"+str(e))
return np.array([])
#--------------------------------------------------------
@ -247,35 +248,51 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
#语音识别
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动")
is_signup = False
audio = ""
try:
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
if not is_signup:
asr.session_signup(session_id)
is_signup = True
audio_data = await user_input_q.get()
audio += audio_data
asr_result = asr.streaming_recognize(session_id,audio_data)
if Config.STRAM_CHAT.ASR == "LOCAL":
is_signup = False
audio = ""
try:
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
if not is_signup:
asr.session_signup(session_id)
is_signup = True
audio_data = await user_input_q.get()
audio += audio_data
asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text'])
if current_message == "":
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
current_message = asr.punctuation_correction(current_message)
emotion_dict = asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message)
asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
if current_message == "":
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
current_message = asr.punctuation_correction(current_message)
emotion_dict = asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message)
asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
elif Config.STRAM_CHAT.ASR == "XF":
status = FIRST_FRAME
async with websockets.connect(generate_xf_asr_url()) as xf_websocket:
while not (user_input_finish_event.is_set() and user_input_q.empty()):
audio_data = await user_input_q.get()
if status == FIRST_FRAME:
await xf_websocket.send(make_first_frame(audio_data))
status = CONTINUE_FRAME
elif status == CONTINUE_FRAME:
await xf_websocket.send(make_continue_frame(audio_data))
await xf_websocket.send(make_last_frame(""))
current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv()))
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
logger.debug("llm调用函数启动")

View File

@ -8,9 +8,9 @@ class DevelopmentConfig:
PORT = 8001 #uvicorn运行端口
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
class XF_ASR:
APP_ID = "your_app_id" #讯飞语音识别APP_ID
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
API_KEY = "your_api_key" #讯飞语音识别API_KEY
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
DOMAIN = "iat"
LANGUAGE = "zh_cn"
ACCENT = "mandarin"
@ -23,6 +23,6 @@ class DevelopmentConfig:
URL = "https://api.minimax.chat/v1/t2a_pro",
GROUP_ID ="1759482180095975904"
class STRAM_CHAT:
ASR = "LOCAL"
ASR = "XF" # 语音识别引擎可选XF或者LOCAL
TTS = "LOCAL"

View File

@ -2,6 +2,7 @@ import datetime
import hashlib
import base64
import hmac
import json
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
@ -35,3 +36,29 @@ def generate_xf_asr_url():
}
url = url + '?' + urlencode(v)
return url
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
"data":{"status":0,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)
def make_continue_frame(buf):
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(continue_frame)
def make_last_frame(buf):
last_frame = {"data":{"status":2,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(last_frame)
def parse_xfasr_recv(message):
code = message['code']
if code!=0:
raise Exception("讯飞ASR错误码"+str(code))
else:
data = message['data']['result']['ws']
result = ""
for i in data:
for w in i['cw']:
result += w['w']
return result