TakwayPlatform/app/controllers/session.py

80 lines
3.5 KiB
Python
Raw Normal View History

2024-05-01 17:18:30 +08:00
from ..schemas.session import *
from ..dependencies.logger import get_logger
from fastapi import HTTPException, status
from ..models import Session
from ..models import UserCharacter
from datetime import datetime
import json
#依赖注入获取logger
logger = get_logger()
#获取SessionID
async def get_session_id_handler(user_id: int, character_id:int, db):
try:
user_character_record = db.query(UserCharacter).filter(UserCharacter.user_id == user_id, UserCharacter.character_id == character_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if not user_character_record:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User Character not found")
try:
session_id = db.query(Session).filter(Session.user_character_id==user_character_record.id).first().id
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_id_query_data = SessionIdQueryData(session_id=session_id)
return SessionIdQueryResponse(status="success",message="Session ID 获取成功",data=session_id_query_data)
#查询Session信息
async def get_session_handler(session_id: int, db, redis):
session_str = ""
if redis.exists(session_id):
session_str = redis.get(session_id)
else:
try:
session_str = db.query(Session).filter(Session.id == session_id).first().content
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
redis.set(session_id, session_str)
if not session_str:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
session = json.loads(session_str)
session_query_data = SessionQueryData(user_id=session["user_id"], messages=session["messages"],user_info=session["user_info"],tts_info=session["tts_info"],llm_info=session["llm_info"],token=session["token"])
return SessionQueryResponse(status="success",message="Session 查询成功",data=session_query_data)
#更新Sessino信息
async def update_session_handler(session_id, session_data:SessionUpdateRequest, db, redis):
existing_session = ""
if redis.exists(session_id):
existing_session = redis.get(session_id)
else:
existing_session = db.query(Session).filter(Session.id == session_id).first().content
if not existing_session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
#更新Session字段
session = json.loads(existing_session)
session["user_id"] = session_data.user_id
session["messages"] = session_data.messages
session["user_info"] = session_data.user_info
session["tts_info"] = session_data.tts_info
session["llm_info"] = session_data.llm_info
session["token"] = session_data.token
#存储Session
session_str = json.dumps(session,ensure_ascii=False)
redis.set(session_id, session_str)
try:
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)