TakwayDisplayPlatform/main.py

229 lines
10 KiB
Python
Raw Normal View History

2024-06-09 22:54:13 +08:00
from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent
2024-06-09 22:54:13 +08:00
from app.model import Assistant, User, get_db
from app.schemas import *
from app.dependency import get_logger
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
2024-06-09 22:54:13 +08:00
import uvicorn
import uuid
import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, kid_text,llm_text):
2024-06-09 22:54:13 +08:00
messages = json.loads(messages)
messages.append({"role":"user","content":kid_text})
2024-06-09 22:54:13 +08:00
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
logger = get_logger()
2024-06-09 22:54:13 +08:00
# 创建FastAPI实例
app = FastAPI()
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建一个assistant
@app.post("/api/assistants")
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
id = str(uuid.uuid4())
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
db.add(assistant)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除一个assistant
@app.delete("/api/assistants/{id}")
async def delete_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
db.delete(assistant)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新一个assistant
@app.put("/api/assistants/{id}")
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.name = request.name
assistant.system_prompt = request.system_prompt
assistant.messages = request.messages
assistant.user_info = request.user_info
assistant.llm_info = request.llm_info
assistant.tts_info = request.tts_info
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取一个assistant
@app.get("/api/assistants/{id}")
async def get_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
return {"code":200,"msg":"success","data":assistant}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取所有的assistant名称和id
@app.get("/api/assistants")
async def get_all_assistants_name_id(db=Depends(get_db)):
assistants = db.query(Assistant.id, Assistant.name).all()
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
# 重置一个assistant的消息
@app.post("/api/assistants/{id}/reset_msg")
async def reset_assistant_msg(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 修改一个assistant的system_prompt
@app.put("/api/assistants/{id}/system_prompt")
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
2024-06-09 22:54:13 +08:00
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新具体参数
@app.put("/api/assistants/{id}/deatil_params")
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 添加用户
@app.post("/api/users")
async def create_user(request: create_user_request,db=Depends(get_db)):
id = str(uuid.uuid4())
user = User(id=id, name=request.name, email=request.email, password=request.password)
db.add(user)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除用户
@app.delete("/api/users/{id}")
async def delete_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
db.delete(user)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取用户
@app.get("/api/users/{id}")
async def get_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
return {"code":200,"msg":"success","data":user}
else:
raise HTTPException(status_code=404, detail="user not found")
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
user.name = request.name
user.email = request.email
user.password = request.password
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
logger.debug("WebSocket连接成功")
2024-06-09 22:54:13 +08:00
try:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
logger.debug("开始进行ASR识别")
2024-06-09 22:54:13 +08:00
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}")
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
db.commit()
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
2024-06-09 22:54:13 +08:00
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
logger.debug("WebSocket连接断开")
2024-06-09 22:54:13 +08:00
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 启动服务
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)