forked from killua/TakwayPlatform
1.添加定时任务(每天凌晨四点处理redis记录的删除,存储入mysql持久化操作)
2.修复了prompt中没有用户输入信息的bug 3.优化了更新类接口代码(删除了db.add()代码)
This commit is contained in:
parent
de370f77ac
commit
065528f509
|
@ -1,12 +1,16 @@
|
|||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import create_engine
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from .models import Base
|
||||
from .routes.character_route import router as character_router
|
||||
from .routes.user_route import router as user_router
|
||||
from .routes.session_route import router as session_router
|
||||
from .routes.chat_route import router as chat_router
|
||||
from .dependencies.logger import get_logger
|
||||
from .background_tasks import updating_redis_cache
|
||||
from config import get_config
|
||||
|
||||
|
||||
|
@ -28,8 +32,20 @@ logger.info("数据库初始化完成")
|
|||
#------------------------------------------------------
|
||||
|
||||
|
||||
#--------------------设置定时任务-----------------------
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
||||
@asynccontextmanager
|
||||
async def lifespan(app:FastAPI):
|
||||
scheduler.start() #启动定时任务
|
||||
yield
|
||||
scheduler.shutdown() #关闭定时任务
|
||||
logger.info("定时任务初始化完成")
|
||||
#------------------------------------------------------
|
||||
|
||||
|
||||
#--------------------创建FastAPI实例--------------------
|
||||
app = FastAPI()
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
logger.info("FastAPI实例创建完成")
|
||||
#------------------------------------------------------
|
||||
|
||||
|
@ -42,6 +58,7 @@ app.include_router(chat_router)
|
|||
logger.info("路由初始化完成")
|
||||
#-------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------设置跨域中间件-----------------------
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
from .dependencies.database import get_db
|
||||
from .dependencies.redis import get_redis
|
||||
from .dependencies.logger import get_logger
|
||||
from .models import Session
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
def updating_redis_cache():
|
||||
db = next(get_db())
|
||||
redis = get_redis()
|
||||
current_time = datetime.now()
|
||||
two_days_ago = current_time - timedelta(days=2)
|
||||
db.begin()
|
||||
try:
|
||||
expired_sessions = db.query(Session.id).filter(Session.last_activity < two_days_ago).all()
|
||||
for session_id_tuple in expired_sessions:
|
||||
session_id = session_id_tuple[0]
|
||||
session_content = redis.get(session_id)
|
||||
if session_content is None:
|
||||
logger.error(f"Session {session_id} not found in Redis cache")
|
||||
continue
|
||||
redis.delete(session_id)
|
||||
db.query(Session).filter(Session.id == session_id).update({'content':session_content})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
finally:
|
||||
db.close()
|
||||
logger.info("Redis cache updated successfully")
|
|
@ -42,7 +42,6 @@ async def update_character_handler(character_id: int, character: CharacterUpdate
|
|||
existing_character.dialogues = character.dialogues
|
||||
|
||||
try:
|
||||
db.add(existing_character)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
|
|
|
@ -89,13 +89,20 @@ def vad_preprocess(audio):
|
|||
return ('A'*1280)
|
||||
return audio[:1280],audio[1280:]
|
||||
|
||||
#更新session活跃时间
|
||||
def update_session_activity(session_id,db):
|
||||
try:
|
||||
db.query(Session).filter(Session.id == session_id).update({"last_activity": datetime.now()})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.roolback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
#--------------------------------------------------------
|
||||
|
||||
# 创建新聊天
|
||||
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||
# 创建新的UserCharacter记录
|
||||
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
|
||||
|
||||
try:
|
||||
db.add(new_chat)
|
||||
db.commit()
|
||||
|
@ -279,8 +286,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False #重置is_end标志位
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
|
@ -304,6 +309,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
|||
|
||||
|
||||
session_id = await future_session_id #获取session_id
|
||||
update_session_activity(session_id,db)
|
||||
response_type = await future_response_type #获取返回类型
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
@ -425,8 +431,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
|
@ -454,6 +458,7 @@ async def streaming_chat_lasting_handler(ws,db,redis):
|
|||
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
|
||||
|
||||
session_id = await future_session_id #获取session_id
|
||||
update_session_activity(session_id,db)
|
||||
response_type = await future_response_type #获取返回类型
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
@ -575,8 +580,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
|
@ -622,6 +625,7 @@ async def voice_call_handler(ws, db, redis):
|
|||
|
||||
#获取session内容
|
||||
session_id = await future #获取session_id
|
||||
update_session_activity(session_id,db)
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
||||
|
|
|
@ -35,7 +35,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
|
|||
existing_user.tags = user.tags
|
||||
existing_user.persona = user.persona
|
||||
try:
|
||||
db.add(existing_user)
|
||||
db.commit()
|
||||
db.refresh(existing_user)
|
||||
except Exception as e:
|
||||
|
@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db):
|
|||
if existing_hardware is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
existing_hardware.user_id = user.user_id
|
||||
db.add(existing_hardware)
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
|
@ -136,7 +134,6 @@ async def update_hardware_handler(hardware_id, hardware, db):
|
|||
existing_hardware.mac = hardware.mac
|
||||
existing_hardware.firmware = hardware.firmware
|
||||
existing_hardware.model = hardware.model
|
||||
db.add(existing_hardware)
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
|
|
|
@ -8,28 +8,28 @@ from ..dependencies.database import get_db
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
#角色创建接口
|
||||
#用户创建接口
|
||||
@router.post('/users', response_model=UserCrateResponse)
|
||||
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
|
||||
response = await create_user_handler(user,db)
|
||||
return response
|
||||
|
||||
|
||||
#角色更新接口
|
||||
#用户更新接口
|
||||
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
|
||||
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
|
||||
response = await update_user_handler(user_id, user, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色删除接口
|
||||
#用户删除接口
|
||||
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
|
||||
async def delete_user(user_id: int, db: Session = Depends(get_db)):
|
||||
response = await delete_user_handler(user_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色查询接口
|
||||
#用户查询接口
|
||||
@router.get('/users/{user_id}', response_model=UserQueryResponse)
|
||||
async def get_user(user_id: int, db: Session = Depends(get_db)):
|
||||
response = await get_user_handler(user_id, db)
|
||||
|
|
|
@ -18,3 +18,4 @@ torch
|
|||
numba
|
||||
soundfile
|
||||
webrtcvad
|
||||
apscheduler
|
Loading…
Reference in New Issue