forked from killua/TakwayDisplayPlatform
update:优化了websocket对数据库连接的获取与释放
This commit is contained in:
parent
ce46c0f35b
commit
5a47440e0f
|
@ -10,7 +10,7 @@ class ASR(ABC):
|
|||
|
||||
class LLM(ABC):
|
||||
@abstractmethod
|
||||
def chat(self, assistant, prompt, db):
|
||||
def chat(self, assistant, prompt):
|
||||
pass
|
||||
|
||||
class TTS(ABC):
|
||||
|
|
|
@ -56,20 +56,18 @@ class XF_ASR(ASR):
|
|||
await self.websocket.send(make_continue_frame(audio_data))
|
||||
elif self.status == LAST_FRAME: #发送最后一帧
|
||||
await self.websocket.send(make_last_frame(audio_data))
|
||||
logger.debug("发送完毕")
|
||||
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
||||
if self.current_message == "":
|
||||
raise AsrResultNoneError()
|
||||
if "进入沉默模式" in self.current_message:
|
||||
if self.current_message in ["啊"]:
|
||||
raise SideNoiseError()
|
||||
if "闭嘴" in self.current_message:
|
||||
self.is_slience = True
|
||||
asyncio.create_task(self.websocket.close())
|
||||
raise EnterSlienceMode()
|
||||
if "退出沉默模式" in self.current_message:
|
||||
self.is_slience = False
|
||||
self.current_message = "已退出沉默模式"
|
||||
if self.is_slience:
|
||||
asyncio.create_task(self.websocket.close())
|
||||
raise SlienceMode()
|
||||
asyncio.create_task(self.websocket.close())
|
||||
logger.debug(f"ASR结果: {self.current_message}")
|
||||
return [{"text":self.current_message, "audio":self.audio}]
|
||||
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
@ -86,7 +84,7 @@ class MINIMAX_LLM(LLM):
|
|||
def __init__(self):
|
||||
self.token = 0
|
||||
|
||||
async def chat(self, assistant, prompt, db):
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
|
@ -111,12 +109,10 @@ class MINIMAX_LLM(LLM):
|
|||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
db.commit()
|
||||
assistant.messages = as_query.messages
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
|
||||
|
||||
|
@ -147,7 +143,7 @@ class VOLCENGINE_LLM(LLM):
|
|||
self.token = 0
|
||||
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||
|
||||
async def chat(self, assistant, prompt, db):
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
model = self.__get_model(llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
|
@ -167,18 +163,20 @@ class VOLCENGINE_LLM(LLM):
|
|||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":20-0,"msg":""}
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
db.commit()
|
||||
assistant.messages = as_query.messages
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
|
||||
def __get_model(self, llm_info):
|
||||
if llm_info['model'] == 'doubao-4k-lite':
|
||||
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
||||
elif llm_info['model'] == 'doubao-32k-lite':
|
||||
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
|
||||
elif llm_info['model'] == 'doubao-32k-pro':
|
||||
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
|
||||
else:
|
||||
raise UnknownVolcEngineModelError()
|
||||
|
||||
|
@ -322,31 +320,31 @@ class Agent():
|
|||
self.recorder = Recorder(user_id)
|
||||
|
||||
# 对用户输入的音频进行预处理
|
||||
def user_audio_process(self, audio, db):
|
||||
def user_audio_process(self, audio):
|
||||
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
||||
|
||||
# 进行流式语音识别
|
||||
async def stream_recognize(self, chunk, db):
|
||||
async def stream_recognize(self, chunk):
|
||||
return await self.asr.stream_recognize(chunk)
|
||||
|
||||
# 进行Prompt加工
|
||||
def prompt_process(self, asr_results, db):
|
||||
def prompt_process(self, asr_results):
|
||||
return self.prompt_service_chain.prompt_process(asr_results)
|
||||
|
||||
# 进行大模型调用
|
||||
async def chat(self, assistant ,prompt, db):
|
||||
return self.llm.chat(assistant, prompt, db)
|
||||
async def chat(self, assistant ,prompt):
|
||||
return self.llm.chat(assistant, prompt)
|
||||
|
||||
# 对大模型的返回进行处理
|
||||
def llm_msg_process(self, llm_chunk, db):
|
||||
def llm_msg_process(self, llm_chunk):
|
||||
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
||||
|
||||
# 进行TTS合成
|
||||
def synthetize(self, assistant, text, db):
|
||||
def synthetize(self, assistant, text):
|
||||
return self.tts.synthetize(assistant, text)
|
||||
|
||||
# 对合成后的音频进行处理
|
||||
def tts_audio_process(self, audio, db):
|
||||
def tts_audio_process(self, audio):
|
||||
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
||||
|
||||
# 编码
|
||||
|
|
|
@ -46,6 +46,11 @@ class SessionNotFoundError(Exception):
|
|||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class LlmResultNoneError(Exception):
|
||||
def __init__(self, message="LLM Result is None!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 大模型返回结束(非异常)
|
||||
class LLMResponseEnd(Exception):
|
||||
def __init__(self, message="LLM Response End!"):
|
||||
|
|
11
app/model.py
11
app/model.py
|
@ -1,8 +1,10 @@
|
|||
from sqlalchemy import create_engine, Column, Integer, String, CHAR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from contextlib import contextmanager
|
||||
from config import Config
|
||||
|
||||
|
||||
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
@ -30,10 +32,17 @@ class Assistant(Base):
|
|||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
def get_db():
|
||||
@contextmanager
|
||||
def get_db_context():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
|
@ -6,7 +6,7 @@ class Config:
|
|||
LOG_LEVEL = "DEBUG"
|
||||
class UVICORN:
|
||||
HOST = '0.0.0.0'
|
||||
PORT = 7878
|
||||
PORT = 8001
|
||||
class XF_ASR:
|
||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||
|
@ -21,3 +21,5 @@ class Config:
|
|||
class VOLCENGINE_LLM:
|
||||
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
||||
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
||||
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
|
||||
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"
|
64
main.py
64
main.py
|
@ -2,20 +2,29 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from config import Config
|
||||
from app.concrete import Agent
|
||||
from app.model import Assistant, User, get_db
|
||||
from app.model import Assistant, User, get_db, get_db_context
|
||||
from app.schemas import *
|
||||
from app.dependency import get_logger
|
||||
from app.exception import *
|
||||
import asyncio
|
||||
import uvicorn
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
def update_messages(messages, kid_text,llm_text):
|
||||
messages = json.loads(messages)
|
||||
def update_messages(assistant, kid_text, llm_text):
|
||||
messages = json.loads(assistant.messages)
|
||||
if not kid_text:
|
||||
raise AsrResultNoneError()
|
||||
if not llm_text:
|
||||
raise LlmResultNoneError()
|
||||
messages.append({"role":"user","content":kid_text})
|
||||
messages.append({"role":"assistant","content":llm_text})
|
||||
return json.dumps(messages,ensure_ascii=False)
|
||||
with get_db_context() as db:
|
||||
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
|
||||
db.commit()
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# 引入logger对象
|
||||
|
@ -177,7 +186,7 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
|||
|
||||
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
@app.websocket("/api/chat/streaming/temporary")
|
||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||
async def streaming_chat(ws: WebSocket):
|
||||
await ws.accept()
|
||||
logger.debug("WebSocket连接成功")
|
||||
try:
|
||||
|
@ -187,57 +196,70 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
|||
llm_text = ""
|
||||
logger.debug("开始进行ASR识别")
|
||||
while len(asr_results)==0:
|
||||
chunk = json.loads(await ws.receive_text())
|
||||
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
|
||||
if assistant is None:
|
||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||
if assistant is None:
|
||||
raise SessionNotFoundError()
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if not agent:
|
||||
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
||||
agent.init_recorder(assistant.user_id)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||
asr_results = await agent.stream_recognize(chunk, db)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"])
|
||||
asr_results = await agent.stream_recognize(chunk)
|
||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||
logger.debug(f"ASR识别成功,识别结果为:{kid_text}")
|
||||
prompt = agent.prompt_process(asr_results, db)
|
||||
prompt = agent.prompt_process(asr_results)
|
||||
agent.recorder.input_text = prompt
|
||||
logger.debug("开始调用大模型")
|
||||
llm_frames = await agent.chat(assistant, prompt, db)
|
||||
llm_frames = await agent.chat(assistant, prompt)
|
||||
async for llm_frame in llm_frames:
|
||||
resp_msgs = agent.llm_msg_process(llm_frame, db)
|
||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
tts_audio = agent.synthetize(assistant, resp_msg, db)
|
||||
agent.tts_audio_process(tts_audio, db)
|
||||
tts_start_time = time.time()
|
||||
tts_audio = agent.synthetize(assistant, resp_msg)
|
||||
tts_end_time = time.time()
|
||||
logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s")
|
||||
agent.tts_audio_process(tts_audio)
|
||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||
logger.debug(f'websocket返回:{resp_msg}')
|
||||
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
logger.debug("结束帧发送完毕")
|
||||
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
|
||||
db.commit()
|
||||
update_messages(assistant, kid_text ,llm_text)
|
||||
logger.debug("聊天更新成功")
|
||||
agent.recorder.output_text = llm_text
|
||||
agent.save()
|
||||
logger.debug("音频保存成功")
|
||||
except EnterSlienceMode:
|
||||
tts_audio = agent.synthetize(assistant, "已进入沉默模式", db)
|
||||
tts_audio = agent.synthetize(assistant, "已进入沉默模式")
|
||||
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
|
||||
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
|
||||
except SlienceMode:
|
||||
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False))
|
||||
logger.debug("进入沉默模式")
|
||||
except AsrResultNoneError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||
logger.error("ASR结果为空")
|
||||
except AbnormalLLMFrame as e:
|
||||
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
||||
logger.error(f"LLM模型返回异常,错误信息:{str(e)}")
|
||||
except SideNoiseError as e:
|
||||
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
|
||||
await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
|
||||
logger.debug("检测为噪声")
|
||||
except SessionNotFoundError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
||||
logger.error("session不存在")
|
||||
except UnknownVolcEngineModelError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
||||
logger.error("未知的火山引擎模型")
|
||||
except LlmResultNoneError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
|
||||
logger.error("LLM结果返回为空")
|
||||
except asyncio.TimeoutError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
|
||||
logger.error("接收超时")
|
||||
logger.debug("WebSocket连接断开")
|
||||
logger.debug("")
|
||||
await ws.close()
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -12,3 +12,4 @@ cn2an
|
|||
numba
|
||||
librosa
|
||||
aiohttp
|
||||
'volcengine-python-sdk[ark]'
|
|
@ -38,7 +38,7 @@ def generate_xf_asr_url():
|
|||
|
||||
|
||||
def make_first_frame(buf):
|
||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vinfo":1,"vad_eos":1000},
|
||||
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(first_frame)
|
||||
|
||||
|
|
Loading…
Reference in New Issue