forked from killua/TakwayPlatform
108 lines
4.8 KiB
Python
108 lines
4.8 KiB
Python
from ..schemas.session_schema import *
|
|
from ..dependencies.logger import get_logger
|
|
from fastapi import HTTPException, status
|
|
from ..models import Session
|
|
from ..models import UserCharacter
|
|
from datetime import datetime
|
|
import json
|
|
|
|
#依赖注入获取logger
|
|
logger = get_logger()
|
|
|
|
|
|
#获取SessionID
|
|
async def get_session_id_handler(user_id: int, character_id:int, db):
|
|
try:
|
|
user_character_record = db.query(UserCharacter).filter(UserCharacter.user_id == user_id, UserCharacter.character_id == character_id).first()
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
|
|
if not user_character_record:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User Character not found")
|
|
try:
|
|
session_id = db.query(Session).filter(Session.user_character_id==user_character_record.id).first().id
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
|
|
session_id_query_data = SessionIdQueryData(session_id=session_id)
|
|
return SessionIdQueryResponse(status="success",message="Session ID 获取成功",data=session_id_query_data)
|
|
|
|
#查询Session信息
|
|
async def get_session_handler(session_id: int, db, redis):
|
|
session_str = ""
|
|
if redis.exists(session_id):
|
|
session_str = redis.get(session_id)
|
|
else:
|
|
try:
|
|
session_str = db.query(Session).filter(Session.id == session_id).first().content
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
redis.set(session_id, session_str)
|
|
if not session_str:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|
session = json.loads(session_str)
|
|
session_query_data = SessionQueryData(user_id=session["user_id"], messages=session["messages"],user_info=session["user_info"],tts_info=session["tts_info"],llm_info=session["llm_info"],token=session["token"])
|
|
return SessionQueryResponse(status="success",message="Session 查询成功",data=session_query_data)
|
|
|
|
#更新Sessino信息
|
|
async def update_session_handler(session_id, session_data:SessionUpdateRequest, db, redis):
|
|
existing_session = ""
|
|
if redis.exists(session_id):
|
|
existing_session = redis.get(session_id)
|
|
else:
|
|
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
|
if not existing_session:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|
|
|
#更新Session字段
|
|
session = json.loads(existing_session)
|
|
session["user_id"] = session_data.user_id
|
|
session["messages"] = session_data.messages
|
|
session["user_info"] = session_data.user_info
|
|
session["tts_info"] = session_data.tts_info
|
|
session["llm_info"] = session_data.llm_info
|
|
session["token"] = session_data.token
|
|
|
|
#存储Session
|
|
session_str = json.dumps(session,ensure_ascii=False)
|
|
redis.set(session_id, session_str)
|
|
try:
|
|
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
|
db.commit()
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
|
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
|
|
|
#更新Session中的Speaker Id信息
|
|
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
|
|
existing_session = ""
|
|
if redis.exists(session_id):
|
|
existing_session = redis.get(session_id)
|
|
else:
|
|
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
|
if not existing_session:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|
|
|
#更新Session字段
|
|
session = json.loads(existing_session)
|
|
session_llm_info = json.loads(session["llm_info"])
|
|
session_llm_info["speaker_id"] = session_data.speaker_id
|
|
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
|
|
|
|
#存储Session
|
|
|
|
session_str = json.dumps(session,ensure_ascii=False)
|
|
redis.set(session_id, session_str)
|
|
try:
|
|
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
|
db.commit()
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
|
|
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data) |