forked from killua/TakwayDisplayPlatform
update: 现在会根据每次请求的信息生成相应的Agent
This commit is contained in:
parent
d885684533
commit
286e83e025
|
@ -34,6 +34,12 @@ class SideNoiseError(Exception):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
# Session不存在异常
|
||||||
|
class SessionNotFoundError(Exception):
|
||||||
|
def __init__(self, message="Session Not Found!"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
||||||
# 大模型返回结束(非异常)
|
# 大模型返回结束(非异常)
|
||||||
class LLMResponseEnd(Exception):
|
class LLMResponseEnd(Exception):
|
||||||
def __init__(self, message="LLM Response End!"):
|
def __init__(self, message="LLM Response End!"):
|
||||||
|
|
|
@ -30,6 +30,7 @@ class update_assistant_system_prompt_request(BaseModel):
|
||||||
system_prompt:str
|
system_prompt:str
|
||||||
|
|
||||||
class update_assistant_deatil_params_request(BaseModel):
|
class update_assistant_deatil_params_request(BaseModel):
|
||||||
|
platform:str
|
||||||
model :str
|
model :str
|
||||||
temperature :float
|
temperature :float
|
||||||
speaker_id:int
|
speaker_id:int
|
||||||
|
|
20
main.py
20
main.py
|
@ -5,7 +5,7 @@ from app.concrete import Agent
|
||||||
from app.model import Assistant, User, get_db
|
from app.model import Assistant, User, get_db
|
||||||
from app.schemas import *
|
from app.schemas import *
|
||||||
from app.dependency import get_logger
|
from app.dependency import get_logger
|
||||||
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
|
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError, SessionNotFoundError
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
@ -108,12 +108,15 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
|
||||||
if assistant:
|
if assistant:
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
tts_info = json.loads(assistant.tts_info)
|
tts_info = json.loads(assistant.tts_info)
|
||||||
|
user_info = json.loads(assistant.user_info)
|
||||||
llm_info['model'] = request.model
|
llm_info['model'] = request.model
|
||||||
llm_info['temperature'] = request.temperature
|
llm_info['temperature'] = request.temperature
|
||||||
tts_info['speaker_id'] = request.speaker_id
|
tts_info['speaker_id'] = request.speaker_id
|
||||||
tts_info['length_scale'] = request.length_scale
|
tts_info['length_scale'] = request.length_scale
|
||||||
|
user_info['llm_type'] = request.platform
|
||||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||||||
|
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"code":200,"msg":"success","data":{}}
|
return {"code":200,"msg":"success","data":{}}
|
||||||
else:
|
else:
|
||||||
|
@ -152,6 +155,12 @@ async def get_user(id: str,db=Depends(get_db)):
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="user not found")
|
raise HTTPException(status_code=404, detail="user not found")
|
||||||
|
|
||||||
|
# 获取所有用户
|
||||||
|
@app.get("/api/users")
|
||||||
|
async def get_all_users(db=Depends(get_db)):
|
||||||
|
users = db.query(User).all()
|
||||||
|
return {"code":200,"msg":"success","data":users}
|
||||||
|
|
||||||
# 更新用户
|
# 更新用户
|
||||||
@app.put("/api/users/{id}")
|
@app.put("/api/users/{id}")
|
||||||
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
||||||
|
@ -181,9 +190,12 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
chunk = json.loads(await ws.receive_text())
|
chunk = json.loads(await ws.receive_text())
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||||
agent.init_recorder(assistant.user_id)
|
if assistant is None:
|
||||||
|
raise SessionNotFoundError()
|
||||||
|
user_info = json.loads(assistant.user_info)
|
||||||
if not agent:
|
if not agent:
|
||||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
||||||
|
agent.init_recorder(assistant.user_id)
|
||||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||||
asr_results = await agent.stream_recognize(chunk, db)
|
asr_results = await agent.stream_recognize(chunk, db)
|
||||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||||
|
@ -215,6 +227,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
|
||||||
except SideNoiseError as e:
|
except SideNoiseError as e:
|
||||||
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
|
||||||
|
except SessionNotFoundError:
|
||||||
|
await ws.send_text(json.dumps({"type":"close","code":204,"msg":"session未找到"}, ensure_ascii=False))
|
||||||
logger.debug("WebSocket连接断开")
|
logger.debug("WebSocket连接断开")
|
||||||
await ws.close()
|
await ws.close()
|
||||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue