From 5a47440e0f211e41f8bc1f465c4b84bd6464d6d4 Mon Sep 17 00:00:00 2001 From: killua <1223086337@qq.com> Date: Thu, 20 Jun 2024 10:34:26 +0800 Subject: [PATCH] =?UTF-8?q?update:=E4=BC=98=E5=8C=96=E4=BA=86websocket?= =?UTF-8?q?=E5=AF=B9=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E7=9A=84?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E4=B8=8E=E9=87=8A=E6=94=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/abstract.py | 2 +- app/concrete.py | 50 ++++++++++++++++----------------- app/exception.py | 5 ++++ app/model.py | 13 +++++++-- config.py | 6 ++-- main.py | 64 +++++++++++++++++++++++++++++-------------- requirements.txt | 3 +- utils/xf_asr_utils.py | 2 +- 8 files changed, 91 insertions(+), 54 deletions(-) diff --git a/app/abstract.py b/app/abstract.py index e622f70..13947b7 100644 --- a/app/abstract.py +++ b/app/abstract.py @@ -10,7 +10,7 @@ class ASR(ABC): class LLM(ABC): @abstractmethod - def chat(self, assistant, prompt, db): + def chat(self, assistant, prompt): pass class TTS(ABC): diff --git a/app/concrete.py b/app/concrete.py index 8263b38..3c19988 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -56,20 +56,18 @@ class XF_ASR(ASR): await self.websocket.send(make_continue_frame(audio_data)) elif self.status == LAST_FRAME: #发送最后一帧 await self.websocket.send(make_last_frame(audio_data)) + logger.debug("发送完毕") self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) if self.current_message == "": raise AsrResultNoneError() - if "进入沉默模式" in self.current_message: + if self.current_message in ["啊"]: + raise SideNoiseError() + if "闭嘴" in self.current_message: self.is_slience = True asyncio.create_task(self.websocket.close()) raise EnterSlienceMode() - if "退出沉默模式" in self.current_message: - self.is_slience = False - self.current_message = "已退出沉默模式" - if self.is_slience: - asyncio.create_task(self.websocket.close()) - raise SlienceMode() asyncio.create_task(self.websocket.close()) + logger.debug(f"ASR结果: {self.current_message}") return [{"text":self.current_message, "audio":self.audio}] current_time = asyncio.get_event_loop().time() @@ -86,7 +84,7 @@ class MINIMAX_LLM(LLM): def __init__(self): self.token = 0 - async def chat(self, assistant, prompt, db): + async def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) @@ -111,12 +109,10 @@ class MINIMAX_LLM(LLM): yield msg_frame except LLMResponseEnd: msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session msg_frame['code'] = '201' - as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() - as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) - db.commit() - assistant.messages = as_query.messages + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) yield msg_frame @@ -147,7 +143,7 @@ class VOLCENGINE_LLM(LLM): self.token = 0 self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) - async def chat(self, assistant, prompt, db): + async def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) model = self.__get_model(llm_info) messages = json.loads(assistant.messages) @@ -167,18 +163,20 @@ class VOLCENGINE_LLM(LLM): msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: - msg_frame = {"is_end":True,"code":20-0,"msg":""} + msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session msg_frame['code'] = '201' - as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() - as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) - db.commit() - assistant.messages = as_query.messages + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) yield msg_frame def __get_model(self, llm_info): if llm_info['model'] == 'doubao-4k-lite': return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k + elif llm_info['model'] == 'doubao-32k-lite': + return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k + elif llm_info['model'] == 'doubao-32k-pro': + return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k else: raise UnknownVolcEngineModelError() @@ -322,31 +320,31 @@ class Agent(): self.recorder = Recorder(user_id) # 对用户输入的音频进行预处理 - def user_audio_process(self, audio, db): + def user_audio_process(self, audio): return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder) # 进行流式语音识别 - async def stream_recognize(self, chunk, db): + async def stream_recognize(self, chunk): return await self.asr.stream_recognize(chunk) # 进行Prompt加工 - def prompt_process(self, asr_results, db): + def prompt_process(self, asr_results): return self.prompt_service_chain.prompt_process(asr_results) # 进行大模型调用 - async def chat(self, assistant ,prompt, db): - return self.llm.chat(assistant, prompt, db) + async def chat(self, assistant ,prompt): + return self.llm.chat(assistant, prompt) # 对大模型的返回进行处理 - def llm_msg_process(self, llm_chunk, db): + def llm_msg_process(self, llm_chunk): return self.llm_msg_service_chain.llm_msg_process(llm_chunk) # 进行TTS合成 - def synthetize(self, assistant, text, db): + def synthetize(self, assistant, text): return self.tts.synthetize(assistant, text) # 对合成后的音频进行处理 - def tts_audio_process(self, audio, db): + def tts_audio_process(self, audio): return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder) # 编码 diff --git a/app/exception.py b/app/exception.py index f5fdbcd..b8f7133 100644 --- a/app/exception.py +++ b/app/exception.py @@ -46,6 +46,11 @@ class SessionNotFoundError(Exception): super().__init__(message) self.message = message +class LlmResultNoneError(Exception): + def __init__(self, message="LLM Result is None!"): + super().__init__(message) + self.message = message + # 大模型返回结束(非异常) class LLMResponseEnd(Exception): def __init__(self, message="LLM Response End!"): diff --git a/app/model.py b/app/model.py index 84336d0..188c3c5 100644 --- a/app/model.py +++ b/app/model.py @@ -1,8 +1,10 @@ from sqlalchemy import create_engine, Column, Integer, String, CHAR from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session +from contextlib import contextmanager from config import Config + engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() @@ -30,10 +32,17 @@ class Assistant(Base): Base.metadata.create_all(bind=engine) -def get_db(): +@contextmanager +def get_db_context(): db = SessionLocal() try: yield db finally: db.close() - \ No newline at end of file + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/config.py b/config.py index a218b77..a31ce8d 100644 --- a/config.py +++ b/config.py @@ -6,7 +6,7 @@ class Config: LOG_LEVEL = "DEBUG" class UVICORN: HOST = '0.0.0.0' - PORT = 7878 + PORT = 8001 class XF_ASR: APP_ID = "f1c121c1" #讯飞语音识别APP_ID API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET @@ -20,4 +20,6 @@ class Config: URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" class VOLCENGINE_LLM: API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d" - DOUBAO_LITE_4k = "ep-20240612075552-5c7tk" \ No newline at end of file + DOUBAO_LITE_4k = "ep-20240612075552-5c7tk" + DOUBAO_LITE_32k = "ep-20240618130753-q85dm" + DOUBAO_PRO_32k = "ep-20240618145315-pm2c6" \ No newline at end of file diff --git a/main.py b/main.py index 4962d82..dd97d21 100644 --- a/main.py +++ b/main.py @@ -2,20 +2,29 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException from fastapi.middleware.cors import CORSMiddleware from config import Config from app.concrete import Agent -from app.model import Assistant, User, get_db +from app.model import Assistant, User, get_db, get_db_context from app.schemas import * from app.dependency import get_logger from app.exception import * +import asyncio import uvicorn import uuid import json +import time + # 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -def update_messages(messages, kid_text,llm_text): - messages = json.loads(messages) +def update_messages(assistant, kid_text, llm_text): + messages = json.loads(assistant.messages) + if not kid_text: + raise AsrResultNoneError() + if not llm_text: + raise LlmResultNoneError() messages.append({"role":"user","content":kid_text}) messages.append({"role":"assistant","content":llm_text}) - return json.dumps(messages,ensure_ascii=False) + with get_db_context() as db: + db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token}) + db.commit() # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # 引入logger对象 @@ -177,7 +186,7 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)): # 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------ @app.websocket("/api/chat/streaming/temporary") -async def streaming_chat(ws: WebSocket,db=Depends(get_db)): +async def streaming_chat(ws: WebSocket): await ws.accept() logger.debug("WebSocket连接成功") try: @@ -187,57 +196,70 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): llm_text = "" logger.debug("开始进行ASR识别") while len(asr_results)==0: - chunk = json.loads(await ws.receive_text()) + chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2)) if assistant is None: - assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() + with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接 + assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() if assistant is None: raise SessionNotFoundError() user_info = json.loads(assistant.user_info) if not agent: agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type']) agent.init_recorder(assistant.user_id) - chunk["audio"] = agent.user_audio_process(chunk["audio"], db) - asr_results = await agent.stream_recognize(chunk, db) + chunk["audio"] = agent.user_audio_process(chunk["audio"]) + asr_results = await agent.stream_recognize(chunk) kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 - logger.debug(f"ASR识别成功,识别结果为:{kid_text}") - prompt = agent.prompt_process(asr_results, db) + prompt = agent.prompt_process(asr_results) agent.recorder.input_text = prompt logger.debug("开始调用大模型") - llm_frames = await agent.chat(assistant, prompt, db) + llm_frames = await agent.chat(assistant, prompt) async for llm_frame in llm_frames: - resp_msgs = agent.llm_msg_process(llm_frame, db) + resp_msgs = agent.llm_msg_process(llm_frame) for resp_msg in resp_msgs: llm_text += resp_msg - tts_audio = agent.synthetize(assistant, resp_msg, db) - agent.tts_audio_process(tts_audio, db) + tts_start_time = time.time() + tts_audio = agent.synthetize(assistant, resp_msg) + tts_end_time = time.time() + logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s") + agent.tts_audio_process(tts_audio) await ws.send_bytes(agent.encode(resp_msg, tts_audio)) logger.debug(f'websocket返回:{resp_msg}') logger.debug(f"大模型返回结束,返回结果为:{llm_text}") await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) logger.debug("结束帧发送完毕") - assistant.messages = update_messages(assistant.messages, kid_text ,llm_text) - db.commit() + update_messages(assistant, kid_text ,llm_text) logger.debug("聊天更新成功") agent.recorder.output_text = llm_text agent.save() logger.debug("音频保存成功") except EnterSlienceMode: - tts_audio = agent.synthetize(assistant, "已进入沉默模式", db) + tts_audio = agent.synthetize(assistant, "已进入沉默模式") await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio)) await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False)) - except SlienceMode: - await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False)) + logger.debug("进入沉默模式") except AsrResultNoneError: await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False)) + logger.error("ASR结果为空") except AbnormalLLMFrame as e: await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False)) + logger.error(f"LLM模型返回异常,错误信息:{str(e)}") except SideNoiseError as e: - await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False)) + await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False)) + logger.debug("检测为噪声") except SessionNotFoundError: await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False)) + logger.error("session不存在") except UnknownVolcEngineModelError: await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False)) + logger.error("未知的火山引擎模型") + except LlmResultNoneError: + await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False)) + logger.error("LLM结果返回为空") + except asyncio.TimeoutError: + await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False)) + logger.error("接收超时") logger.debug("WebSocket连接断开") + logger.debug("") await ws.close() # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index 24bbb3c..e6c81a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ jieba cn2an numba librosa -aiohttp \ No newline at end of file +aiohttp +'volcengine-python-sdk[ark]' \ No newline at end of file diff --git a/utils/xf_asr_utils.py b/utils/xf_asr_utils.py index 3b1b071..9581ad0 100644 --- a/utils/xf_asr_utils.py +++ b/utils/xf_asr_utils.py @@ -38,7 +38,7 @@ def generate_xf_asr_url(): def make_first_frame(buf): - first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000}, + first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vinfo":1,"vad_eos":1000}, "data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}} return json.dumps(first_frame)