295 lines
14 KiB
Python
295 lines
14 KiB
Python
from fastapi import FastAPI, Depends, WebSocket, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from config import Config
|
||
from app.concrete import Agent
|
||
from app.model import Assistant, User, get_db, get_db_context
|
||
from app.schemas import *
|
||
from app.dependency import get_logger
|
||
from app.exception import *
|
||
import asyncio
|
||
import uvicorn
|
||
import uuid
|
||
import json
|
||
import time
|
||
|
||
|
||
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
def update_messages(assistant, kid_text, llm_text):
|
||
messages = json.loads(assistant.messages)
|
||
if not kid_text:
|
||
raise AsrResultNoneError()
|
||
if not llm_text:
|
||
raise LlmResultNoneError()
|
||
messages.append({"role":"user","content":kid_text})
|
||
messages.append({"role":"assistant","content":llm_text})
|
||
with get_db_context() as db:
|
||
db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token})
|
||
db.commit()
|
||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
|
||
# 引入logger对象
|
||
logger = get_logger()
|
||
|
||
# 创建FastAPI实例
|
||
app = FastAPI()
|
||
|
||
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
# 创建一个assistant
|
||
@app.post("/api/assistants")
|
||
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
|
||
id = str(uuid.uuid4())
|
||
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
|
||
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
|
||
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
|
||
db.add(assistant)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{"id":id}}
|
||
|
||
# 删除一个assistant
|
||
@app.delete("/api/assistants/{id}")
|
||
async def delete_assistant(id: str,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
db.delete(assistant)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 更新一个assistant
|
||
@app.put("/api/assistants/{id}")
|
||
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
assistant.name = request.name
|
||
assistant.system_prompt = request.system_prompt
|
||
assistant.messages = request.messages
|
||
assistant.user_info = request.user_info
|
||
assistant.llm_info = request.llm_info
|
||
assistant.tts_info = request.tts_info
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 获取一个assistant
|
||
@app.get("/api/assistants/{id}")
|
||
async def get_assistant(id: str,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
return {"code":200,"msg":"success","data":assistant}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 获取所有的assistant名称和id
|
||
@app.get("/api/assistants")
|
||
async def get_all_assistants_name_id(db=Depends(get_db)):
|
||
assistants = db.query(Assistant.id, Assistant.name).all()
|
||
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
|
||
|
||
# 重置一个assistant的消息
|
||
@app.post("/api/assistants/{id}/reset_msg")
|
||
async def reset_assistant_msg(id: str,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 修改一个assistant的system_prompt
|
||
@app.put("/api/assistants/{id}/system_prompt")
|
||
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
assistant.system_prompt = request.system_prompt
|
||
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 更新具体参数
|
||
@app.put("/api/assistants/{id}/deatil_params")
|
||
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
llm_info = json.loads(assistant.llm_info)
|
||
tts_info = json.loads(assistant.tts_info)
|
||
user_info = json.loads(assistant.user_info)
|
||
llm_info['model'] = request.model
|
||
llm_info['temperature'] = request.temperature
|
||
tts_info['speaker_id'] = request.speaker_id
|
||
tts_info['length_scale'] = request.length_scale
|
||
tts_info['language'] = request.language
|
||
tts_info['style_text'] = request.style_text
|
||
tts_info['style_weight'] = request.style_weight
|
||
tts_info['sdp_ratio'] = 0.5
|
||
tts_info['opt_cut_by_send'] = False
|
||
tts_info['interval_between_para'] = 1.0
|
||
tts_info['interval_between_sent'] = 0.2
|
||
tts_info['en_ratio'] = 1.0
|
||
user_info['llm_type'] = request.platform
|
||
user_info['tts_type'] = request.tts_engine
|
||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
|
||
# 更新max_tokens
|
||
@app.put("/api/assistants/{id}/max_tokens")
|
||
async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)):
|
||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||
if assistant:
|
||
llm_info = json.loads(assistant.llm_info)
|
||
llm_info['max_tokens'] = request.max_tokens
|
||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
return {"code":404,'msg':"assistant not found","data":{}}
|
||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
|
||
|
||
|
||
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
# 添加用户
|
||
@app.post("/api/users")
|
||
async def create_user(request: create_user_request,db=Depends(get_db)):
|
||
id = str(uuid.uuid4())
|
||
user = User(id=id, name=request.name, email=request.email, password=request.password)
|
||
db.add(user)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{"id":id}}
|
||
|
||
# 删除用户
|
||
@app.delete("/api/users/{id}")
|
||
async def delete_user(id: str,db=Depends(get_db)):
|
||
user = db.query(User).filter(User.id == id).first()
|
||
if user:
|
||
db.delete(user)
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="user not found")
|
||
|
||
# 获取用户
|
||
@app.get("/api/users/{id}")
|
||
async def get_user(id: str,db=Depends(get_db)):
|
||
user = db.query(User).filter(User.id == id).first()
|
||
if user:
|
||
return {"code":200,"msg":"success","data":user}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="user not found")
|
||
|
||
# 获取所有用户
|
||
@app.get("/api/users")
|
||
async def get_all_users(db=Depends(get_db)):
|
||
users = db.query(User).all()
|
||
return {"code":200,"msg":"success","data":users}
|
||
|
||
# 更新用户
|
||
@app.put("/api/users/{id}")
|
||
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
||
user = db.query(User).filter(User.id == id).first()
|
||
if user:
|
||
user.name = request.name
|
||
user.email = request.email
|
||
user.password = request.password
|
||
db.commit()
|
||
return {"code":200,"msg":"success","data":{}}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="user not found")
|
||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
|
||
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
@app.websocket("/api/chat/streaming/temporary")
|
||
async def streaming_chat(ws: WebSocket):
|
||
await ws.accept()
|
||
logger.debug("WebSocket连接成功")
|
||
try:
|
||
agent = None
|
||
assistant = None
|
||
asr_results = []
|
||
llm_text = ""
|
||
logger.debug("开始进行ASR识别")
|
||
while len(asr_results)==0:
|
||
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
|
||
if assistant is None:
|
||
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||
logger.debug(f"接收到{assistant.name}的请求")
|
||
if assistant is None:
|
||
raise SessionNotFoundError()
|
||
user_info = json.loads(assistant.user_info)
|
||
if not agent:
|
||
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
||
agent.init_recorder(assistant.user_id)
|
||
chunk["audio"] = agent.user_audio_process(chunk["audio"])
|
||
asr_results = await agent.stream_recognize(assistant, chunk)
|
||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||
prompt = agent.prompt_process(asr_results)
|
||
agent.recorder.input_text = prompt
|
||
logger.debug("开始调用大模型")
|
||
llm_frames = await agent.chat(assistant, prompt)
|
||
for llm_frame in llm_frames:
|
||
resp_msgs = agent.llm_msg_process(llm_frame)
|
||
for resp_msg in resp_msgs:
|
||
llm_text += resp_msg
|
||
tts_audio = agent.synthetize(assistant, resp_msg)
|
||
agent.tts_audio_process(tts_audio)
|
||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||
logger.debug(f'websocket返回:{resp_msg}')
|
||
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
|
||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||
logger.debug("结束帧发送完毕")
|
||
update_messages(assistant, kid_text ,llm_text)
|
||
logger.debug("聊天更新成功")
|
||
agent.recorder.output_text = llm_text
|
||
agent.save()
|
||
logger.debug("音频保存成功")
|
||
except EnterSlienceMode:
|
||
tts_audio = agent.synthetize(assistant, "已进入沉默模式")
|
||
await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio))
|
||
await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False))
|
||
logger.debug("进入沉默模式")
|
||
except AsrResultNoneError:
|
||
await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False))
|
||
logger.error("ASR结果为空")
|
||
except AbnormalLLMFrame as e:
|
||
await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False))
|
||
logger.error(f"LLM模型返回异常,错误信息:{str(e)}")
|
||
except SideNoiseError as e:
|
||
await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False))
|
||
logger.debug("检测为噪声")
|
||
except SessionNotFoundError:
|
||
await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False))
|
||
logger.error("session不存在")
|
||
except UnknownVolcEngineModelError:
|
||
await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False))
|
||
logger.error("未知的火山引擎模型")
|
||
except LlmResultNoneError:
|
||
await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False))
|
||
logger.error("LLM结果返回为空")
|
||
except asyncio.TimeoutError:
|
||
await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False))
|
||
logger.error("接收超时")
|
||
logger.debug("WebSocket连接断开")
|
||
logger.debug("")
|
||
await ws.close()
|
||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 允许所有源,也可以指定特定源
|
||
allow_credentials=True,
|
||
allow_methods=["*"], # 允许所有方法
|
||
allow_headers=["*"], # 允许所有头
|
||
)
|
||
|
||
# 启动服务
|
||
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT) |