TakwayDisplayPlatform/main.py

253 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent
from app.model import Assistant, User, get_db
from app.schemas import *
from app.dependency import get_logger
from app.exception import *
import uvicorn
import uuid
import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, kid_text,llm_text):
messages = json.loads(messages)
messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
logger = get_logger()
# 创建FastAPI实例
app = FastAPI()
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建一个assistant
@app.post("/api/assistants")
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
id = str(uuid.uuid4())
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
db.add(assistant)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除一个assistant
@app.delete("/api/assistants/{id}")
async def delete_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
db.delete(assistant)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新一个assistant
@app.put("/api/assistants/{id}")
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.name = request.name
assistant.system_prompt = request.system_prompt
assistant.messages = request.messages
assistant.user_info = request.user_info
assistant.llm_info = request.llm_info
assistant.tts_info = request.tts_info
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取一个assistant
@app.get("/api/assistants/{id}")
async def get_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
return {"code":200,"msg":"success","data":assistant}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取所有的assistant名称和id
@app.get("/api/assistants")
async def get_all_assistants_name_id(db=Depends(get_db)):
assistants = db.query(Assistant.id, Assistant.name).all()
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
# 重置一个assistant的消息
@app.post("/api/assistants/{id}/reset_msg")
async def reset_assistant_msg(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 修改一个assistant的system_prompt
@app.put("/api/assistants/{id}/system_prompt")
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新具体参数
@app.put("/api/assistants/{id}/deatil_params")
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
user_info = json.loads(assistant.user_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
user_info['llm_type'] = request.platform
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 添加用户
@app.post("/api/users")
async def create_user(request: create_user_request,db=Depends(get_db)):
id = str(uuid.uuid4())
user = User(id=id, name=request.name, email=request.email, password=request.password)
db.add(user)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除用户
@app.delete("/api/users/{id}")
async def delete_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
db.delete(user)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取用户
@app.get("/api/users/{id}")
async def get_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
return {"code":200,"msg":"success","data":user}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取所有用户
@app.get("/api/users")
async def get_all_users(db=Depends(get_db)):
users = db.query(User).all()
return {"code":200,"msg":"success","data":users}
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
user.name = request.name
user.email = request.email
user.password = request.password
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
logger.debug("WebSocket连接成功")
try:
agent = None
assistant = None
asr_results = []
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
if assistant is None:
raise SessionNotFoundError()
user_info = json.loads(assistant.user_info)
if not agent:
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}")
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
db.commit()
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
except EnterSlienceMode:
tts_audio = agent.synthetize(assistant, "已进入沉默模式", db)
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
except SlienceMode:
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False))
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"error","code":503,"msg":str(e)}, ensure_ascii=False))
except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
except UnknownVolcEngineModelError:
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
logger.debug("WebSocket连接断开")
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 启动服务
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)