forked from killua/TakwayDisplayPlatform
update: 封装了豆包模型
This commit is contained in:
parent
286e83e025
commit
76c7adc1b5
|
@ -1,4 +1,5 @@
|
||||||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
from .model import Assistant
|
from .model import Assistant
|
||||||
from .abstract import *
|
from .abstract import *
|
||||||
from .public import *
|
from .public import *
|
||||||
|
@ -98,9 +99,9 @@ class MINIMAX_LLM(LLM):
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
except LLMResponseEnd:
|
except LLMResponseEnd:
|
||||||
msg_frame = {"is_end":True,"msg":""}
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
msg_frame['msg'] = 'max_token reached'
|
msg_frame['code'] = '201'
|
||||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
db.commit()
|
db.commit()
|
||||||
|
@ -130,6 +131,51 @@ class MINIMAX_LLM(LLM):
|
||||||
logger.error(llm_chunk)
|
logger.error(llm_chunk)
|
||||||
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
|
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
|
||||||
|
|
||||||
|
class VOLCENGINE_LLM(LLM):
|
||||||
|
def __init__(self):
|
||||||
|
self.token = 0
|
||||||
|
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||||
|
|
||||||
|
async def chat(self, assistant, prompt, db):
|
||||||
|
llm_info = json.loads(assistant.llm_info)
|
||||||
|
model = self.__get_model(llm_info)
|
||||||
|
messages = json.loads(assistant.messages)
|
||||||
|
messages.append({'role':'user','content':prompt})
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model = model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
stream_options={'include_usage': True}
|
||||||
|
)
|
||||||
|
for chunk in stream:
|
||||||
|
try:
|
||||||
|
chunk_msg = self.__parseChunk(chunk)
|
||||||
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
|
yield msg_frame
|
||||||
|
except LLMResponseEnd:
|
||||||
|
msg_frame = {"is_end":True,"code":20-0,"msg":""}
|
||||||
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
|
msg_frame['code'] = '201'
|
||||||
|
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||||
|
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
|
db.commit()
|
||||||
|
assistant.messages = as_query.messages
|
||||||
|
yield msg_frame
|
||||||
|
|
||||||
|
def __get_model(self, llm_info):
|
||||||
|
if llm_info['model'] == 'doubao-4k-lite':
|
||||||
|
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
||||||
|
else:
|
||||||
|
raise UnknownVolcEngineModelError()
|
||||||
|
|
||||||
|
def __parseChunk(self, llm_chunk):
|
||||||
|
if llm_chunk.usage:
|
||||||
|
self.token = llm_chunk.usage.total_tokens
|
||||||
|
raise LLMResponseEnd()
|
||||||
|
if not llm_chunk.choices:
|
||||||
|
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
|
||||||
|
return llm_chunk.choices[0].delta.content
|
||||||
|
|
||||||
class VITS_TTS(TTS):
|
class VITS_TTS(TTS):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -151,6 +197,8 @@ class LLMFactory:
|
||||||
def create_llm(self,llm_type:str) -> LLM:
|
def create_llm(self,llm_type:str) -> LLM:
|
||||||
if llm_type == 'MINIMAX':
|
if llm_type == 'MINIMAX':
|
||||||
return MINIMAX_LLM()
|
return MINIMAX_LLM()
|
||||||
|
if llm_type == 'VOLCENGINE':
|
||||||
|
return VOLCENGINE_LLM()
|
||||||
|
|
||||||
class TTSFactory:
|
class TTSFactory:
|
||||||
def create_tts(self,tts_type:str) -> TTS:
|
def create_tts(self,tts_type:str) -> TTS:
|
||||||
|
|
|
@ -10,6 +10,12 @@ class NoAsrResultsError(Exception):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
# 未知的火山引擎模型
|
||||||
|
class UnknownVolcEngineModelError(Exception):
|
||||||
|
def __init__(self, message="Unknown Volc Engine Model!"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
# 未知LLM返回帧
|
# 未知LLM返回帧
|
||||||
class UnkownLLMFrame(Exception):
|
class UnkownLLMFrame(Exception):
|
||||||
def __init__(self, message="Unkown LLM Frame!"):
|
def __init__(self, message="Unkown LLM Frame!"):
|
||||||
|
|
|
@ -18,3 +18,6 @@ class Config:
|
||||||
class MINIMAX_LLM:
|
class MINIMAX_LLM:
|
||||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
||||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||||
|
class VOLCENGINE_LLM:
|
||||||
|
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
||||||
|
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
12
main.py
12
main.py
|
@ -5,7 +5,7 @@ from app.concrete import Agent
|
||||||
from app.model import Assistant, User, get_db
|
from app.model import Assistant, User, get_db
|
||||||
from app.schemas import *
|
from app.schemas import *
|
||||||
from app.dependency import get_logger
|
from app.dependency import get_logger
|
||||||
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError
|
from app.exception import *
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
@ -222,13 +222,15 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
agent.save()
|
agent.save()
|
||||||
logger.debug("音频保存成功")
|
logger.debug("音频保存成功")
|
||||||
except AsrResultNoneError:
|
except AsrResultNoneError:
|
||||||
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||||
except AbnormalLLMFrame as e:
|
except AbnormalLLMFrame as e:
|
||||||
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
||||||
except SideNoiseError as e:
|
except SideNoiseError as e:
|
||||||
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
|
||||||
except SessionNotFoundError:
|
except SessionNotFoundError:
|
||||||
await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
||||||
|
except UnknownVolcEngineModelError:
|
||||||
|
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
||||||
logger.debug("WebSocket连接断开")
|
logger.debug("WebSocket连接断开")
|
||||||
await ws.close()
|
await ws.close()
|
||||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue