1
0
Fork 0

1.添加定时任务(每天凌晨四点处理redis记录的删除,存储入mysql持久化操作)

2.修复了prompt中没有用户输入信息的bug
3.优化了更新类接口代码(删除了db.add()代码)
This commit is contained in:
killua4396 2024-05-07 11:26:59 +08:00
parent de370f77ac
commit 065528f509
7 changed files with 65 additions and 16 deletions

View File

@ -1,12 +1,16 @@
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sqlalchemy import create_engine
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from .models import Base
from .routes.character_route import router as character_router
from .routes.user_route import router as user_router
from .routes.session_route import router as session_router
from .routes.chat_route import router as chat_router
from .dependencies.logger import get_logger
from .background_tasks import updating_redis_cache
from config import get_config
@ -28,8 +32,20 @@ logger.info("数据库初始化完成")
#------------------------------------------------------
#--------------------设置定时任务-----------------------
scheduler = AsyncIOScheduler()
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
@asynccontextmanager
async def lifespan(app:FastAPI):
scheduler.start() #启动定时任务
yield
scheduler.shutdown() #关闭定时任务
logger.info("定时任务初始化完成")
#------------------------------------------------------
#--------------------创建FastAPI实例--------------------
app = FastAPI()
app = FastAPI(lifespan=lifespan)
logger.info("FastAPI实例创建完成")
#------------------------------------------------------
@ -42,6 +58,7 @@ app.include_router(chat_router)
logger.info("路由初始化完成")
#-------------------------------------------------------
#-------------------设置跨域中间件-----------------------
app.add_middleware(
CORSMiddleware,

31
app/background_tasks.py Normal file
View File

@ -0,0 +1,31 @@
from .dependencies.database import get_db
from .dependencies.redis import get_redis
from .dependencies.logger import get_logger
from .models import Session
from datetime import datetime, timedelta
logger = get_logger()
def updating_redis_cache():
db = next(get_db())
redis = get_redis()
current_time = datetime.now()
two_days_ago = current_time - timedelta(days=2)
db.begin()
try:
expired_sessions = db.query(Session.id).filter(Session.last_activity < two_days_ago).all()
for session_id_tuple in expired_sessions:
session_id = session_id_tuple[0]
session_content = redis.get(session_id)
if session_content is None:
logger.error(f"Session {session_id} not found in Redis cache")
continue
redis.delete(session_id)
db.query(Session).filter(Session.id == session_id).update({'content':session_content})
db.commit()
except Exception as e:
db.rollback()
raise e
finally:
db.close()
logger.info("Redis cache updated successfully")

View File

@ -42,7 +42,6 @@ async def update_character_handler(character_id: int, character: CharacterUpdate
existing_character.dialogues = character.dialogues
try:
db.add(existing_character)
db.commit()
except Exception as e:
db.rollback()

View File

@ -89,13 +89,20 @@ def vad_preprocess(audio):
return ('A'*1280)
return audio[:1280],audio[1280:]
#更新session活跃时间
def update_session_activity(session_id,db):
try:
db.query(Session).filter(Session.id == session_id).update({"last_activity": datetime.now()})
db.commit()
except Exception as e:
db.roolback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
#--------------------------------------------------------
# 创建新聊天
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 创建新的UserCharacter记录
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
try:
db.add(new_chat)
db.commit()
@ -279,8 +286,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False #重置is_end标志位
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
@ -304,6 +309,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
session_id = await future_session_id #获取session_id
update_session_activity(session_id,db)
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
@ -425,8 +431,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
@ -454,6 +458,7 @@ async def streaming_chat_lasting_handler(ws,db,redis):
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
session_id = await future_session_id #获取session_id
update_session_activity(session_id,db)
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
@ -575,8 +580,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
@ -622,6 +625,7 @@ async def voice_call_handler(ws, db, redis):
#获取session内容
session_id = await future #获取session_id
update_session_activity(session_id,db)
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])

View File

@ -35,7 +35,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
existing_user.tags = user.tags
existing_user.persona = user.persona
try:
db.add(existing_user)
db.commit()
db.refresh(existing_user)
except Exception as e:
@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db):
if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.user_id = user.user_id
db.add(existing_hardware)
db.commit()
db.refresh(existing_hardware)
except Exception as e:
@ -136,7 +134,6 @@ async def update_hardware_handler(hardware_id, hardware, db):
existing_hardware.mac = hardware.mac
existing_hardware.firmware = hardware.firmware
existing_hardware.model = hardware.model
db.add(existing_hardware)
db.commit()
db.refresh(existing_hardware)
except Exception as e:

View File

@ -8,28 +8,28 @@ from ..dependencies.database import get_db
router = APIRouter()
#角色创建接口
#用户创建接口
@router.post('/users', response_model=UserCrateResponse)
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
response = await create_user_handler(user,db)
return response
#角色更新接口
#用户更新接口
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
response = await update_user_handler(user_id, user, db)
return response
#角色删除接口
#用户删除接口
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
async def delete_user(user_id: int, db: Session = Depends(get_db)):
response = await delete_user_handler(user_id, db)
return response
#角色查询接口
#用户查询接口
@router.get('/users/{user_id}', response_model=UserQueryResponse)
async def get_user(user_id: int, db: Session = Depends(get_db)):
response = await get_user_handler(user_id, db)

View File

@ -18,3 +18,4 @@ torch
numba
soundfile
webrtcvad
apscheduler