update: 封装了豆包模型

This commit is contained in:
killua4396 2024-06-12 17:18:47 +08:00
parent 286e83e025
commit 76c7adc1b5
4 changed files with 67 additions and 8 deletions

View File

@ -1,4 +1,5 @@
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from volcenginesdkarkruntime import Ark
from .model import Assistant from .model import Assistant
from .abstract import * from .abstract import *
from .public import * from .public import *
@ -98,9 +99,9 @@ class MINIMAX_LLM(LLM):
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame yield msg_frame
except LLMResponseEnd: except LLMResponseEnd:
msg_frame = {"is_end":True,"msg":""} msg_frame = {"is_end":True,"code":200,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['msg'] = 'max_token reached' msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit() db.commit()
@ -130,6 +131,51 @@ class MINIMAX_LLM(LLM):
logger.error(llm_chunk) logger.error(llm_chunk)
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}") raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
class VOLCENGINE_LLM(LLM):
def __init__(self):
self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = model,
messages=messages,
stream=True,
stream_options={'include_usage': True}
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":20-0,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame
def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
else:
raise UnknownVolcEngineModelError()
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
self.token = llm_chunk.usage.total_tokens
raise LLMResponseEnd()
if not llm_chunk.choices:
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
class VITS_TTS(TTS): class VITS_TTS(TTS):
def __init__(self): def __init__(self):
pass pass
@ -151,6 +197,8 @@ class LLMFactory:
def create_llm(self,llm_type:str) -> LLM: def create_llm(self,llm_type:str) -> LLM:
if llm_type == 'MINIMAX': if llm_type == 'MINIMAX':
return MINIMAX_LLM() return MINIMAX_LLM()
if llm_type == 'VOLCENGINE':
return VOLCENGINE_LLM()
class TTSFactory: class TTSFactory:
def create_tts(self,tts_type:str) -> TTS: def create_tts(self,tts_type:str) -> TTS:

View File

@ -10,6 +10,12 @@ class NoAsrResultsError(Exception):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
# 未知的火山引擎模型
class UnknownVolcEngineModelError(Exception):
def __init__(self, message="Unknown Volc Engine Model!"):
super().__init__(message)
self.message = message
# 未知LLM返回帧 # 未知LLM返回帧
class UnkownLLMFrame(Exception): class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"): def __init__(self, message="Unkown LLM Frame!"):

View File

@ -18,3 +18,6 @@ class Config:
class MINIMAX_LLM: class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg" API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
class VOLCENGINE_LLM:
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"

12
main.py
View File

@ -5,7 +5,7 @@ from app.concrete import Agent
from app.model import Assistant, User, get_db from app.model import Assistant, User, get_db
from app.schemas import * from app.schemas import *
from app.dependency import get_logger from app.dependency import get_logger
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError from app.exception import *
import uvicorn import uvicorn
import uuid import uuid
import json import json
@ -222,13 +222,15 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
agent.save() agent.save()
logger.debug("音频保存成功") logger.debug("音频保存成功")
except AsrResultNoneError: except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
except AbnormalLLMFrame as e: except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
except SideNoiseError as e: except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
except SessionNotFoundError: except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
except UnknownVolcEngineModelError:
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
logger.debug("WebSocket连接断开") logger.debug("WebSocket连接断开")
await ws.close() await ws.close()
# -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------