diff --git a/app/concrete.py b/app/concrete.py index 3ac654c..b218811 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -1,4 +1,5 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv +from volcenginesdkarkruntime import Ark from .model import Assistant from .abstract import * from .public import * @@ -98,9 +99,9 @@ class MINIMAX_LLM(LLM): msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: - msg_frame = {"is_end":True,"msg":""} + msg_frame = {"is_end":True,"code":200,"msg":""} if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session - msg_frame['msg'] = 'max_token reached' + msg_frame['code'] = '201' as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) db.commit() @@ -130,6 +131,51 @@ class MINIMAX_LLM(LLM): logger.error(llm_chunk) raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}") +class VOLCENGINE_LLM(LLM): + def __init__(self): + self.token = 0 + self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) + + async def chat(self, assistant, prompt, db): + llm_info = json.loads(assistant.llm_info) + model = self.__get_model(llm_info) + messages = json.loads(assistant.messages) + messages.append({'role':'user','content':prompt}) + stream = self.client.chat.completions.create( + model = model, + messages=messages, + stream=True, + stream_options={'include_usage': True} + ) + for chunk in stream: + try: + chunk_msg = self.__parseChunk(chunk) + msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} + yield msg_frame + except LLMResponseEnd: + msg_frame = {"is_end":True,"code":20-0,"msg":""} + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session + msg_frame['code'] = '201' + as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() + as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) + db.commit() + assistant.messages = as_query.messages + yield msg_frame + + def __get_model(self, llm_info): + if llm_info['model'] == 'doubao-4k-lite': + return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k + else: + raise UnknownVolcEngineModelError() + + def __parseChunk(self, llm_chunk): + if llm_chunk.usage: + self.token = llm_chunk.usage.total_tokens + raise LLMResponseEnd() + if not llm_chunk.choices: + raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}") + return llm_chunk.choices[0].delta.content + class VITS_TTS(TTS): def __init__(self): pass @@ -151,6 +197,8 @@ class LLMFactory: def create_llm(self,llm_type:str) -> LLM: if llm_type == 'MINIMAX': return MINIMAX_LLM() + if llm_type == 'VOLCENGINE': + return VOLCENGINE_LLM() class TTSFactory: def create_tts(self,tts_type:str) -> TTS: diff --git a/app/exception.py b/app/exception.py index 32f0254..dd0e83b 100644 --- a/app/exception.py +++ b/app/exception.py @@ -10,6 +10,12 @@ class NoAsrResultsError(Exception): super().__init__(message) self.message = message +# 未知的火山引擎模型 +class UnknownVolcEngineModelError(Exception): + def __init__(self, message="Unknown Volc Engine Model!"): + super().__init__(message) + self.message = message + # 未知LLM返回帧 class UnkownLLMFrame(Exception): def __init__(self, message="Unkown LLM Frame!"): diff --git a/config.py b/config.py index 4282b57..a218b77 100644 --- a/config.py +++ b/config.py @@ -17,4 +17,7 @@ class Config: VAD_EOS = 10000 class MINIMAX_LLM: API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg" - URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" \ No newline at end of file + URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" + class VOLCENGINE_LLM: + API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d" + DOUBAO_LITE_4k = "ep-20240612075552-5c7tk" \ No newline at end of file diff --git a/main.py b/main.py index b9bbee9..200136b 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ from app.concrete import Agent from app.model import Assistant, User, get_db from app.schemas import * from app.dependency import get_logger -from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError +from app.exception import * import uvicorn import uuid import json @@ -222,13 +222,15 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): agent.save() logger.debug("音频保存成功") except AsrResultNoneError: - await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False)) + await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False)) except AbnormalLLMFrame as e: - await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False)) + await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False)) except SideNoiseError as e: - await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False)) + await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False)) except SessionNotFoundError: - await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False)) + await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False)) + except UnknownVolcEngineModelError: + await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False)) logger.debug("WebSocket连接断开") await ws.close() # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------