From 387c277c2801e155fa6e810e722d1df5ec33c388 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Thu, 23 May 2024 09:57:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=94=A8=E6=88=B7=E8=A1=A8=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0selected=5Faudio=5Fid=E5=AD=97=E6=AE=B5=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=94=A8=E6=88=B7=E9=9F=B3=E9=A2=91=E7=BB=91?= =?UTF-8?q?=E5=AE=9A=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/user_controller.py | 29 +++++++++++++++++++++++++++-- app/models/models.py | 2 ++ app/routes/user_route.py | 7 +++++++ app/schemas/user_schema.py | 16 +++++++++++++++- tests/unit_test/user_test.py | 16 ++++++++++++++++ 5 files changed, 67 insertions(+), 3 deletions(-) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index bb30fcb..aee1da3 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -166,7 +166,10 @@ async def upload_audio_handler(user_id, audio, db): raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data numpy_data = np.frombuffer(raw_data, dtype=np.int32) emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes() - new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) + new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频 + db.flush() + existing_user = db.query(User).filter(User.id == user_id).first() + existing_user.selected_audio_id = new_audio.id #绑定音频到用户 db.add(new_audio) db.commit() db.refresh(new_audio) @@ -214,16 +217,38 @@ async def download_audio_handler(audio_id, db): async def delete_audio_handler(audio_id, db): try: existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() + existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first() except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) if existing_audio is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") try: + if existing_user.selected_audio_id == audio_id: + existing_user.selected_audio_id = None db.delete(existing_audio) db.commit() except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat()) - return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data) \ No newline at end of file + return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data) + + +#用户绑定音频 +async def bind_audio_handler(bind_req, db): + try: + existing_user = db.query(User).filter(User.id == bind_req.user_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在") + try: + existing_user.selected_audio_id = bind_req.audio_id + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat()) + return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data) \ No newline at end of file diff --git a/app/models/models.py b/app/models/models.py index ca114fe..9bc1084 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -36,6 +36,7 @@ class User(Base): avatar_id = Column(String(36), nullable=True) tags = Column(JSON) persona = Column(JSON) + selected_audio_id = Column(Integer, nullable=True) def __repr__(self): return f"" @@ -81,6 +82,7 @@ class Session(Base): def __repr__(self): return f"" +#音频表定义 class Audio(Base): __tablename__ = 'audio' id = Column(Integer, primary_key=True, autoincrement=True) diff --git a/app/routes/user_route.py b/app/routes/user_route.py index 1a42bc1..b37c542 100644 --- a/app/routes/user_route.py +++ b/app/routes/user_route.py @@ -96,4 +96,11 @@ async def download_audio(audio_id:int, db: Session = Depends(get_db)): @router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse) async def delete_audio(audio_id:int, db: Session = Depends(get_db)): response = await delete_audio_handler(audio_id, db) + return response + + +#用户绑定音频 +@router.post('/users/audio/bind',response_model=AudioBindResponse) +async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)): + response = await bind_audio_handler(bind_req, db) return response \ No newline at end of file diff --git a/app/schemas/user_schema.py b/app/schemas/user_schema.py index 53adc42..372fa56 100644 --- a/app/schemas/user_schema.py +++ b/app/schemas/user_schema.py @@ -165,4 +165,18 @@ class AudioDeleteData(BaseModel): class AudioDeleteResponse(BaseResponse): data: Optional[AudioDeleteData] -#------------------------------------------------------------------------------- \ No newline at end of file +#------------------------------------------------------------------------------- + + +#-------------------------------用户音频绑定------------------------------------- +class AudioBindRequest(BaseModel): + audio_id: int + user_id: int + +class AudioBindData(BaseModel): + bindedAt: str + +class AudioBindResponse(BaseResponse): + data: Optional[AudioBindData] +#------------------------------------------------------------------------------- + diff --git a/tests/unit_test/user_test.py b/tests/unit_test/user_test.py index 40b5821..45b0d7d 100644 --- a/tests/unit_test/user_test.py +++ b/tests/unit_test/user_test.py @@ -155,6 +155,21 @@ class UserServiceTest: print("音频上传测试成功") else: raise Exception("音频上传测试失败") + + def test_bind_audio(self): + url = f"{self.socket}/users/audio/bind" + payload = json.dumps({ + "user_id":self.id, + "audio_id":self.audio_id + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("音频绑定测试成功") + else: + raise Exception("音频绑定测试失败") def test_audio_download(self): url = f"{self.socket}/users/audio/{self.audio_id}" @@ -185,6 +200,7 @@ def user_test(): user_service_test.test_hardware_unbind() user_service_test.test_upload_audio() user_service_test.test_update_audio() + user_service_test.test_bind_audio() user_service_test.test_audio_download() user_service_test.test_audio_delete() user_service_test.test_user_delete()