feat: 用户表添加selected_audio_id字段,添加用户音频绑定接口

This commit is contained in:
killua4396 2024-05-23 09:57:53 +08:00
parent 773b48471a
commit 387c277c28
5 changed files with 67 additions and 3 deletions

View File

@ -166,7 +166,10 @@ async def upload_audio_handler(user_id, audio, db):
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
numpy_data = np.frombuffer(raw_data, dtype=np.int32) numpy_data = np.frombuffer(raw_data, dtype=np.int32)
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes() emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频
db.flush()
existing_user = db.query(User).filter(User.id == user_id).first()
existing_user.selected_audio_id = new_audio.id #绑定音频到用户
db.add(new_audio) db.add(new_audio)
db.commit() db.commit()
db.refresh(new_audio) db.refresh(new_audio)
@ -214,12 +217,15 @@ async def download_audio_handler(audio_id, db):
async def delete_audio_handler(audio_id, db): async def delete_audio_handler(audio_id, db):
try: try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None: if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
try: try:
if existing_user.selected_audio_id == audio_id:
existing_user.selected_audio_id = None
db.delete(existing_audio) db.delete(existing_audio)
db.commit() db.commit()
except Exception as e: except Exception as e:
@ -227,3 +233,22 @@ async def delete_audio_handler(audio_id, db):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat()) audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data) return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
#用户绑定音频
async def bind_audio_handler(bind_req, db):
try:
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
try:
existing_user.selected_audio_id = bind_req.audio_id
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)

View File

@ -36,6 +36,7 @@ class User(Base):
avatar_id = Column(String(36), nullable=True) avatar_id = Column(String(36), nullable=True)
tags = Column(JSON) tags = Column(JSON)
persona = Column(JSON) persona = Column(JSON)
selected_audio_id = Column(Integer, nullable=True)
def __repr__(self): def __repr__(self):
return f"<User(id={self.id}, tags={self.tags})>" return f"<User(id={self.id}, tags={self.tags})>"
@ -81,6 +82,7 @@ class Session(Base):
def __repr__(self): def __repr__(self):
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>" return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
#音频表定义
class Audio(Base): class Audio(Base):
__tablename__ = 'audio' __tablename__ = 'audio'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)

View File

@ -97,3 +97,10 @@ async def download_audio(audio_id:int, db: Session = Depends(get_db)):
async def delete_audio(audio_id:int, db: Session = Depends(get_db)): async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
response = await delete_audio_handler(audio_id, db) response = await delete_audio_handler(audio_id, db)
return response return response
#用户绑定音频
@router.post('/users/audio/bind',response_model=AudioBindResponse)
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
response = await bind_audio_handler(bind_req, db)
return response

View File

@ -166,3 +166,17 @@ class AudioDeleteData(BaseModel):
class AudioDeleteResponse(BaseResponse): class AudioDeleteResponse(BaseResponse):
data: Optional[AudioDeleteData] data: Optional[AudioDeleteData]
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
#-------------------------------用户音频绑定-------------------------------------
class AudioBindRequest(BaseModel):
audio_id: int
user_id: int
class AudioBindData(BaseModel):
bindedAt: str
class AudioBindResponse(BaseResponse):
data: Optional[AudioBindData]
#-------------------------------------------------------------------------------

View File

@ -156,6 +156,21 @@ class UserServiceTest:
else: else:
raise Exception("音频上传测试失败") raise Exception("音频上传测试失败")
def test_bind_audio(self):
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
def test_audio_download(self): def test_audio_download(self):
url = f"{self.socket}/users/audio/{self.audio_id}" url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("GET", url) response = requests.request("GET", url)
@ -185,6 +200,7 @@ def user_test():
user_service_test.test_hardware_unbind() user_service_test.test_hardware_unbind()
user_service_test.test_upload_audio() user_service_test.test_upload_audio()
user_service_test.test_update_audio() user_service_test.test_update_audio()
user_service_test.test_bind_audio()
user_service_test.test_audio_download() user_service_test.test_audio_download()
user_service_test.test_audio_delete() user_service_test.test_audio_delete()
user_service_test.test_user_delete() user_service_test.test_user_delete()