1
0
Fork 0

update:优化了websocket对数据库连接的获取与释放

This commit is contained in:
killua 2024-06-20 10:34:26 +08:00
parent ce46c0f35b
commit 5a47440e0f
8 changed files with 91 additions and 54 deletions

View File

@ -10,7 +10,7 @@ class ASR(ABC):
class LLM(ABC): class LLM(ABC):
@abstractmethod @abstractmethod
def chat(self, assistant, prompt, db): def chat(self, assistant, prompt):
pass pass
class TTS(ABC): class TTS(ABC):

View File

@ -56,20 +56,18 @@ class XF_ASR(ASR):
await self.websocket.send(make_continue_frame(audio_data)) await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧 elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data)) await self.websocket.send(make_last_frame(audio_data))
logger.debug("发送完毕")
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "": if self.current_message == "":
raise AsrResultNoneError() raise AsrResultNoneError()
if "进入沉默模式" in self.current_message: if self.current_message in [""]:
raise SideNoiseError()
if "闭嘴" in self.current_message:
self.is_slience = True self.is_slience = True
asyncio.create_task(self.websocket.close()) asyncio.create_task(self.websocket.close())
raise EnterSlienceMode() raise EnterSlienceMode()
if "退出沉默模式" in self.current_message:
self.is_slience = False
self.current_message = "已退出沉默模式"
if self.is_slience:
asyncio.create_task(self.websocket.close())
raise SlienceMode()
asyncio.create_task(self.websocket.close()) asyncio.create_task(self.websocket.close())
logger.debug(f"ASR结果: {self.current_message}")
return [{"text":self.current_message, "audio":self.audio}] return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time() current_time = asyncio.get_event_loop().time()
@ -86,7 +84,7 @@ class MINIMAX_LLM(LLM):
def __init__(self): def __init__(self):
self.token = 0 self.token = 0
async def chat(self, assistant, prompt, db): async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages) messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt}) messages.append({'role':'user','content':prompt})
@ -111,12 +109,10 @@ class MINIMAX_LLM(LLM):
yield msg_frame yield msg_frame
except LLMResponseEnd: except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""} msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201' msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame yield msg_frame
@ -147,7 +143,7 @@ class VOLCENGINE_LLM(LLM):
self.token = 0 self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt, db): async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info) model = self.__get_model(llm_info)
messages = json.loads(assistant.messages) messages = json.loads(assistant.messages)
@ -167,18 +163,20 @@ class VOLCENGINE_LLM(LLM):
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame yield msg_frame
except LLMResponseEnd: except LLMResponseEnd:
msg_frame = {"is_end":True,"code":20-0,"msg":""} msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201' msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame yield msg_frame
def __get_model(self, llm_info): def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite': if llm_info['model'] == 'doubao-4k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
elif llm_info['model'] == 'doubao-32k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
elif llm_info['model'] == 'doubao-32k-pro':
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
else: else:
raise UnknownVolcEngineModelError() raise UnknownVolcEngineModelError()
@ -322,31 +320,31 @@ class Agent():
self.recorder = Recorder(user_id) self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理 # 对用户输入的音频进行预处理
def user_audio_process(self, audio, db): def user_audio_process(self, audio):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder) return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别 # 进行流式语音识别
async def stream_recognize(self, chunk, db): async def stream_recognize(self, chunk):
return await self.asr.stream_recognize(chunk) return await self.asr.stream_recognize(chunk)
# 进行Prompt加工 # 进行Prompt加工
def prompt_process(self, asr_results, db): def prompt_process(self, asr_results):
return self.prompt_service_chain.prompt_process(asr_results) return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用 # 进行大模型调用
async def chat(self, assistant ,prompt, db): async def chat(self, assistant ,prompt):
return self.llm.chat(assistant, prompt, db) return self.llm.chat(assistant, prompt)
# 对大模型的返回进行处理 # 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db): def llm_msg_process(self, llm_chunk):
return self.llm_msg_service_chain.llm_msg_process(llm_chunk) return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成 # 进行TTS合成
def synthetize(self, assistant, text, db): def synthetize(self, assistant, text):
return self.tts.synthetize(assistant, text) return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理 # 对合成后的音频进行处理
def tts_audio_process(self, audio, db): def tts_audio_process(self, audio):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder) return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码 # 编码

View File

@ -46,6 +46,11 @@ class SessionNotFoundError(Exception):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
class LlmResultNoneError(Exception):
def __init__(self, message="LLM Result is None!"):
super().__init__(message)
self.message = message
# 大模型返回结束(非异常) # 大模型返回结束(非异常)
class LLMResponseEnd(Exception): class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"): def __init__(self, message="LLM Response End!"):

View File

@ -1,8 +1,10 @@
from sqlalchemy import create_engine, Column, Integer, String, CHAR from sqlalchemy import create_engine, Column, Integer, String, CHAR
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from config import Config from config import Config
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False}) engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
@ -30,10 +32,17 @@ class Assistant(Base):
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
def get_db(): @contextmanager
def get_db_context():
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() db.close()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@ -6,7 +6,7 @@ class Config:
LOG_LEVEL = "DEBUG" LOG_LEVEL = "DEBUG"
class UVICORN: class UVICORN:
HOST = '0.0.0.0' HOST = '0.0.0.0'
PORT = 7878 PORT = 8001
class XF_ASR: class XF_ASR:
APP_ID = "f1c121c1" #讯飞语音识别APP_ID APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
@ -20,4 +20,6 @@ class Config:
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
class VOLCENGINE_LLM: class VOLCENGINE_LLM:
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d" API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk" DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"

64
main.py
View File

@ -2,20 +2,29 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from config import Config from config import Config
from app.concrete import Agent from app.concrete import Agent
from app.model import Assistant, User, get_db from app.model import Assistant, User, get_db, get_db_context
from app.schemas import * from app.schemas import *
from app.dependency import get_logger from app.dependency import get_logger
from app.exception import * from app.exception import *
import asyncio
import uvicorn import uvicorn
import uuid import uuid
import json import json
import time
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ # 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, kid_text,llm_text): def update_messages(assistant, kid_text, llm_text):
messages = json.loads(messages) messages = json.loads(assistant.messages)
if not kid_text:
raise AsrResultNoneError()
if not llm_text:
raise LlmResultNoneError()
messages.append({"role":"user","content":kid_text}) messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text}) messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False) with get_db_context() as db:
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
db.commit()
# -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象 # 引入logger对象
@ -177,7 +186,7 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------ # 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary") @app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)): async def streaming_chat(ws: WebSocket):
await ws.accept() await ws.accept()
logger.debug("WebSocket连接成功") logger.debug("WebSocket连接成功")
try: try:
@ -187,57 +196,70 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
llm_text = "" llm_text = ""
logger.debug("开始进行ASR识别") logger.debug("开始进行ASR识别")
while len(asr_results)==0: while len(asr_results)==0:
chunk = json.loads(await ws.receive_text()) chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
if assistant is None: if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
if assistant is None: if assistant is None:
raise SessionNotFoundError() raise SessionNotFoundError()
user_info = json.loads(assistant.user_info) user_info = json.loads(assistant.user_info)
if not agent: if not agent:
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type']) agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id) agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db) chunk["audio"] = agent.user_audio_process(chunk["audio"])
asr_results = await agent.stream_recognize(chunk, db) asr_results = await agent.stream_recognize(chunk)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}") prompt = agent.prompt_process(asr_results)
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt agent.recorder.input_text = prompt
logger.debug("开始调用大模型") logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db) llm_frames = await agent.chat(assistant, prompt)
async for llm_frame in llm_frames: async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db) resp_msgs = agent.llm_msg_process(llm_frame)
for resp_msg in resp_msgs: for resp_msg in resp_msgs:
llm_text += resp_msg llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db) tts_start_time = time.time()
agent.tts_audio_process(tts_audio, db) tts_audio = agent.synthetize(assistant, resp_msg)
tts_end_time = time.time()
logger.debug(f"TTS生成音频耗时{tts_end_time-tts_start_time}s")
agent.tts_audio_process(tts_audio)
await ws.send_bytes(agent.encode(resp_msg, tts_audio)) await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}') logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}") logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕") logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text) update_messages(assistant, kid_text ,llm_text)
db.commit()
logger.debug("聊天更新成功") logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text agent.recorder.output_text = llm_text
agent.save() agent.save()
logger.debug("音频保存成功") logger.debug("音频保存成功")
except EnterSlienceMode: except EnterSlienceMode:
tts_audio = agent.synthetize(assistant, "已进入沉默模式", db) tts_audio = agent.synthetize(assistant, "已进入沉默模式")
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio)) await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
except SlienceMode: logger.debug("进入沉默模式")
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False))
except AsrResultNoneError: except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
logger.error("ASR结果为空")
except AbnormalLLMFrame as e: except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
logger.error(f"LLM模型返回异常错误信息{str(e)}")
except SideNoiseError as e: except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
logger.debug("检测为噪声")
except SessionNotFoundError: except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
logger.error("session不存在")
except UnknownVolcEngineModelError: except UnknownVolcEngineModelError:
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
logger.error("未知的火山引擎模型")
except LlmResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
logger.error("LLM结果返回为空")
except asyncio.TimeoutError:
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
logger.error("接收超时")
logger.debug("WebSocket连接断开") logger.debug("WebSocket连接断开")
logger.debug("")
await ws.close() await ws.close()
# -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

View File

@ -11,4 +11,5 @@ jieba
cn2an cn2an
numba numba
librosa librosa
aiohttp aiohttp
'volcengine-python-sdk[ark]'

View File

@ -38,7 +38,7 @@ def generate_xf_asr_url():
def make_first_frame(buf): def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000}, first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vinfo":1,"vad_eos":1000},
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}} "data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame) return json.dumps(first_frame)