forked from killua/TakwayPlatform
feat: 增加讯飞asr功能
This commit is contained in:
parent
30fdb9c6bd
commit
09ffeb6ab6
|
@ -8,9 +8,10 @@ from ..models import UserCharacter, Session, Character, User, Audio
|
||||||
from utils.audio_utils import VAD
|
from utils.audio_utils import VAD
|
||||||
from fastapi import WebSocket, HTTPException, status
|
from fastapi import WebSocket, HTTPException, status
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils.xf_asr_utils import generate_xf_asr_url
|
from utils.xf_asr_utils import generate_xf_asr_url, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||||
from config import get_config
|
from config import get_config
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import websockets
|
||||||
import struct
|
import struct
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
@ -114,7 +115,7 @@ def get_emb(session_id,db):
|
||||||
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
||||||
return emb_npy
|
return emb_npy
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("未找到音频:"+str(e))
|
logger.debug("未找到音频:"+str(e))
|
||||||
return np.array([])
|
return np.array([])
|
||||||
#--------------------------------------------------------
|
#--------------------------------------------------------
|
||||||
|
|
||||||
|
@ -247,35 +248,51 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
#语音识别
|
#语音识别
|
||||||
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
is_signup = False
|
if Config.STRAM_CHAT.ASR == "LOCAL":
|
||||||
audio = ""
|
is_signup = False
|
||||||
try:
|
audio = ""
|
||||||
current_message = ""
|
try:
|
||||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
current_message = ""
|
||||||
if not is_signup:
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||||
asr.session_signup(session_id)
|
if not is_signup:
|
||||||
is_signup = True
|
asr.session_signup(session_id)
|
||||||
audio_data = await user_input_q.get()
|
is_signup = True
|
||||||
audio += audio_data
|
audio_data = await user_input_q.get()
|
||||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
audio += audio_data
|
||||||
|
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||||
|
current_message += ''.join(asr_result['text'])
|
||||||
|
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
if current_message == "":
|
||||||
current_message += ''.join(asr_result['text'])
|
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||||
if current_message == "":
|
return
|
||||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
current_message = asr.punctuation_correction(current_message)
|
||||||
return
|
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||||
current_message = asr.punctuation_correction(current_message)
|
if not isinstance(emotion_dict, str):
|
||||||
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
if not isinstance(emotion_dict, str):
|
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
await llm_input_q.put(current_message)
|
||||||
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
asr.session_signout(session_id)
|
||||||
await llm_input_q.put(current_message)
|
except Exception as e:
|
||||||
asr.session_signout(session_id)
|
asr.session_signout(session_id)
|
||||||
except Exception as e:
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
asr.session_signout(session_id)
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
elif Config.STRAM_CHAT.ASR == "XF":
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
status = FIRST_FRAME
|
||||||
|
async with websockets.connect(generate_xf_asr_url()) as xf_websocket:
|
||||||
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||||
|
audio_data = await user_input_q.get()
|
||||||
|
if status == FIRST_FRAME:
|
||||||
|
await xf_websocket.send(make_first_frame(audio_data))
|
||||||
|
status = CONTINUE_FRAME
|
||||||
|
elif status == CONTINUE_FRAME:
|
||||||
|
await xf_websocket.send(make_continue_frame(audio_data))
|
||||||
|
await xf_websocket.send(make_last_frame(""))
|
||||||
|
|
||||||
|
current_message = parse_xfasr_recv(json.loads(await xf_websocket.recv()))
|
||||||
|
await llm_input_q.put(current_message)
|
||||||
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
|
|
||||||
#大模型调用
|
#大模型调用
|
||||||
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
||||||
logger.debug("llm调用函数启动")
|
logger.debug("llm调用函数启动")
|
||||||
|
|
|
@ -8,9 +8,9 @@ class DevelopmentConfig:
|
||||||
PORT = 8001 #uvicorn运行端口
|
PORT = 8001 #uvicorn运行端口
|
||||||
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
||||||
class XF_ASR:
|
class XF_ASR:
|
||||||
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||||
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||||
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
|
||||||
DOMAIN = "iat"
|
DOMAIN = "iat"
|
||||||
LANGUAGE = "zh_cn"
|
LANGUAGE = "zh_cn"
|
||||||
ACCENT = "mandarin"
|
ACCENT = "mandarin"
|
||||||
|
@ -23,6 +23,6 @@ class DevelopmentConfig:
|
||||||
URL = "https://api.minimax.chat/v1/t2a_pro",
|
URL = "https://api.minimax.chat/v1/t2a_pro",
|
||||||
GROUP_ID ="1759482180095975904"
|
GROUP_ID ="1759482180095975904"
|
||||||
class STRAM_CHAT:
|
class STRAM_CHAT:
|
||||||
ASR = "LOCAL"
|
ASR = "XF" # 语音识别引擎,可选XF或者LOCAL
|
||||||
TTS = "LOCAL"
|
TTS = "LOCAL"
|
||||||
|
|
|
@ -2,6 +2,7 @@ import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
|
import json
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -35,3 +36,29 @@ def generate_xf_asr_url():
|
||||||
}
|
}
|
||||||
url = url + '?' + urlencode(v)
|
url = url + '?' + urlencode(v)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def make_first_frame(buf):
|
||||||
|
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
||||||
|
"data":{"status":0,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||||
|
return json.dumps(first_frame)
|
||||||
|
|
||||||
|
def make_continue_frame(buf):
|
||||||
|
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||||
|
return json.dumps(continue_frame)
|
||||||
|
|
||||||
|
def make_last_frame(buf):
|
||||||
|
last_frame = {"data":{"status":2,"format":"audio/L16;rate=8000","audio":buf,"encoding":"raw"}}
|
||||||
|
return json.dumps(last_frame)
|
||||||
|
|
||||||
|
def parse_xfasr_recv(message):
|
||||||
|
code = message['code']
|
||||||
|
if code!=0:
|
||||||
|
raise Exception("讯飞ASR错误码:"+str(code))
|
||||||
|
else:
|
||||||
|
data = message['data']['result']['ws']
|
||||||
|
result = ""
|
||||||
|
for i in data:
|
||||||
|
for w in i['cw']:
|
||||||
|
result += w['w']
|
||||||
|
return result
|
Loading…
Reference in New Issue