from fastapi import FastAPI, Depends, WebSocket, HTTPException from fastapi.middleware.cors import CORSMiddleware from config import Config from app.concrete import Agent from app.model import Assistant, User, get_db, get_db_context from app.schemas import * from app.dependency import get_logger from app.exception import * import asyncio import uvicorn import uuid import json import time # 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ def update_messages(assistant, kid_text, llm_text): messages = json.loads(assistant.messages) if not kid_text: raise AsrResultNoneError() if not llm_text: raise LlmResultNoneError() messages.append({"role":"user","content":kid_text}) messages.append({"role":"assistant","content":llm_text}) with get_db_context() as db: db.query(Assistant).filter(Assistant.id == assistant.id).update({"messages":json.dumps(messages,ensure_ascii=False),"token":assistant.token}) db.commit() # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # 引入logger对象 logger = get_logger() # 创建FastAPI实例 app = FastAPI() # 增删查改 assiatant---------------------------------------------------------------------------------------------------------------------------------------------------------------- # 创建一个assistant @app.post("/api/assistants") async def create_assistant(request: create_assistant_request,db=Depends(get_db)): id = str(uuid.uuid4()) messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False) assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages, user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0) db.add(assistant) db.commit() return {"code":200,"msg":"success","data":{"id":id}} # 删除一个assistant @app.delete("/api/assistants/{id}") async def delete_assistant(id: str,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: db.delete(assistant) db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # 更新一个assistant @app.put("/api/assistants/{id}") async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: assistant.name = request.name assistant.system_prompt = request.system_prompt assistant.messages = request.messages assistant.user_info = request.user_info assistant.llm_info = request.llm_info assistant.tts_info = request.tts_info db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # 获取一个assistant @app.get("/api/assistants/{id}") async def get_assistant(id: str,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: return {"code":200,"msg":"success","data":assistant} else: return {"code":404,'msg':"assistant not found","data":{}} # 获取所有的assistant名称和id @app.get("/api/assistants") async def get_all_assistants_name_id(db=Depends(get_db)): assistants = db.query(Assistant.id, Assistant.name).all() return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]} # 重置一个assistant的消息 @app.post("/api/assistants/{id}/reset_msg") async def reset_assistant_msg(id: str,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # 修改一个assistant的system_prompt @app.put("/api/assistants/{id}/system_prompt") async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: assistant.system_prompt = request.system_prompt assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # 更新具体参数 @app.put("/api/assistants/{id}/deatil_params") async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: llm_info = json.loads(assistant.llm_info) tts_info = json.loads(assistant.tts_info) user_info = json.loads(assistant.user_info) llm_info['model'] = request.model llm_info['temperature'] = request.temperature tts_info['speaker_id'] = request.speaker_id tts_info['length_scale'] = request.length_scale user_info['llm_type'] = request.platform assistant.llm_info = json.dumps(llm_info, ensure_ascii=False) assistant.tts_info = json.dumps(tts_info, ensure_ascii=False) assistant.user_info = json.dumps(user_info, ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # 更新max_tokens @app.put("/api/assistants/{id}/max_tokens") async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)): assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: llm_info = json.loads(assistant.llm_info) llm_info['max_tokens'] = request.max_tokens assistant.llm_info = json.dumps(llm_info, ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # 用户增删改查接口 ---------------------------------------------------------------------------------------------------------------------------------------------------------------- # 添加用户 @app.post("/api/users") async def create_user(request: create_user_request,db=Depends(get_db)): id = str(uuid.uuid4()) user = User(id=id, name=request.name, email=request.email, password=request.password) db.add(user) db.commit() return {"code":200,"msg":"success","data":{"id":id}} # 删除用户 @app.delete("/api/users/{id}") async def delete_user(id: str,db=Depends(get_db)): user = db.query(User).filter(User.id == id).first() if user: db.delete(user) db.commit() return {"code":200,"msg":"success","data":{}} else: raise HTTPException(status_code=404, detail="user not found") # 获取用户 @app.get("/api/users/{id}") async def get_user(id: str,db=Depends(get_db)): user = db.query(User).filter(User.id == id).first() if user: return {"code":200,"msg":"success","data":user} else: raise HTTPException(status_code=404, detail="user not found") # 获取所有用户 @app.get("/api/users") async def get_all_users(db=Depends(get_db)): users = db.query(User).all() return {"code":200,"msg":"success","data":users} # 更新用户 @app.put("/api/users/{id}") async def update_user(id: str,request: update_user_request,db=Depends(get_db)): user = db.query(User).filter(User.id == id).first() if user: user.name = request.name user.email = request.email user.password = request.password db.commit() return {"code":200,"msg":"success","data":{}} else: raise HTTPException(status_code=404, detail="user not found") # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------ @app.websocket("/api/chat/streaming/temporary") async def streaming_chat(ws: WebSocket): await ws.accept() logger.debug("WebSocket连接成功") try: agent = None assistant = None asr_results = [] llm_text = "" logger.debug("开始进行ASR识别") while len(asr_results)==0: chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1)) if assistant is None: with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接 assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() logger.debug(f"接收到{assistant.name}的请求") if assistant is None: raise SessionNotFoundError() user_info = json.loads(assistant.user_info) if not agent: agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type']) agent.init_recorder(assistant.user_id) chunk["audio"] = agent.user_audio_process(chunk["audio"]) asr_results = await agent.stream_recognize(assistant, chunk) kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 prompt = agent.prompt_process(asr_results) agent.recorder.input_text = prompt logger.debug("开始调用大模型") llm_frames = await agent.chat(assistant, prompt) start_time = time.time() is_first_response = True for llm_frame in llm_frames: if is_first_response: end_time = time.time() logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s") is_first_response = False resp_msgs = agent.llm_msg_process(llm_frame) for resp_msg in resp_msgs: llm_text += resp_msg tts_audio = agent.synthetize(assistant, resp_msg) agent.tts_audio_process(tts_audio) await ws.send_bytes(agent.encode(resp_msg, tts_audio)) logger.debug(f'websocket返回:{resp_msg}') logger.debug(f"大模型返回结束,返回结果为:{llm_text}") await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) logger.debug("结束帧发送完毕") update_messages(assistant, kid_text ,llm_text) logger.debug("聊天更新成功") agent.recorder.output_text = llm_text agent.save() logger.debug("音频保存成功") except EnterSlienceMode: tts_audio = agent.synthetize(assistant, "已进入沉默模式") await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio)) await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False)) logger.debug("进入沉默模式") except AsrResultNoneError: await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False)) logger.error("ASR结果为空") except AbnormalLLMFrame as e: await ws.send_text(json.dumps({"type":"error","code":502,"msg":str(e)}, ensure_ascii=False)) logger.error(f"LLM模型返回异常,错误信息:{str(e)}") except SideNoiseError as e: await ws.send_text(json.dumps({"type":"error","code":503,"msg":"检测为噪声"}, ensure_ascii=False)) logger.debug("检测为噪声") except SessionNotFoundError: await ws.send_text(json.dumps({"type":"error","code":504,"msg":"session不存在"}, ensure_ascii=False)) logger.error("session不存在") except UnknownVolcEngineModelError: await ws.send_text(json.dumps({"type":"error","code":505,"msg":"未知的火山引擎模型"}, ensure_ascii=False)) logger.error("未知的火山引擎模型") except LlmResultNoneError: await ws.send_text(json.dumps({"type":"error","code":506,"msg":"llm结果返回为空"}, ensure_ascii=False)) logger.error("LLM结果返回为空") except asyncio.TimeoutError: await ws.send_text(json.dumps({"type":"error","code":507,"msg":"接收超时"}, ensure_ascii=False)) logger.error("接收超时") logger.debug("WebSocket连接断开") logger.debug("") await ws.close() # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有源,也可以指定特定源 allow_credentials=True, allow_methods=["*"], # 允许所有方法 allow_headers=["*"], # 允许所有头 ) # 启动服务 uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)