feat: 增加讯飞asr功能
This commit is contained in:
parent
30fdb9c6bd
commit
09ffeb6ab6
|
@ -8,9 +8,10 @@ from ..models import UserCharacter, Session, Character, User, Audio
|
|||
from utils.audio_utils import VAD
|
||||
from fastapi import WebSocket, HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.xf_asr_utils import generate_xf_asr_url
|
||||
from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||
from config import get_config
|
||||
import numpy as np
|
||||
import websockets
|
||||
import struct
|
||||
import uuid
|
||||
import json
|
||||
|
@ -114,7 +115,7 @@ def get_emb(session_id,db):
|
|||
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
||||
return emb_npy
|
||||
except Exception as e:
|
||||
logger.error("未找到音频:"+str(e))
|
||||
logger.debug("未找到音频:"+str(e))
|
||||
return np.array([])
|
||||
#--------------------------------------------------------
|
||||
|
||||
|
@ -247,35 +248,51 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
|||
#语音识别
|
||||
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||
logger.debug("语音识别函数启动")
|
||||
is_signup = False
|
||||
audio = ""
|
||||
try:
|
||||
current_message = ""
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
if not is_signup:
|
||||
asr.session_signup(session_id)
|
||||
is_signup = True
|
||||
audio_data = await user_input_q.get()
|
||||
audio += audio_data
|
||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||
if Config.STRAM_CHAT.ASR == "LOCAL":
|
||||
is_signup = False
|
||||
audio = ""
|
||||
try:
|
||||
current_message = ""
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
if not is_signup:
|
||||
asr.session_signup(session_id)
|
||||
is_signup = True
|
||||
audio_data = await user_input_q.get()
|
||||
audio += audio_data
|
||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
if current_message == "":
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||
return
|
||||
current_message = asr.punctuation_correction(current_message)
|
||||
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
await llm_input_q.put(current_message)
|
||||
asr.session_signout(session_id)
|
||||
except Exception as e:
|
||||
asr.session_signout(session_id)
|
||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
|
||||
if current_message == "":
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||
return
|
||||
current_message = asr.punctuation_correction(current_message)
|
||||
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
await llm_input_q.put(current_message)
|
||||
asr.session_signout(session_id)
|
||||
except Exception as e:
|
||||
asr.session_signout(session_id)
|
||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
elif Config.STRAM_CHAT.ASR == "XF":
|
||||
status = FIRST_FRAME
|
||||
async with websockets.connect(generate_xf_asr_url()) as xf_websocket:
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
audio_data = await user_input_q.get()
|
||||
if status == FIRST_FRAME:
|
||||
await xf_websocket.send(make_first_frame(audio_data))
|
||||
status = CONTINUE_FRAME
|
||||
elif status == CONTINUE_FRAME:
|
||||
await xf_websocket.send(make_continue_frame(audio_data))
|
||||
await xf_websocket.send(make_last_frame(""))
|
||||
|
||||
current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv()))
|
||||
await llm_input_q.put(current_message)
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
|
||||
#大模型调用
|
||||
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
||||
logger.debug("llm调用函数启动")
|
||||
|
|
|
@ -8,9 +8,9 @@ class DevelopmentConfig:
|
|||
PORT = 8001 #uvicorn运行端口
|
||||
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
||||
class XF_ASR:
|
||||
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
||||
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
||||
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
|
||||
DOMAIN = "iat"
|
||||
LANGUAGE = "zh_cn"
|
||||
ACCENT = "mandarin"
|
||||
|
@ -23,6 +23,6 @@ class DevelopmentConfig:
|
|||
URL = "https://api.minimax.chat/v1/t2a_pro",
|
||||
GROUP_ID ="1759482180095975904"
|
||||
class STRAM_CHAT:
|
||||
ASR = "LOCAL"
|
||||
ASR = "XF" # 语音识别引擎,可选XF或者LOCAL
|
||||
TTS = "LOCAL"
|
||||
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
|
@ -35,3 +36,29 @@ def generate_xf_asr_url():
|
|||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
def make_first_frame(buf):
|
||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
||||
"data":{"status":0,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(first_frame)
|
||||
|
||||
def make_continue_frame(buf):
|
||||
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(continue_frame)
|
||||
|
||||
def make_last_frame(buf):
|
||||
last_frame = {"data":{"status":2,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(last_frame)
|
||||
|
||||
def parse_xfasr_recv(message):
|
||||
code = message['code']
|
||||
if code!=0:
|
||||
raise Exception("讯飞ASR错误码:"+str(code))
|
||||
else:
|
||||
data = message['data']['result']['ws']
|
||||
result = ""
|
||||
for i in data:
|
||||
for w in i['cw']:
|
||||
result += w['w']
|
||||
return result
|
Loading…
Reference in New Issue