diff --git a/app/__init__.py b/app/__init__.py index 66d25fb..38297a7 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,12 +1,16 @@ from fastapi import FastAPI, Depends from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager from sqlalchemy import create_engine +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger from .models import Base from .routes.character_route import router as character_router from .routes.user_route import router as user_router from .routes.session_route import router as session_router from .routes.chat_route import router as chat_router from .dependencies.logger import get_logger +from .background_tasks import updating_redis_cache from config import get_config @@ -28,8 +32,20 @@ logger.info("数据库初始化完成") #------------------------------------------------------ +#--------------------设置定时任务----------------------- +scheduler = AsyncIOScheduler() +scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *")) +@asynccontextmanager +async def lifespan(app:FastAPI): + scheduler.start() #启动定时任务 + yield + scheduler.shutdown() #关闭定时任务 +logger.info("定时任务初始化完成") +#------------------------------------------------------ + + #--------------------创建FastAPI实例-------------------- -app = FastAPI() +app = FastAPI(lifespan=lifespan) logger.info("FastAPI实例创建完成") #------------------------------------------------------ @@ -42,6 +58,7 @@ app.include_router(chat_router) logger.info("路由初始化完成") #------------------------------------------------------- + #-------------------设置跨域中间件----------------------- app.add_middleware( CORSMiddleware, diff --git a/app/background_tasks.py b/app/background_tasks.py new file mode 100644 index 0000000..f757494 --- /dev/null +++ b/app/background_tasks.py @@ -0,0 +1,31 @@ +from .dependencies.database import get_db +from .dependencies.redis import get_redis +from .dependencies.logger import get_logger +from .models import Session +from datetime import datetime, timedelta + +logger = get_logger() + +def updating_redis_cache(): + db = next(get_db()) + redis = get_redis() + current_time = datetime.now() + two_days_ago = current_time - timedelta(days=2) + db.begin() + try: + expired_sessions = db.query(Session.id).filter(Session.last_activity < two_days_ago).all() + for session_id_tuple in expired_sessions: + session_id = session_id_tuple[0] + session_content = redis.get(session_id) + if session_content is None: + logger.error(f"Session {session_id} not found in Redis cache") + continue + redis.delete(session_id) + db.query(Session).filter(Session.id == session_id).update({'content':session_content}) + db.commit() + except Exception as e: + db.rollback() + raise e + finally: + db.close() + logger.info("Redis cache updated successfully") \ No newline at end of file diff --git a/app/controllers/character_controller.py b/app/controllers/character_controller.py index cba4666..04a3cce 100644 --- a/app/controllers/character_controller.py +++ b/app/controllers/character_controller.py @@ -42,7 +42,6 @@ async def update_character_handler(character_id: int, character: CharacterUpdate existing_character.dialogues = character.dialogues try: - db.add(existing_character) db.commit() except Exception as e: db.rollback() diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 8f8c8f5..4d06359 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -89,13 +89,20 @@ def vad_preprocess(audio): return ('A'*1280) return audio[:1280],audio[1280:] +#更新session活跃时间 +def update_session_activity(session_id,db): + try: + db.query(Session).filter(Session.id == session_id).update({"last_activity": datetime.now()}) + db.commit() + except Exception as e: + db.roolback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) #-------------------------------------------------------- # 创建新聊天 async def create_chat_handler(chat: ChatCreateRequest, db, redis): # 创建新的UserCharacter记录 new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id) - try: db.add(new_chat) db.commit() @@ -279,8 +286,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis logger.debug(f"llm返回结果: {llm_response}") await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False #重置is_end标志位 - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -304,6 +309,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): session_id = await future_session_id #获取session_id + update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) @@ -425,8 +431,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -454,6 +458,7 @@ async def streaming_chat_lasting_handler(ws,db,redis): asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event)) session_id = await future_session_id #获取session_id + update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) @@ -575,8 +580,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) is_end = False - session_content = get_session_content(session_id,redis,db) - messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session @@ -622,6 +625,7 @@ async def voice_call_handler(ws, db, redis): #获取session内容 session_id = await future #获取session_id + update_session_activity(session_id,db) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index 43da3ce..a726d0b 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -35,7 +35,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session): existing_user.tags = user.tags existing_user.persona = user.persona try: - db.add(existing_user) db.commit() db.refresh(existing_user) except Exception as e: @@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db): if existing_hardware is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") existing_hardware.user_id = user.user_id - db.add(existing_hardware) db.commit() db.refresh(existing_hardware) except Exception as e: @@ -136,7 +134,6 @@ async def update_hardware_handler(hardware_id, hardware, db): existing_hardware.mac = hardware.mac existing_hardware.firmware = hardware.firmware existing_hardware.model = hardware.model - db.add(existing_hardware) db.commit() db.refresh(existing_hardware) except Exception as e: diff --git a/app/routes/user_route.py b/app/routes/user_route.py index 995f615..a0bc6a7 100644 --- a/app/routes/user_route.py +++ b/app/routes/user_route.py @@ -8,28 +8,28 @@ from ..dependencies.database import get_db router = APIRouter() -#角色创建接口 +#用户创建接口 @router.post('/users', response_model=UserCrateResponse) async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)): response = await create_user_handler(user,db) return response -#角色更新接口 +#用户更新接口 @router.put('/users/{user_id}', response_model=UserUpdateResponse) async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)): response = await update_user_handler(user_id, user, db) return response -#角色删除接口 +#用户删除接口 @router.delete('/users/{user_id}', response_model=UserDeleteResponse) async def delete_user(user_id: int, db: Session = Depends(get_db)): response = await delete_user_handler(user_id, db) return response -#角色查询接口 +#用户查询接口 @router.get('/users/{user_id}', response_model=UserQueryResponse) async def get_user(user_id: int, db: Session = Depends(get_db)): response = await get_user_handler(user_id, db) diff --git a/requirements.txt b/requirements.txt index 171ef6a..ffb082f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ torch numba soundfile webrtcvad +apscheduler \ No newline at end of file