1
0
Fork 0
TakwayDisplayPlatform/main.py

294 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent
from app.model import Assistant, User, get_db, get_db_context
from app.schemas import *
from app.dependency import get_logger
from app.exception import *
import asyncio
import uvicorn
import uuid
import json
import time
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(assistant, kid_text, llm_text):
messages = json.loads(assistant.messages)
if not kid_text:
raise AsrResultNoneError()
if not llm_text:
raise LlmResultNoneError()
messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text})
with get_db_context() as db:
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
db.commit()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
logger = get_logger()
# 创建FastAPI实例
app = FastAPI()
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建一个assistant
@app.post("/api/assistants")
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
id = str(uuid.uuid4())
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
db.add(assistant)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除一个assistant
@app.delete("/api/assistants/{id}")
async def delete_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
db.delete(assistant)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新一个assistant
@app.put("/api/assistants/{id}")
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.name = request.name
assistant.system_prompt = request.system_prompt
assistant.messages = request.messages
assistant.user_info = request.user_info
assistant.llm_info = request.llm_info
assistant.tts_info = request.tts_info
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取一个assistant
@app.get("/api/assistants/{id}")
async def get_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
return {"code":200,"msg":"success","data":assistant}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取所有的assistant名称和id
@app.get("/api/assistants")
async def get_all_assistants_name_id(db=Depends(get_db)):
assistants = db.query(Assistant.id, Assistant.name).all()
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
# 重置一个assistant的消息
@app.post("/api/assistants/{id}/reset_msg")
async def reset_assistant_msg(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 修改一个assistant的system_prompt
@app.put("/api/assistants/{id}/system_prompt")
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新具体参数
@app.put("/api/assistants/{id}/deatil_params")
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
user_info = json.loads(assistant.user_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
user_info['llm_type'] = request.platform
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新max_tokens
@app.put("/api/assistants/{id}/max_tokens")
async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
llm_info['max_tokens'] = request.max_tokens
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 添加用户
@app.post("/api/users")
async def create_user(request: create_user_request,db=Depends(get_db)):
id = str(uuid.uuid4())
user = User(id=id, name=request.name, email=request.email, password=request.password)
db.add(user)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除用户
@app.delete("/api/users/{id}")
async def delete_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
db.delete(user)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取用户
@app.get("/api/users/{id}")
async def get_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
return {"code":200,"msg":"success","data":user}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取所有用户
@app.get("/api/users")
async def get_all_users(db=Depends(get_db)):
users = db.query(User).all()
return {"code":200,"msg":"success","data":users}
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
user.name = request.name
user.email = request.email
user.password = request.password
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket):
await ws.accept()
logger.debug("WebSocket连接成功")
try:
agent = None
assistant = None
asr_results = []
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0:
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
if assistant is None:
with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
logger.debug(f"接收到{assistant.name}的请求")
if assistant is None:
raise SessionNotFoundError()
user_info = json.loads(assistant.user_info)
if not agent:
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"])
asr_results = await agent.stream_recognize(assistant, chunk)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
prompt = agent.prompt_process(asr_results)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt)
start_time = time.time()
is_first_response = True
for llm_frame in llm_frames:
if is_first_response:
end_time = time.time()
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
is_first_response = False
resp_msgs = agent.llm_msg_process(llm_frame)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg)
agent.tts_audio_process(tts_audio)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
update_messages(assistant, kid_text ,llm_text)
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
except EnterSlienceMode:
tts_audio = agent.synthetize(assistant, "已进入沉默模式")
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
logger.debug("进入沉默模式")
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
logger.error("ASR结果为空")
except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
logger.error(f"LLM模型返回异常错误信息{str(e)}")
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
logger.debug("检测为噪声")
except SessionNotFoundError:
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
logger.error("session不存在")
except UnknownVolcEngineModelError:
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
logger.error("未知的火山引擎模型")
except LlmResultNoneError:
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
logger.error("LLM结果返回为空")
except asyncio.TimeoutError:
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
logger.error("接收超时")
logger.debug("WebSocket连接断开")
logger.debug("")
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 启动服务
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)