forked from killua/TakwayDisplayPlatform
update:优化了websocket对数据库连接的获取与释放
This commit is contained in:
parent
ce46c0f35b
commit
5a47440e0f
|
@ -10,7 +10,7 @@ class ASR(ABC):
|
||||||
|
|
||||||
class LLM(ABC):
|
class LLM(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def chat(self, assistant, prompt, db):
|
def chat(self, assistant, prompt):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class TTS(ABC):
|
class TTS(ABC):
|
||||||
|
|
|
@ -56,20 +56,18 @@ class XF_ASR(ASR):
|
||||||
await self.websocket.send(make_continue_frame(audio_data))
|
await self.websocket.send(make_continue_frame(audio_data))
|
||||||
elif self.status == LAST_FRAME: #发送最后一帧
|
elif self.status == LAST_FRAME: #发送最后一帧
|
||||||
await self.websocket.send(make_last_frame(audio_data))
|
await self.websocket.send(make_last_frame(audio_data))
|
||||||
|
logger.debug("发送完毕")
|
||||||
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
||||||
if self.current_message == "":
|
if self.current_message == "":
|
||||||
raise AsrResultNoneError()
|
raise AsrResultNoneError()
|
||||||
if "进入沉默模式" in self.current_message:
|
if self.current_message in ["啊"]:
|
||||||
|
raise SideNoiseError()
|
||||||
|
if "闭嘴" in self.current_message:
|
||||||
self.is_slience = True
|
self.is_slience = True
|
||||||
asyncio.create_task(self.websocket.close())
|
asyncio.create_task(self.websocket.close())
|
||||||
raise EnterSlienceMode()
|
raise EnterSlienceMode()
|
||||||
if "退出沉默模式" in self.current_message:
|
|
||||||
self.is_slience = False
|
|
||||||
self.current_message = "已退出沉默模式"
|
|
||||||
if self.is_slience:
|
|
||||||
asyncio.create_task(self.websocket.close())
|
|
||||||
raise SlienceMode()
|
|
||||||
asyncio.create_task(self.websocket.close())
|
asyncio.create_task(self.websocket.close())
|
||||||
|
logger.debug(f"ASR结果: {self.current_message}")
|
||||||
return [{"text":self.current_message, "audio":self.audio}]
|
return [{"text":self.current_message, "audio":self.audio}]
|
||||||
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
@ -86,7 +84,7 @@ class MINIMAX_LLM(LLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token = 0
|
self.token = 0
|
||||||
|
|
||||||
async def chat(self, assistant, prompt, db):
|
async def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
messages.append({'role':'user','content':prompt})
|
messages.append({'role':'user','content':prompt})
|
||||||
|
@ -111,12 +109,10 @@ class MINIMAX_LLM(LLM):
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
except LLMResponseEnd:
|
except LLMResponseEnd:
|
||||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
|
assistant.token = self.token
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
msg_frame['code'] = '201'
|
msg_frame['code'] = '201'
|
||||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
|
||||||
db.commit()
|
|
||||||
assistant.messages = as_query.messages
|
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,7 +143,7 @@ class VOLCENGINE_LLM(LLM):
|
||||||
self.token = 0
|
self.token = 0
|
||||||
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||||
|
|
||||||
async def chat(self, assistant, prompt, db):
|
async def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
model = self.__get_model(llm_info)
|
model = self.__get_model(llm_info)
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
|
@ -167,18 +163,20 @@ class VOLCENGINE_LLM(LLM):
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
except LLMResponseEnd:
|
except LLMResponseEnd:
|
||||||
msg_frame = {"is_end":True,"code":20-0,"msg":""}
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
|
assistant.token = self.token
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
msg_frame['code'] = '201'
|
msg_frame['code'] = '201'
|
||||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
|
||||||
db.commit()
|
|
||||||
assistant.messages = as_query.messages
|
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
|
|
||||||
def __get_model(self, llm_info):
|
def __get_model(self, llm_info):
|
||||||
if llm_info['model'] == 'doubao-4k-lite':
|
if llm_info['model'] == 'doubao-4k-lite':
|
||||||
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
||||||
|
elif llm_info['model'] == 'doubao-32k-lite':
|
||||||
|
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
|
||||||
|
elif llm_info['model'] == 'doubao-32k-pro':
|
||||||
|
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
|
||||||
else:
|
else:
|
||||||
raise UnknownVolcEngineModelError()
|
raise UnknownVolcEngineModelError()
|
||||||
|
|
||||||
|
@ -322,31 +320,31 @@ class Agent():
|
||||||
self.recorder = Recorder(user_id)
|
self.recorder = Recorder(user_id)
|
||||||
|
|
||||||
# 对用户输入的音频进行预处理
|
# 对用户输入的音频进行预处理
|
||||||
def user_audio_process(self, audio, db):
|
def user_audio_process(self, audio):
|
||||||
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
||||||
|
|
||||||
# 进行流式语音识别
|
# 进行流式语音识别
|
||||||
async def stream_recognize(self, chunk, db):
|
async def stream_recognize(self, chunk):
|
||||||
return await self.asr.stream_recognize(chunk)
|
return await self.asr.stream_recognize(chunk)
|
||||||
|
|
||||||
# 进行Prompt加工
|
# 进行Prompt加工
|
||||||
def prompt_process(self, asr_results, db):
|
def prompt_process(self, asr_results):
|
||||||
return self.prompt_service_chain.prompt_process(asr_results)
|
return self.prompt_service_chain.prompt_process(asr_results)
|
||||||
|
|
||||||
# 进行大模型调用
|
# 进行大模型调用
|
||||||
async def chat(self, assistant ,prompt, db):
|
async def chat(self, assistant ,prompt):
|
||||||
return self.llm.chat(assistant, prompt, db)
|
return self.llm.chat(assistant, prompt)
|
||||||
|
|
||||||
# 对大模型的返回进行处理
|
# 对大模型的返回进行处理
|
||||||
def llm_msg_process(self, llm_chunk, db):
|
def llm_msg_process(self, llm_chunk):
|
||||||
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
||||||
|
|
||||||
# 进行TTS合成
|
# 进行TTS合成
|
||||||
def synthetize(self, assistant, text, db):
|
def synthetize(self, assistant, text):
|
||||||
return self.tts.synthetize(assistant, text)
|
return self.tts.synthetize(assistant, text)
|
||||||
|
|
||||||
# 对合成后的音频进行处理
|
# 对合成后的音频进行处理
|
||||||
def tts_audio_process(self, audio, db):
|
def tts_audio_process(self, audio):
|
||||||
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
||||||
|
|
||||||
# 编码
|
# 编码
|
||||||
|
|
|
@ -46,6 +46,11 @@ class SessionNotFoundError(Exception):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
class LlmResultNoneError(Exception):
|
||||||
|
def __init__(self, message="LLM Result is None!"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
# 大模型返回结束(非异常)
|
# 大模型返回结束(非异常)
|
||||||
class LLMResponseEnd(Exception):
|
class LLMResponseEnd(Exception):
|
||||||
def __init__(self, message="LLM Response End!"):
|
def __init__(self, message="LLM Response End!"):
|
||||||
|
|
11
app/model.py
11
app/model.py
|
@ -1,8 +1,10 @@
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, CHAR
|
from sqlalchemy import create_engine, Column, Integer, String, CHAR
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from contextlib import contextmanager
|
||||||
from config import Config
|
from config import Config
|
||||||
|
|
||||||
|
|
||||||
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
|
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
@ -30,10 +32,17 @@ class Assistant(Base):
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
def get_db():
|
@contextmanager
|
||||||
|
def get_db_context():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
|
@ -6,7 +6,7 @@ class Config:
|
||||||
LOG_LEVEL = "DEBUG"
|
LOG_LEVEL = "DEBUG"
|
||||||
class UVICORN:
|
class UVICORN:
|
||||||
HOST = '0.0.0.0'
|
HOST = '0.0.0.0'
|
||||||
PORT = 7878
|
PORT = 8001
|
||||||
class XF_ASR:
|
class XF_ASR:
|
||||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||||
|
@ -21,3 +21,5 @@ class Config:
|
||||||
class VOLCENGINE_LLM:
|
class VOLCENGINE_LLM:
|
||||||
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
||||||
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
||||||
|
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
|
||||||
|
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"
|
64
main.py
64
main.py
|
@ -2,20 +2,29 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from config import Config
|
from config import Config
|
||||||
from app.concrete import Agent
|
from app.concrete import Agent
|
||||||
from app.model import Assistant, User, get_db
|
from app.model import Assistant, User, get_db, get_db_context
|
||||||
from app.schemas import *
|
from app.schemas import *
|
||||||
from app.dependency import get_logger
|
from app.dependency import get_logger
|
||||||
from app.exception import *
|
from app.exception import *
|
||||||
|
import asyncio
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
def update_messages(messages, kid_text,llm_text):
|
def update_messages(assistant, kid_text, llm_text):
|
||||||
messages = json.loads(messages)
|
messages = json.loads(assistant.messages)
|
||||||
|
if not kid_text:
|
||||||
|
raise AsrResultNoneError()
|
||||||
|
if not llm_text:
|
||||||
|
raise LlmResultNoneError()
|
||||||
messages.append({"role":"user","content":kid_text})
|
messages.append({"role":"user","content":kid_text})
|
||||||
messages.append({"role":"assistant","content":llm_text})
|
messages.append({"role":"assistant","content":llm_text})
|
||||||
return json.dumps(messages,ensure_ascii=False)
|
with get_db_context() as db:
|
||||||
|
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
|
||||||
|
db.commit()
|
||||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# 引入logger对象
|
# 引入logger对象
|
||||||
|
@ -177,7 +186,7 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
||||||
|
|
||||||
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
@app.websocket("/api/chat/streaming/temporary")
|
@app.websocket("/api/chat/streaming/temporary")
|
||||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
async def streaming_chat(ws: WebSocket):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
logger.debug("WebSocket连接成功")
|
logger.debug("WebSocket连接成功")
|
||||||
try:
|
try:
|
||||||
|
@ -187,57 +196,70 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
llm_text = ""
|
llm_text = ""
|
||||||
logger.debug("开始进行ASR识别")
|
logger.debug("开始进行ASR识别")
|
||||||
while len(asr_results)==0:
|
while len(asr_results)==0:
|
||||||
chunk = json.loads(await ws.receive_text())
|
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
||||||
|
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
raise SessionNotFoundError()
|
raise SessionNotFoundError()
|
||||||
user_info = json.loads(assistant.user_info)
|
user_info = json.loads(assistant.user_info)
|
||||||
if not agent:
|
if not agent:
|
||||||
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
||||||
agent.init_recorder(assistant.user_id)
|
agent.init_recorder(assistant.user_id)
|
||||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
chunk["audio"] = agent.user_audio_process(chunk["audio"])
|
||||||
asr_results = await agent.stream_recognize(chunk, db)
|
asr_results = await agent.stream_recognize(chunk)
|
||||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||||
logger.debug(f"ASR识别成功,识别结果为:{kid_text}")
|
prompt = agent.prompt_process(asr_results)
|
||||||
prompt = agent.prompt_process(asr_results, db)
|
|
||||||
agent.recorder.input_text = prompt
|
agent.recorder.input_text = prompt
|
||||||
logger.debug("开始调用大模型")
|
logger.debug("开始调用大模型")
|
||||||
llm_frames = await agent.chat(assistant, prompt, db)
|
llm_frames = await agent.chat(assistant, prompt)
|
||||||
async for llm_frame in llm_frames:
|
async for llm_frame in llm_frames:
|
||||||
resp_msgs = agent.llm_msg_process(llm_frame, db)
|
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||||
for resp_msg in resp_msgs:
|
for resp_msg in resp_msgs:
|
||||||
llm_text += resp_msg
|
llm_text += resp_msg
|
||||||
tts_audio = agent.synthetize(assistant, resp_msg, db)
|
tts_start_time = time.time()
|
||||||
agent.tts_audio_process(tts_audio, db)
|
tts_audio = agent.synthetize(assistant, resp_msg)
|
||||||
|
tts_end_time = time.time()
|
||||||
|
logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s")
|
||||||
|
agent.tts_audio_process(tts_audio)
|
||||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||||
logger.debug(f'websocket返回:{resp_msg}')
|
logger.debug(f'websocket返回:{resp_msg}')
|
||||||
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
|
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
|
||||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
logger.debug("结束帧发送完毕")
|
logger.debug("结束帧发送完毕")
|
||||||
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
|
update_messages(assistant, kid_text ,llm_text)
|
||||||
db.commit()
|
|
||||||
logger.debug("聊天更新成功")
|
logger.debug("聊天更新成功")
|
||||||
agent.recorder.output_text = llm_text
|
agent.recorder.output_text = llm_text
|
||||||
agent.save()
|
agent.save()
|
||||||
logger.debug("音频保存成功")
|
logger.debug("音频保存成功")
|
||||||
except EnterSlienceMode:
|
except EnterSlienceMode:
|
||||||
tts_audio = agent.synthetize(assistant, "已进入沉默模式", db)
|
tts_audio = agent.synthetize(assistant, "已进入沉默模式")
|
||||||
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
|
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
|
||||||
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
|
||||||
except SlienceMode:
|
logger.debug("进入沉默模式")
|
||||||
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False))
|
|
||||||
except AsrResultNoneError:
|
except AsrResultNoneError:
|
||||||
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||||
|
logger.error("ASR结果为空")
|
||||||
except AbnormalLLMFrame as e:
|
except AbnormalLLMFrame as e:
|
||||||
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
||||||
|
logger.error(f"LLM模型返回异常,错误信息:{str(e)}")
|
||||||
except SideNoiseError as e:
|
except SideNoiseError as e:
|
||||||
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
|
||||||
|
logger.debug("检测为噪声")
|
||||||
except SessionNotFoundError:
|
except SessionNotFoundError:
|
||||||
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
||||||
|
logger.error("session不存在")
|
||||||
except UnknownVolcEngineModelError:
|
except UnknownVolcEngineModelError:
|
||||||
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
||||||
|
logger.error("未知的火山引擎模型")
|
||||||
|
except LlmResultNoneError:
|
||||||
|
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
|
||||||
|
logger.error("LLM结果返回为空")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
|
||||||
|
logger.error("接收超时")
|
||||||
logger.debug("WebSocket连接断开")
|
logger.debug("WebSocket连接断开")
|
||||||
|
logger.debug("")
|
||||||
await ws.close()
|
await ws.close()
|
||||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -12,3 +12,4 @@ cn2an
|
||||||
numba
|
numba
|
||||||
librosa
|
librosa
|
||||||
aiohttp
|
aiohttp
|
||||||
|
'volcengine-python-sdk[ark]'
|
|
@ -38,7 +38,7 @@ def generate_xf_asr_url():
|
||||||
|
|
||||||
|
|
||||||
def make_first_frame(buf):
|
def make_first_frame(buf):
|
||||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vinfo":1,"vad_eos":1000},
|
||||||
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||||
return json.dumps(first_frame)
|
return json.dumps(first_frame)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue