1
0
Fork 0
TakwayDisplayPlatform/main.py

208 lines
9.3 KiB
Python

from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent, AsrResultNoneError
from app.model import Assistant, User, get_db
from app.schemas import *
import uvicorn
import uuid
import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, llm_text):
messages = json.loads(messages)
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建FastAPI实例
app = FastAPI()
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建一个assistant
@app.post("/api/assistants")
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
id = str(uuid.uuid4())
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
db.add(assistant)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除一个assistant
@app.delete("/api/assistants/{id}")
async def delete_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
db.delete(assistant)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新一个assistant
@app.put("/api/assistants/{id}")
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.name = request.name
assistant.system_prompt = request.system_prompt
assistant.messages = request.messages
assistant.user_info = request.user_info
assistant.llm_info = request.llm_info
assistant.tts_info = request.tts_info
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取一个assistant
@app.get("/api/assistants/{id}")
async def get_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
return {"code":200,"msg":"success","data":assistant}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取所有的assistant名称和id
@app.get("/api/assistants")
async def get_all_assistants_name_id(db=Depends(get_db)):
assistants = db.query(Assistant.id, Assistant.name).all()
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
# 重置一个assistant的消息
@app.post("/api/assistants/{id}/reset_msg")
async def reset_assistant_msg(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 修改一个assistant的system_prompt
@app.put("/api/assistants/{id}/system_prompt")
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新具体参数
@app.put("/api/assistants/{id}/deatil_params")
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 添加用户
@app.post("/api/users")
async def create_user(request: create_user_request,db=Depends(get_db)):
id = str(uuid.uuid4())
user = User(id=id, name=request.name, email=request.email, password=request.password)
db.add(user)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除用户
@app.delete("/api/users/{id}")
async def delete_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
db.delete(user)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取用户
@app.get("/api/users/{id}")
async def get_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
return {"code":200,"msg":"success","data":user}
else:
raise HTTPException(status_code=404, detail="user not found")
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
user.name = request.name
user.email = request.email
user.password = request.password
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
try:
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False))
return
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
assistant.messages = update_messages(assistant.messages, llm_text)
db.commit()
agent.recorder.output_text = llm_text
agent.save()
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 启动服务
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)