1
0
Fork 0

update:优化了websocket对数据库连接的获取与释放

This commit is contained in:
killua 2024-06-20 10:34:26 +08:00
parent ce46c0f35b
commit 5a47440e0f
8 changed files with 91 additions and 54 deletions

View File

@ -10,7 +10,7 @@ class ASR(ABC):
class LLM(ABC):
@abstractmethod
def chat(self, assistant, prompt, db):
def chat(self, assistant, prompt):
pass
class TTS(ABC):

View File

@ -56,20 +56,18 @@ class XF_ASR(ASR):
await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data))
logger.debug("发送完毕")
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
if "进入沉默模式" in self.current_message:
if self.current_message in [""]:
raise SideNoiseError()
if "闭嘴" in self.current_message:
self.is_slience = True
asyncio.create_task(self.websocket.close())
raise EnterSlienceMode()
if "退出沉默模式" in self.current_message:
self.is_slience = False
self.current_message = "已退出沉默模式"
if self.is_slience:
asyncio.create_task(self.websocket.close())
raise SlienceMode()
asyncio.create_task(self.websocket.close())
logger.debug(f"ASR结果: {self.current_message}")
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
@ -86,7 +84,7 @@ class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt, db):
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
@ -111,12 +109,10 @@ class MINIMAX_LLM(LLM):
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
@ -147,7 +143,7 @@ class VOLCENGINE_LLM(LLM):
self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt, db):
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
@ -167,18 +163,20 @@ class VOLCENGINE_LLM(LLM):
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":20-0,"msg":""}
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
elif llm_info['model'] == 'doubao-32k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
elif llm_info['model'] == 'doubao-32k-pro':
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
else:
raise UnknownVolcEngineModelError()
@ -322,31 +320,31 @@ class Agent():
self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理
def user_audio_process(self, audio, db):
def user_audio_process(self, audio):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, chunk, db):
async def stream_recognize(self, chunk):
return await self.asr.stream_recognize(chunk)
# 进行Prompt加工
def prompt_process(self, asr_results, db):
def prompt_process(self, asr_results):
return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用
async def chat(self, assistant ,prompt, db):
return self.llm.chat(assistant, prompt, db)
async def chat(self, assistant ,prompt):
return self.llm.chat(assistant, prompt)
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db):
def llm_msg_process(self, llm_chunk):
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成
def synthetize(self, assistant, text, db):
def synthetize(self, assistant, text):
return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理
def tts_audio_process(self, audio, db):
def tts_audio_process(self, audio):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码

View File

@ -46,6 +46,11 @@ class SessionNotFoundError(Exception):
super().__init__(message)
self.message = message
class LlmResultNoneError(Exception):
def __init__(self, message="LLM Result is None!"):
super().__init__(message)
self.message = message
# 大模型返回结束(非异常)
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):

View File

@ -1,8 +1,10 @@
from sqlalchemy import create_engine, Column, Integer, String, CHAR
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from config import Config
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@ -30,10 +32,17 @@ class Assistant(Base):
Base.metadata.create_all(bind=engine)
def get_db():
@contextmanager
def get_db_context():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@ -6,7 +6,7 @@ class Config:
LOG_LEVEL = "DEBUG"
class UVICORN:
HOST = '0.0.0.0'
PORT = 7878
PORT = 8001
class XF_ASR:
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
@ -21,3 +21,5 @@ class Config:
class VOLCENGINE_LLM:
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"

62
main.py
View File

@ -2,20 +2,29 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent
from app.model import Assistant, User, get_db
from app.model import Assistant, User, get_db, get_db_context
from app.schemas import *
from app.dependency import get_logger
from app.exception import *
import asyncio
import uvicorn
import uuid
import json
import time
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, kid_text,llm_text):
messages = json.loads(messages)
def update_messages(assistant, kid_text, llm_text):
messages = json.loads(assistant.messages)
if not kid_text:
raise AsrResultNoneError()
if not llm_text:
raise LlmResultNoneError()
messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
with get_db_context() as db:
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
db.commit()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
@ -177,7 +186,7 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
async def streaming_chat(ws: WebSocket):
await ws.accept()
logger.debug("WebSocket连接成功")
try:
@ -187,8 +196,9 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
if assistant is None:
with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
if assistant is None:
raise SessionNotFoundError()
@ -196,48 +206,60 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
if not agent:
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
chunk["audio"] = agent.user_audio_process(chunk["audio"])
asr_results = await agent.stream_recognize(chunk)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}")
prompt = agent.prompt_process(asr_results, db)
prompt = agent.prompt_process(asr_results)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db)
llm_frames = await agent.chat(assistant, prompt)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
resp_msgs = agent.llm_msg_process(llm_frame)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
tts_start_time = time.time()
tts_audio = agent.synthetize(assistant, resp_msg)
tts_end_time = time.time()
logger.debug(f"TTS生成音频耗时{tts_end_time-tts_start_time}s")
agent.tts_audio_process(tts_audio)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
db.commit()
update_messages(assistant, kid_text ,llm_text)
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
except EnterSlienceMode:
tts_audio = agent.synthetize(assistant, "已进入沉默模式", db)
tts_audio = agent.synthetize(assistant, "已进入沉默模式")
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
except SlienceMode:
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False))
logger.debug("进入沉默模式")
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
logger.error("ASR结果为空")
except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
logger.error(f"LLM模型返回异常错误信息{str(e)}")
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
logger.debug("检测为噪声")
except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
logger.error("session不存在")
except UnknownVolcEngineModelError:
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
logger.error("未知的火山引擎模型")
except LlmResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
logger.error("LLM结果返回为空")
except asyncio.TimeoutError:
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
logger.error("接收超时")
logger.debug("WebSocket连接断开")
logger.debug("")
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

View File

@ -12,3 +12,4 @@ cn2an
numba
librosa
aiohttp
'volcengine-python-sdk[ark]'

View File

@ -38,7 +38,7 @@ def generate_xf_asr_url():
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vinfo":1,"vad_eos":1000},
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)