From 286e83e025e57b279f72d7b88308945809be7364 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Tue, 11 Jun 2024 15:11:23 +0800 Subject: [PATCH] =?UTF-8?q?update:=20=E7=8E=B0=E5=9C=A8=E4=BC=9A=E6=A0=B9?= =?UTF-8?q?=E6=8D=AE=E6=AF=8F=E6=AC=A1=E8=AF=B7=E6=B1=82=E7=9A=84=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E7=94=9F=E6=88=90=E7=9B=B8=E5=BA=94=E7=9A=84Agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/exception.py | 6 ++++++ app/schemas.py | 1 + main.py | 20 +++++++++++++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/app/exception.py b/app/exception.py index ea02356..32f0254 100644 --- a/app/exception.py +++ b/app/exception.py @@ -33,6 +33,12 @@ class SideNoiseError(Exception): def __init__(self, message="Side Noise!"): super().__init__(message) self.message = message + +# Session不存在异常 +class SessionNotFoundError(Exception): + def __init__(self, message="Session Not Found!"): + super().__init__(message) + self.message = message # 大模型返回结束(非异常) class LLMResponseEnd(Exception): diff --git a/app/schemas.py b/app/schemas.py index f8ffe81..ffb2aea 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -30,6 +30,7 @@ class update_assistant_system_prompt_request(BaseModel): system_prompt:str class update_assistant_deatil_params_request(BaseModel): + platform:str model :str temperature :float speaker_id:int diff --git a/main.py b/main.py index 54cd84f..b9bbee9 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ from app.concrete import Agent from app.model import Assistant, User, get_db from app.schemas import * from app.dependency import get_logger -from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError +from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError import uvicorn import uuid import json @@ -108,12 +108,15 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati if assistant: llm_info = json.loads(assistant.llm_info) tts_info = json.loads(assistant.tts_info) + user_info = json.loads(assistant.user_info) llm_info['model'] = request.model llm_info['temperature'] = request.temperature tts_info['speaker_id'] = request.speaker_id tts_info['length_scale'] = request.length_scale + user_info['llm_type'] = request.platform assistant.llm_info = json.dumps(llm_info, ensure_ascii=False) assistant.tts_info = json.dumps(tts_info, ensure_ascii=False) + assistant.user_info = json.dumps(user_info, ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: @@ -152,6 +155,12 @@ async def get_user(id: str,db=Depends(get_db)): else: raise HTTPException(status_code=404, detail="user not found") +# 获取所有用户 +@app.get("/api/users") +async def get_all_users(db=Depends(get_db)): + users = db.query(User).all() + return {"code":200,"msg":"success","data":users} + # 更新用户 @app.put("/api/users/{id}") async def update_user(id: str,request: update_user_request,db=Depends(get_db)): @@ -181,9 +190,12 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): chunk = json.loads(await ws.receive_text()) if assistant is None: assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() - agent.init_recorder(assistant.user_id) + if assistant is None: + raise SessionNotFoundError() + user_info = json.loads(assistant.user_info) if not agent: - agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) + agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type']) + agent.init_recorder(assistant.user_id) chunk["audio"] = agent.user_audio_process(chunk["audio"], db) asr_results = await agent.stream_recognize(chunk, db) kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 @@ -215,6 +227,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False)) except SideNoiseError as e: await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False)) + except SessionNotFoundError: + await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False)) logger.debug("WebSocket连接断开") await ws.close() # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------