forked from killua/TakwayPlatform
feat: 用户表添加selected_audio_id字段,添加用户音频绑定接口
This commit is contained in:
parent
773b48471a
commit
387c277c28
|
@ -166,7 +166,10 @@ async def upload_audio_handler(user_id, audio, db):
|
|||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
||||
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
|
||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频
|
||||
db.flush()
|
||||
existing_user = db.query(User).filter(User.id == user_id).first()
|
||||
existing_user.selected_audio_id = new_audio.id #绑定音频到用户
|
||||
db.add(new_audio)
|
||||
db.commit()
|
||||
db.refresh(new_audio)
|
||||
|
@ -214,12 +217,15 @@ async def download_audio_handler(audio_id, db):
|
|||
async def delete_audio_handler(audio_id, db):
|
||||
try:
|
||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_audio is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||
try:
|
||||
if existing_user.selected_audio_id == audio_id:
|
||||
existing_user.selected_audio_id = None
|
||||
db.delete(existing_audio)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
|
@ -227,3 +233,22 @@ async def delete_audio_handler(audio_id, db):
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
|
||||
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
|
||||
|
||||
|
||||
#用户绑定音频
|
||||
async def bind_audio_handler(bind_req, db):
|
||||
try:
|
||||
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
try:
|
||||
existing_user.selected_audio_id = bind_req.audio_id
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
|
||||
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)
|
|
@ -36,6 +36,7 @@ class User(Base):
|
|||
avatar_id = Column(String(36), nullable=True)
|
||||
tags = Column(JSON)
|
||||
persona = Column(JSON)
|
||||
selected_audio_id = Column(Integer, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, tags={self.tags})>"
|
||||
|
@ -81,6 +82,7 @@ class Session(Base):
|
|||
def __repr__(self):
|
||||
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
||||
|
||||
#音频表定义
|
||||
class Audio(Base):
|
||||
__tablename__ = 'audio'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
|
|
@ -97,3 +97,10 @@ async def download_audio(audio_id:int, db: Session = Depends(get_db)):
|
|||
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
|
||||
response = await delete_audio_handler(audio_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户绑定音频
|
||||
@router.post('/users/audio/bind',response_model=AudioBindResponse)
|
||||
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
|
||||
response = await bind_audio_handler(bind_req, db)
|
||||
return response
|
|
@ -166,3 +166,17 @@ class AudioDeleteData(BaseModel):
|
|||
class AudioDeleteResponse(BaseResponse):
|
||||
data: Optional[AudioDeleteData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------用户音频绑定-------------------------------------
|
||||
class AudioBindRequest(BaseModel):
|
||||
audio_id: int
|
||||
user_id: int
|
||||
|
||||
class AudioBindData(BaseModel):
|
||||
bindedAt: str
|
||||
|
||||
class AudioBindResponse(BaseResponse):
|
||||
data: Optional[AudioBindData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -156,6 +156,21 @@ class UserServiceTest:
|
|||
else:
|
||||
raise Exception("音频上传测试失败")
|
||||
|
||||
def test_bind_audio(self):
|
||||
url = f"{self.socket}/users/audio/bind"
|
||||
payload = json.dumps({
|
||||
"user_id":self.id,
|
||||
"audio_id":self.audio_id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("音频绑定测试成功")
|
||||
else:
|
||||
raise Exception("音频绑定测试失败")
|
||||
|
||||
def test_audio_download(self):
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
response = requests.request("GET", url)
|
||||
|
@ -185,6 +200,7 @@ def user_test():
|
|||
user_service_test.test_hardware_unbind()
|
||||
user_service_test.test_upload_audio()
|
||||
user_service_test.test_update_audio()
|
||||
user_service_test.test_bind_audio()
|
||||
user_service_test.test_audio_download()
|
||||
user_service_test.test_audio_delete()
|
||||
user_service_test.test_user_delete()
|
||||
|
|
Loading…
Reference in New Issue