From 065528f509784af498a5f6fc8a83348e062c66fe Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Tue, 7 May 2024 11:26:59 +0800 Subject: [PATCH] =?UTF-8?q?1.=E6=B7=BB=E5=8A=A0=E5=AE=9A=E6=97=B6=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1(=E6=AF=8F=E5=A4=A9=E5=87=8C=E6=99=A8=E5=9B=9B?= =?UTF-8?q?=E7=82=B9=E5=A4=84=E7=90=86redis=E8=AE=B0=E5=BD=95=E7=9A=84?= =?UTF-8?q?=E5=88=A0=E9=99=A4=EF=BC=8C=E5=AD=98=E5=82=A8=E5=85=A5mysql?= =?UTF-8?q?=E6=8C=81=E4=B9=85=E5=8C=96=E6=93=8D=E4=BD=9C)=202.=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E4=BA=86prompt=E4=B8=AD=E6=B2=A1=E6=9C=89=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=BE=93=E5=85=A5=E4=BF=A1=E6=81=AF=E7=9A=84bug=203.?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86=E6=9B=B4=E6=96=B0=E7=B1=BB=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E4=BB=A3=E7=A0=81(=E5=88=A0=E9=99=A4=E4=BA=86db.add()?= =?UTF-8?q?=E4=BB=A3=E7=A0=81)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/__init__.py | 19 ++++++++++++++- app/background_tasks.py | 31 +++++++++++++++++++++++++ app/controllers/character_controller.py | 1 - app/controllers/chat_controller.py | 18 ++++++++------ app/controllers/user_controller.py | 3 --- app/routes/user_route.py | 8 +++---- requirements.txt | 1 + 7 files changed, 65 insertions(+), 16 deletions(-) create mode 100644 app/background_tasks.py diff --git a/app/__init__.py b/app/__init__.py index 66d25fb..38297a7 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,12 +1,16 @@ from fastapi import FastAPI, Depends from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager from sqlalchemy import create_engine +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger from .models import Base from .routes.character_route import router as character_router from .routes.user_route import router as user_router from .routes.session_route import router as session_router from .routes.chat_route import router as chat_router from .dependencies.logger import get_logger +from .background_tasks import updating_redis_cache from config import get_config @@ -28,8 +32,20 @@ logger.info("数据库初始化完成") #------------------------------------------------------ +#--------------------设置定时任务----------------------- +scheduler = AsyncIOScheduler() +scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *")) +@asynccontextmanager +async def lifespan(app:FastAPI): + scheduler.start() #启动定时任务 + yield + scheduler.shutdown() #关闭定时任务 +logger.info("定时任务初始化完成") +#------------------------------------------------------ + + #--------------------创建FastAPI实例-------------------- -app = FastAPI() +app = FastAPI(lifespan=lifespan) logger.info("FastAPI实例创建完成") #------------------------------------------------------ @@ -42,6 +58,7 @@ app.include_router(chat_router) logger.info("路由初始化完成") #------------------------------------------------------- + #-------------------设置跨域中间件----------------------- app.add_middleware( CORSMiddleware, diff --git a/app/background_tasks.py b/app/background_tasks.py new file mode 100644 index 0000000..f757494 --- /dev/null +++ b/app/background_tasks.py @@ -0,0 +1,31 @@ +from .dependencies.database import get_db +from .dependencies.redis import get_redis +from .dependencies.logger import get_logger +from .models import Session +from datetime import datetime, timedelta + +logger = get_logger() + +def updating_redis_cache(): + db = next(get_db()) + redis = get_redis() + current_time = datetime.now() + two_days_ago = current_time - timedelta(days=2) + db.begin() + try: + expired_sessions = db.query(Session.id).filter(Session.last_activity < two_days_ago).all() + for session_id_tuple in expired_sessions: + session_id = session_id_tuple[0] + session_content = redis.get(session_id) + if session_content is None: + logger.error(f"Session {session_id} not found in Redis cache") + continue + redis.delete(session_id) + db.query(Session).filter(Session.id == session_id).update({'content':session_content}) + db.commit() + except Exception as e: + db.rollback() + raise e + finally: + db.close() + logger.info("Redis cache updated successfully") \ No newline at end of file diff --git a/app/controllers/character_controller.py b/app/controllers/character_controller.py index cba4666..04a3cce 100644 --- a/app/controllers/character_controller.py +++ b/app/controllers/character_controller.py @@ -42,7 +42,6 @@ async def update_character_handler(character_id: int, character: CharacterUpdate existing_character.dialogues = character.dialogues try: - db.add(existing_character) db.commit() except Exception as e: db.rollback() diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 8f8c8f5..4d06359 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -89,13 +89,20 @@ def vad_preprocess(audio): return ('A'*1280) return audio[:1280],audio[1280:] +#更新session活跃时间 +def update_session_activity(session_id,db): + try: + db.query(Session).filter(Session.id == session_id).update({"last_activity": datetime.now()}) + db.commit() + except Exception as e: + db.roolback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) #-------------------------------------------------------- # 创建新聊天 async def create_chat_handler(chat: ChatCreateRequest, db, redis): # 创建新的UserCharacter记录 new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id) - try: db.add(new_chat) db.commit() @@ -279,8 +286,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis logger.debug(f"llm返回结果: {llm_response}") await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False #重置is_end标志位 - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -304,6 +309,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): session_id = await future_session_id #获取session_id + update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) @@ -425,8 +431,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -454,6 +458,7 @@ async def streaming_chat_lasting_handler(ws,db,redis): asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event)) session_id = await future_session_id #获取session_id + update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) @@ -575,8 +580,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -622,6 +625,7 @@ async def voice_call_handler(ws, db, redis): #获取session内容 session_id = await future #获取session_id + update_session_activity(session_id,db) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index 43da3ce..a726d0b 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -35,7 +35,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session): existing_user.tags = user.tags existing_user.persona = user.persona try: - db.add(existing_user) db.commit() db.refresh(existing_user) except Exception as e: @@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db): if existing_hardware is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") existing_hardware.user_id = user.user_id - db.add(existing_hardware) db.commit() db.refresh(existing_hardware) except Exception as e: @@ -136,7 +134,6 @@ async def update_hardware_handler(hardware_id, hardware, db): existing_hardware.mac = hardware.mac existing_hardware.firmware = hardware.firmware existing_hardware.model = hardware.model - db.add(existing_hardware) db.commit() db.refresh(existing_hardware) except Exception as e: diff --git a/app/routes/user_route.py b/app/routes/user_route.py index 995f615..a0bc6a7 100644 --- a/app/routes/user_route.py +++ b/app/routes/user_route.py @@ -8,28 +8,28 @@ from ..dependencies.database import get_db router = APIRouter() -#角色创建接口 +#用户创建接口 @router.post('/users', response_model=UserCrateResponse) async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)): response = await create_user_handler(user,db) return response -#角色更新接口 +#用户更新接口 @router.put('/users/{user_id}', response_model=UserUpdateResponse) async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)): response = await update_user_handler(user_id, user, db) return response -#角色删除接口 +#用户删除接口 @router.delete('/users/{user_id}', response_model=UserDeleteResponse) async def delete_user(user_id: int, db: Session = Depends(get_db)): response = await delete_user_handler(user_id, db) return response -#角色查询接口 +#用户查询接口 @router.get('/users/{user_id}', response_model=UserQueryResponse) async def get_user(user_id: int, db: Session = Depends(get_db)): response = await get_user_handler(user_id, db) diff --git a/requirements.txt b/requirements.txt index 171ef6a..ffb082f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ torch numba soundfile webrtcvad +apscheduler \ No newline at end of file