forked from killua/TakwayPlatform
1.添加定时任务(每天凌晨四点处理redis记录的删除,存储入mysql持久化操作)
2.修复了prompt中没有用户输入信息的bug 3.优化了更新类接口代码(删除了db.add()代码)
This commit is contained in:
parent
de370f77ac
commit
065528f509
|
@ -1,12 +1,16 @@
|
||||||
from fastapi import FastAPI, Depends
|
from fastapi import FastAPI, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from .models import Base
|
from .models import Base
|
||||||
from .routes.character_route import router as character_router
|
from .routes.character_route import router as character_router
|
||||||
from .routes.user_route import router as user_router
|
from .routes.user_route import router as user_router
|
||||||
from .routes.session_route import router as session_router
|
from .routes.session_route import router as session_router
|
||||||
from .routes.chat_route import router as chat_router
|
from .routes.chat_route import router as chat_router
|
||||||
from .dependencies.logger import get_logger
|
from .dependencies.logger import get_logger
|
||||||
|
from .background_tasks import updating_redis_cache
|
||||||
from config import get_config
|
from config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,8 +32,20 @@ logger.info("数据库初始化完成")
|
||||||
#------------------------------------------------------
|
#------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
#--------------------设置定时任务-----------------------
|
||||||
|
scheduler = AsyncIOScheduler()
|
||||||
|
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app:FastAPI):
|
||||||
|
scheduler.start() #启动定时任务
|
||||||
|
yield
|
||||||
|
scheduler.shutdown() #关闭定时任务
|
||||||
|
logger.info("定时任务初始化完成")
|
||||||
|
#------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
#--------------------创建FastAPI实例--------------------
|
#--------------------创建FastAPI实例--------------------
|
||||||
app = FastAPI()
|
app = FastAPI(lifespan=lifespan)
|
||||||
logger.info("FastAPI实例创建完成")
|
logger.info("FastAPI实例创建完成")
|
||||||
#------------------------------------------------------
|
#------------------------------------------------------
|
||||||
|
|
||||||
|
@ -42,6 +58,7 @@ app.include_router(chat_router)
|
||||||
logger.info("路由初始化完成")
|
logger.info("路由初始化完成")
|
||||||
#-------------------------------------------------------
|
#-------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
#-------------------设置跨域中间件-----------------------
|
#-------------------设置跨域中间件-----------------------
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
from .dependencies.database import get_db
|
||||||
|
from .dependencies.redis import get_redis
|
||||||
|
from .dependencies.logger import get_logger
|
||||||
|
from .models import Session
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
def updating_redis_cache():
|
||||||
|
db = next(get_db())
|
||||||
|
redis = get_redis()
|
||||||
|
current_time = datetime.now()
|
||||||
|
two_days_ago = current_time - timedelta(days=2)
|
||||||
|
db.begin()
|
||||||
|
try:
|
||||||
|
expired_sessions = db.query(Session.id).filter(Session.last_activity < two_days_ago).all()
|
||||||
|
for session_id_tuple in expired_sessions:
|
||||||
|
session_id = session_id_tuple[0]
|
||||||
|
session_content = redis.get(session_id)
|
||||||
|
if session_content is None:
|
||||||
|
logger.error(f"Session {session_id} not found in Redis cache")
|
||||||
|
continue
|
||||||
|
redis.delete(session_id)
|
||||||
|
db.query(Session).filter(Session.id == session_id).update({'content':session_content})
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
logger.info("Redis cache updated successfully")
|
|
@ -42,7 +42,6 @@ async def update_character_handler(character_id: int, character: CharacterUpdate
|
||||||
existing_character.dialogues = character.dialogues
|
existing_character.dialogues = character.dialogues
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.add(existing_character)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|
|
@ -89,13 +89,20 @@ def vad_preprocess(audio):
|
||||||
return ('A'*1280)
|
return ('A'*1280)
|
||||||
return audio[:1280],audio[1280:]
|
return audio[:1280],audio[1280:]
|
||||||
|
|
||||||
|
#更新session活跃时间
|
||||||
|
def update_session_activity(session_id,db):
|
||||||
|
try:
|
||||||
|
db.query(Session).filter(Session.id == session_id).update({"last_activity": datetime.now()})
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.roolback()
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
#--------------------------------------------------------
|
#--------------------------------------------------------
|
||||||
|
|
||||||
# 创建新聊天
|
# 创建新聊天
|
||||||
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
# 创建新的UserCharacter记录
|
# 创建新的UserCharacter记录
|
||||||
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
|
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.add(new_chat)
|
db.add(new_chat)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
@ -279,8 +286,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
is_end = False #重置is_end标志位
|
is_end = False #重置is_end标志位
|
||||||
session_content = get_session_content(session_id,redis,db)
|
|
||||||
messages = json.loads(session_content["messages"])
|
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
|
@ -304,6 +309,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||||
|
|
||||||
|
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
|
update_session_activity(session_id,db)
|
||||||
response_type = await future_response_type #获取返回类型
|
response_type = await future_response_type #获取返回类型
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
@ -425,8 +431,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
is_end = False
|
is_end = False
|
||||||
|
|
||||||
session_content = get_session_content(session_id,redis,db)
|
|
||||||
messages = json.loads(session_content["messages"])
|
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
|
@ -454,6 +458,7 @@ async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
|
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
|
||||||
|
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
|
update_session_activity(session_id,db)
|
||||||
response_type = await future_response_type #获取返回类型
|
response_type = await future_response_type #获取返回类型
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
@ -575,8 +580,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
is_end = False
|
is_end = False
|
||||||
|
|
||||||
session_content = get_session_content(session_id,redis,db)
|
|
||||||
messages = json.loads(session_content["messages"])
|
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
|
@ -622,6 +625,7 @@ async def voice_call_handler(ws, db, redis):
|
||||||
|
|
||||||
#获取session内容
|
#获取session内容
|
||||||
session_id = await future #获取session_id
|
session_id = await future #获取session_id
|
||||||
|
update_session_activity(session_id,db)
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
|
||||||
existing_user.tags = user.tags
|
existing_user.tags = user.tags
|
||||||
existing_user.persona = user.persona
|
existing_user.persona = user.persona
|
||||||
try:
|
try:
|
||||||
db.add(existing_user)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_user)
|
db.refresh(existing_user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db):
|
||||||
if existing_hardware is None:
|
if existing_hardware is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||||
existing_hardware.user_id = user.user_id
|
existing_hardware.user_id = user.user_id
|
||||||
db.add(existing_hardware)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_hardware)
|
db.refresh(existing_hardware)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -136,7 +134,6 @@ async def update_hardware_handler(hardware_id, hardware, db):
|
||||||
existing_hardware.mac = hardware.mac
|
existing_hardware.mac = hardware.mac
|
||||||
existing_hardware.firmware = hardware.firmware
|
existing_hardware.firmware = hardware.firmware
|
||||||
existing_hardware.model = hardware.model
|
existing_hardware.model = hardware.model
|
||||||
db.add(existing_hardware)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_hardware)
|
db.refresh(existing_hardware)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -8,28 +8,28 @@ from ..dependencies.database import get_db
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
#角色创建接口
|
#用户创建接口
|
||||||
@router.post('/users', response_model=UserCrateResponse)
|
@router.post('/users', response_model=UserCrateResponse)
|
||||||
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
|
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
|
||||||
response = await create_user_handler(user,db)
|
response = await create_user_handler(user,db)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#角色更新接口
|
#用户更新接口
|
||||||
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
|
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
|
||||||
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
|
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
|
||||||
response = await update_user_handler(user_id, user, db)
|
response = await update_user_handler(user_id, user, db)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#角色删除接口
|
#用户删除接口
|
||||||
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
|
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
|
||||||
async def delete_user(user_id: int, db: Session = Depends(get_db)):
|
async def delete_user(user_id: int, db: Session = Depends(get_db)):
|
||||||
response = await delete_user_handler(user_id, db)
|
response = await delete_user_handler(user_id, db)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#角色查询接口
|
#用户查询接口
|
||||||
@router.get('/users/{user_id}', response_model=UserQueryResponse)
|
@router.get('/users/{user_id}', response_model=UserQueryResponse)
|
||||||
async def get_user(user_id: int, db: Session = Depends(get_db)):
|
async def get_user(user_id: int, db: Session = Depends(get_db)):
|
||||||
response = await get_user_handler(user_id, db)
|
response = await get_user_handler(user_id, db)
|
||||||
|
|
|
@ -18,3 +18,4 @@ torch
|
||||||
numba
|
numba
|
||||||
soundfile
|
soundfile
|
||||||
webrtcvad
|
webrtcvad
|
||||||
|
apscheduler
|
Loading…
Reference in New Issue