feat: 用户表添加selected_audio_id字段,添加用户音频绑定接口
This commit is contained in:
parent
773b48471a
commit
387c277c28
|
@ -166,7 +166,10 @@ async def upload_audio_handler(user_id, audio, db):
|
||||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||||
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
||||||
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
||||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
|
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频
|
||||||
|
db.flush()
|
||||||
|
existing_user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
existing_user.selected_audio_id = new_audio.id #绑定音频到用户
|
||||||
db.add(new_audio)
|
db.add(new_audio)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_audio)
|
db.refresh(new_audio)
|
||||||
|
@ -214,16 +217,38 @@ async def download_audio_handler(audio_id, db):
|
||||||
async def delete_audio_handler(audio_id, db):
|
async def delete_audio_handler(audio_id, db):
|
||||||
try:
|
try:
|
||||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||||
|
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
if existing_audio is None:
|
if existing_audio is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||||
try:
|
try:
|
||||||
|
if existing_user.selected_audio_id == audio_id:
|
||||||
|
existing_user.selected_audio_id = None
|
||||||
db.delete(existing_audio)
|
db.delete(existing_audio)
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
|
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
|
||||||
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
|
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
|
||||||
|
|
||||||
|
|
||||||
|
#用户绑定音频
|
||||||
|
async def bind_audio_handler(bind_req, db):
|
||||||
|
try:
|
||||||
|
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
if existing_user is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||||
|
try:
|
||||||
|
existing_user.selected_audio_id = bind_req.audio_id
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
|
||||||
|
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)
|
|
@ -36,6 +36,7 @@ class User(Base):
|
||||||
avatar_id = Column(String(36), nullable=True)
|
avatar_id = Column(String(36), nullable=True)
|
||||||
tags = Column(JSON)
|
tags = Column(JSON)
|
||||||
persona = Column(JSON)
|
persona = Column(JSON)
|
||||||
|
selected_audio_id = Column(Integer, nullable=True)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User(id={self.id}, tags={self.tags})>"
|
return f"<User(id={self.id}, tags={self.tags})>"
|
||||||
|
@ -81,6 +82,7 @@ class Session(Base):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
||||||
|
|
||||||
|
#音频表定义
|
||||||
class Audio(Base):
|
class Audio(Base):
|
||||||
__tablename__ = 'audio'
|
__tablename__ = 'audio'
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
|
|
@ -96,4 +96,11 @@ async def download_audio(audio_id:int, db: Session = Depends(get_db)):
|
||||||
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
|
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
|
||||||
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
|
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
|
||||||
response = await delete_audio_handler(audio_id, db)
|
response = await delete_audio_handler(audio_id, db)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
#用户绑定音频
|
||||||
|
@router.post('/users/audio/bind',response_model=AudioBindResponse)
|
||||||
|
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
|
||||||
|
response = await bind_audio_handler(bind_req, db)
|
||||||
return response
|
return response
|
|
@ -165,4 +165,18 @@ class AudioDeleteData(BaseModel):
|
||||||
|
|
||||||
class AudioDeleteResponse(BaseResponse):
|
class AudioDeleteResponse(BaseResponse):
|
||||||
data: Optional[AudioDeleteData]
|
data: Optional[AudioDeleteData]
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
#-------------------------------用户音频绑定-------------------------------------
|
||||||
|
class AudioBindRequest(BaseModel):
|
||||||
|
audio_id: int
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
class AudioBindData(BaseModel):
|
||||||
|
bindedAt: str
|
||||||
|
|
||||||
|
class AudioBindResponse(BaseResponse):
|
||||||
|
data: Optional[AudioBindData]
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -155,6 +155,21 @@ class UserServiceTest:
|
||||||
print("音频上传测试成功")
|
print("音频上传测试成功")
|
||||||
else:
|
else:
|
||||||
raise Exception("音频上传测试失败")
|
raise Exception("音频上传测试失败")
|
||||||
|
|
||||||
|
def test_bind_audio(self):
|
||||||
|
url = f"{self.socket}/users/audio/bind"
|
||||||
|
payload = json.dumps({
|
||||||
|
"user_id":self.id,
|
||||||
|
"audio_id":self.audio_id
|
||||||
|
})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("音频绑定测试成功")
|
||||||
|
else:
|
||||||
|
raise Exception("音频绑定测试失败")
|
||||||
|
|
||||||
def test_audio_download(self):
|
def test_audio_download(self):
|
||||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||||
|
@ -185,6 +200,7 @@ def user_test():
|
||||||
user_service_test.test_hardware_unbind()
|
user_service_test.test_hardware_unbind()
|
||||||
user_service_test.test_upload_audio()
|
user_service_test.test_upload_audio()
|
||||||
user_service_test.test_update_audio()
|
user_service_test.test_update_audio()
|
||||||
|
user_service_test.test_bind_audio()
|
||||||
user_service_test.test_audio_download()
|
user_service_test.test_audio_download()
|
||||||
user_service_test.test_audio_delete()
|
user_service_test.test_audio_delete()
|
||||||
user_service_test.test_user_delete()
|
user_service_test.test_user_delete()
|
||||||
|
|
Loading…
Reference in New Issue