update: 现在会根据每次请求的信息生成相应的Agent

This commit is contained in:
killua4396 2024-06-11 15:11:23 +08:00
parent d885684533
commit 286e83e025
3 changed files with 24 additions and 3 deletions

View File

@ -33,6 +33,12 @@ class SideNoiseError(Exception):
def __init__(self, message="Side Noise!"):
super().__init__(message)
self.message = message
# Session不存在异常
class SessionNotFoundError(Exception):
def __init__(self, message="Session Not Found!"):
super().__init__(message)
self.message = message
# 大模型返回结束(非异常)
class LLMResponseEnd(Exception):

View File

@ -30,6 +30,7 @@ class update_assistant_system_prompt_request(BaseModel):
system_prompt:str
class update_assistant_deatil_params_request(BaseModel):
platform:str
model :str
temperature :float
speaker_id:int

20
main.py
View File

@ -5,7 +5,7 @@ from app.concrete import Agent
from app.model import Assistant, User, get_db
from app.schemas import *
from app.dependency import get_logger
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError
import uvicorn
import uuid
import json
@ -108,12 +108,15 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
user_info = json.loads(assistant.user_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
user_info['llm_type'] = request.platform
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
@ -152,6 +155,12 @@ async def get_user(id: str,db=Depends(get_db)):
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取所有用户
@app.get("/api/users")
async def get_all_users(db=Depends(get_db)):
users = db.query(User).all()
return {"code":200,"msg":"success","data":users}
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
@ -181,9 +190,12 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
chunk = json.loads(await ws.receive_text())
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id)
if assistant is None:
raise SessionNotFoundError()
user_info = json.loads(assistant.user_info)
if not agent:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
@ -215,6 +227,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False))
logger.debug("WebSocket连接断开")
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------