feat: 增加讯飞asr功能

This commit is contained in:
killua4396 2024-06-04 18:05:23 +08:00
parent 30fdb9c6bd
commit 09ffeb6ab6
3 changed files with 78 additions and 34 deletions

View File

@ -8,9 +8,10 @@ from ..models import UserCharacter, Session, Character, User, Audio
from utils.audio_utils import VAD from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status from fastapi import WebSocket, HTTPException, status
from datetime import datetime from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from config import get_config from config import get_config
import numpy as np import numpy as np
import websockets
import struct import struct
import uuid import uuid
import json import json
@ -114,7 +115,7 @@ def get_emb(session_id,db):
emb_npy = np.load(io.BytesIO(audio_record.emb_data)) emb_npy = np.load(io.BytesIO(audio_record.emb_data))
return emb_npy return emb_npy
except Exception as e: except Exception as e:
logger.error("未找到音频:"+str(e)) logger.debug("未找到音频:"+str(e))
return np.array([]) return np.array([])
#-------------------------------------------------------- #--------------------------------------------------------
@ -247,35 +248,51 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
#语音识别 #语音识别
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event): async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动") logger.debug("语音识别函数启动")
is_signup = False if Config.STRAM_CHAT.ASR == "LOCAL":
audio = "" is_signup = False
try: audio = ""
current_message = "" try:
while not (user_input_finish_event.is_set() and user_input_q.empty()): current_message = ""
if not is_signup: while not (user_input_finish_event.is_set() and user_input_q.empty()):
asr.session_signup(session_id) if not is_signup:
is_signup = True asr.session_signup(session_id)
audio_data = await user_input_q.get() is_signup = True
audio += audio_data audio_data = await user_input_q.get()
asr_result = asr.streaming_recognize(session_id,audio_data) audio += audio_data
asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True) if current_message == "":
current_message += ''.join(asr_result['text']) await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
if current_message == "": return
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) current_message = asr.punctuation_correction(current_message)
return emotion_dict = asr.emtion_recognition(audio) #情感辨识
current_message = asr.punctuation_correction(current_message) if not isinstance(emotion_dict, str):
emotion_dict = asr.emtion_recognition(audio) #情感辨识 max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
if not isinstance(emotion_dict, str): current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) await llm_input_q.put(current_message)
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}" asr.session_signout(session_id)
await llm_input_q.put(current_message) except Exception as e:
asr.session_signout(session_id) asr.session_signout(session_id)
except Exception as e: logger.error(f"语音识别函数发生错误: {str(e)}")
asr.session_signout(session_id) logger.debug(f"接收到用户消息: {current_message}")
logger.error(f"语音识别函数发生错误: {str(e)}") elif Config.STRAM_CHAT.ASR == "XF":
logger.debug(f"接收到用户消息: {current_message}") status = FIRST_FRAME
async with websockets.connect(generate_xf_asr_url()) as xf_websocket:
while not (user_input_finish_event.is_set() and user_input_q.empty()):
audio_data = await user_input_q.get()
if status == FIRST_FRAME:
await xf_websocket.send(make_first_frame(audio_data))
status = CONTINUE_FRAME
elif status == CONTINUE_FRAME:
await xf_websocket.send(make_continue_frame(audio_data))
await xf_websocket.send(make_last_frame(""))
current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv()))
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用 #大模型调用
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
logger.debug("llm调用函数启动") logger.debug("llm调用函数启动")

View File

@ -8,9 +8,9 @@ class DevelopmentConfig:
PORT = 8001 #uvicorn运行端口 PORT = 8001 #uvicorn运行端口
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同) WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
class XF_ASR: class XF_ASR:
APP_ID = "your_app_id" #讯飞语音识别APP_ID APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
API_KEY = "your_api_key" #讯飞语音识别API_KEY API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
DOMAIN = "iat" DOMAIN = "iat"
LANGUAGE = "zh_cn" LANGUAGE = "zh_cn"
ACCENT = "mandarin" ACCENT = "mandarin"
@ -23,6 +23,6 @@ class DevelopmentConfig:
URL = "https://api.minimax.chat/v1/t2a_pro", URL = "https://api.minimax.chat/v1/t2a_pro",
GROUP_ID ="1759482180095975904" GROUP_ID ="1759482180095975904"
class STRAM_CHAT: class STRAM_CHAT:
ASR = "LOCAL" ASR = "XF" # 语音识别引擎可选XF或者LOCAL
TTS = "LOCAL" TTS = "LOCAL"

View File

@ -2,6 +2,7 @@ import datetime
import hashlib import hashlib
import base64 import base64
import hmac import hmac
import json
from urllib.parse import urlencode from urllib.parse import urlencode
from wsgiref.handlers import format_date_time from wsgiref.handlers import format_date_time
from datetime import datetime from datetime import datetime
@ -35,3 +36,29 @@ def generate_xf_asr_url():
} }
url = url + '?' + urlencode(v) url = url + '?' + urlencode(v)
return url return url
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
"data":{"status":0,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)
def make_continue_frame(buf):
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(continue_frame)
def make_last_frame(buf):
last_frame = {"data":{"status":2,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
return json.dumps(last_frame)
def parse_xfasr_recv(message):
code = message['code']
if code!=0:
raise Exception("讯飞ASR错误码"+str(code))
else:
data = message['data']['result']['ws']
result = ""
for i in data:
for w in i['cw']:
result += w['w']
return result