update: 封装了豆包模型
This commit is contained in:
parent
286e83e025
commit
76c7adc1b5
|
@ -1,4 +1,5 @@
|
|||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from .model import Assistant
|
||||
from .abstract import *
|
||||
from .public import *
|
||||
|
@ -98,9 +99,9 @@ class MINIMAX_LLM(LLM):
|
|||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"msg":""}
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['msg'] = 'max_token reached'
|
||||
msg_frame['code'] = '201'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
db.commit()
|
||||
|
@ -130,6 +131,51 @@ class MINIMAX_LLM(LLM):
|
|||
logger.error(llm_chunk)
|
||||
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
|
||||
|
||||
class VOLCENGINE_LLM(LLM):
|
||||
def __init__(self):
|
||||
self.token = 0
|
||||
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||
|
||||
async def chat(self, assistant, prompt, db):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
model = self.__get_model(llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
stream_options={'include_usage': True}
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":20-0,"msg":""}
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
db.commit()
|
||||
assistant.messages = as_query.messages
|
||||
yield msg_frame
|
||||
|
||||
def __get_model(self, llm_info):
|
||||
if llm_info['model'] == 'doubao-4k-lite':
|
||||
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
||||
else:
|
||||
raise UnknownVolcEngineModelError()
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
if llm_chunk.usage:
|
||||
self.token = llm_chunk.usage.total_tokens
|
||||
raise LLMResponseEnd()
|
||||
if not llm_chunk.choices:
|
||||
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
|
||||
return llm_chunk.choices[0].delta.content
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -151,6 +197,8 @@ class LLMFactory:
|
|||
def create_llm(self,llm_type:str) -> LLM:
|
||||
if llm_type == 'MINIMAX':
|
||||
return MINIMAX_LLM()
|
||||
if llm_type == 'VOLCENGINE':
|
||||
return VOLCENGINE_LLM()
|
||||
|
||||
class TTSFactory:
|
||||
def create_tts(self,tts_type:str) -> TTS:
|
||||
|
|
|
@ -10,6 +10,12 @@ class NoAsrResultsError(Exception):
|
|||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 未知的火山引擎模型
|
||||
class UnknownVolcEngineModelError(Exception):
|
||||
def __init__(self, message="Unknown Volc Engine Model!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 未知LLM返回帧
|
||||
class UnkownLLMFrame(Exception):
|
||||
def __init__(self, message="Unkown LLM Frame!"):
|
||||
|
|
|
@ -18,3 +18,6 @@ class Config:
|
|||
class MINIMAX_LLM:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
class VOLCENGINE_LLM:
|
||||
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
||||
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
12
main.py
12
main.py
|
@ -5,7 +5,7 @@ from app.concrete import Agent
|
|||
from app.model import Assistant, User, get_db
|
||||
from app.schemas import *
|
||||
from app.dependency import get_logger
|
||||
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError
|
||||
from app.exception import *
|
||||
import uvicorn
|
||||
import uuid
|
||||
import json
|
||||
|
@ -222,13 +222,15 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
|||
agent.save()
|
||||
logger.debug("音频保存成功")
|
||||
except AsrResultNoneError:
|
||||
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||
except AbnormalLLMFrame as e:
|
||||
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
|
||||
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
||||
except SideNoiseError as e:
|
||||
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
|
||||
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
|
||||
except SessionNotFoundError:
|
||||
await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False))
|
||||
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
||||
except UnknownVolcEngineModelError:
|
||||
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
||||
logger.debug("WebSocket连接断开")
|
||||
await ws.close()
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue