feat: 增加了用户音频的增删查改

This commit is contained in:
killua4396 2024-05-16 13:28:47 +08:00
parent 4322b03418
commit 4a256fa506
6 changed files with 132 additions and 184 deletions

View File

@ -106,7 +106,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
try: try:
db.add(new_chat) db.add(new_chat)
db.commit() db.commit()
db.refresh(new_chat)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -154,7 +153,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 将Session记录存入 # 将Session记录存入
db.add(new_session) db.add(new_session)
db.commit() db.commit()
db.refresh(new_session)
redis.set(session_id, json.dumps(content, ensure_ascii=False)) redis.set(session_id, json.dumps(content, ensure_ascii=False))
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat()) chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())

View File

@ -1,6 +1,6 @@
from ..schemas.user_schema import * from ..schemas.user_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..models import User, Hardware from ..models import User, Hardware, Audio
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -36,7 +36,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
existing_user.persona = user.persona existing_user.persona = user.persona
try: try:
db.commit() db.commit()
db.refresh(existing_user)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.user_id = user.user_id existing_hardware.user_id = user.user_id
db.commit() db.commit()
db.refresh(existing_hardware)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -135,7 +133,6 @@ async def update_hardware_handler(hardware_id, hardware, db):
existing_hardware.firmware = hardware.firmware existing_hardware.firmware = hardware.firmware
existing_hardware.model = hardware.model existing_hardware.model = hardware.model
db.commit() db.commit()
db.refresh(existing_hardware)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -154,3 +151,66 @@ async def get_hardware_handler(hardware_id, db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model) hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data) return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
#用户上传音频
async def upload_audio_handler(user_id, audio, db):
try:
audio_data = audio.file.read()
new_audio = Audio(user_id=user_id, audio_data=audio_data)
db.add(new_audio)
db.commit()
db.refresh(new_audio)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
#用户更新音频
async def update_audio_handler(audio_id, audio_file, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = audio_file.file.read()
existing_audio.audio_data = audio_data
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat())
return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data)
#用户查询音频
async def download_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = existing_audio.audio_data
return audio_data
#用户删除音频
async def delete_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
try:
db.delete(existing_audio)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() Base = declarative_base()
@ -80,3 +80,9 @@ class Session(Base):
def __repr__(self): def __repr__(self):
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>" return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
class Audio(Base):
__tablename__ = 'audio'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('user.id'))
audio_data = Column(LargeBinary)

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, UploadFile, File, Response
from ..controllers.user_controller import * from ..controllers.user_controller import *
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -69,3 +69,31 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)): async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
response = await get_hardware_handler(hardware_id, db) response = await get_hardware_handler(hardware_id, db)
return response return response
#用户音频上传
@router.post('/users/audio',response_model=AudioUploadResponse)
async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await upload_audio_handler(user_id, audio_file, db)
return response
#用户音频修改
@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse)
async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await update_audio_handler(audio_id, audio_file, db)
return response
#用户音频下载
@router.get('/users/audio/{audio_id}')
async def download_audio(audio_id:int, db: Session = Depends(get_db)):
audio_data = await download_audio_handler(audio_id, db)
return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"})
#用户音频删除
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
response = await delete_audio_handler(audio_id, db)
return response

View File

@ -3,7 +3,6 @@ from typing import Optional
from .base_schema import BaseResponse from .base_schema import BaseResponse
#---------------------------------用户创建---------------------------------- #---------------------------------用户创建----------------------------------
#用户创建请求类 #用户创建请求类
class UserCrateRequest(BaseModel): class UserCrateRequest(BaseModel):
@ -138,3 +137,32 @@ class HardwareQueryData(BaseModel):
class HardwareQueryResponse(BaseResponse): class HardwareQueryResponse(BaseResponse):
data: Optional[HardwareQueryData] data: Optional[HardwareQueryData]
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
#-------------------------------用户音频上传-------------------------------------
class AudioUploadData(BaseModel):
audio_id: int
uploadedAt: str
class AudioUploadResponse(BaseResponse):
data: Optional[AudioUploadData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频修改-------------------------------------
class AudioUpdateData(BaseModel):
updatedAt: str
class AudioUpdateResponse(BaseResponse):
data: Optional[AudioUpdateData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频删除-------------------------------------
class AudioDeleteData(BaseModel):
deletedAt: str
class AudioDeleteResponse(BaseResponse):
data: Optional[AudioDeleteData]
#-------------------------------------------------------------------------------

View File

@ -242,175 +242,3 @@ class FunAutoSpeechRecognizer(STTBase):
self.audio_cache[session_id] = audio_cache self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache self.asr_cache[session_id] = asr_cache
return text_dict return text_dict
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
# import io
# import numpy as np
# import base64
# import wave
# from funasr import AutoModel
# from .base_stt import STTBase
# def decode_str2bytes(data):
# # 将Base64编码的字节串解码为字节串
# if data is None:
# return None
# return base64.b64decode(data.encode('utf-8'))
# class FunAutoSpeechRecognizer(STTBase):
# def __init__(self,
# model_path="paraformer-zh-streaming",
# device="cuda",
# RATE=16000,
# cfg_path=None,
# debug=False,
# chunk_ms=480,
# encoder_chunk_look_back=4,
# decoder_chunk_look_back=1,
# **kwargs):
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
# if chunk_ms == 480:
# self.chunk_size = [0, 8, 4]
# elif chunk_ms == 600:
# self.chunk_size = [0, 10, 5]
# else:
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
# self.chunk_partial_size = self.chunk_size[1] * 960
# self.audio_cache = None
# self.asr_cache = {}
# self._init_asr()
# def check_audio_type(self, audio_data):
# """check audio data type and convert it to bytes if necessary."""
# if isinstance(audio_data, bytes):
# pass
# elif isinstance(audio_data, list):
# audio_data = b''.join(audio_data)
# elif isinstance(audio_data, str):
# audio_data = decode_str2bytes(audio_data)
# elif isinstance(audio_data, io.BytesIO):
# wf = wave.open(audio_data, 'rb')
# audio_data = wf.readframes(wf.getnframes())
# elif isinstance(audio_data, np.ndarray):
# pass
# else:
# raise TypeError(f"audio_data must be bytes, list, str, \
# io.BytesIO or numpy array, but got {type(audio_data)}")
# if isinstance(audio_data, bytes):
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
# elif isinstance(audio_data, np.ndarray):
# if audio_data.dtype != np.int16:
# audio_data = audio_data.astype(np.int16)
# else:
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# return audio_data
# def _init_asr(self):
# # 随机初始化一段音频数据
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# self.audio_cache = None
# self.asr_cache = {}
# # print("init ASR model done.")
# def recognize(self, audio_data):
# """recognize audio data to text"""
# audio_data = self.check_audio_type(audio_data)
# result = self.asr_model.generate(input=audio_data,
# batch_size_s=300,
# hotword=self.hotwords)
# # print(result)
# text = ''
# for res in result:
# text += res['text']
# return text
# def streaming_recognize(self,
# audio_data,
# is_end=False,
# auto_det_end=False):
# """recognize partial result
# Args:
# audio_data: bytes or numpy array, partial audio data
# is_end: bool, whether the audio data is the end of a sentence
# auto_det_end: bool, whether to automatically detect the end of a audio data
# """
# text_dict = dict(text=[], is_end=is_end)
# audio_data = self.check_audio_type(audio_data)
# if self.audio_cache is None:
# self.audio_cache = audio_data
# else:
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
# if self.audio_cache.shape[0] > 0:
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
# return text_dict
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
# if is_end:
# # if the audio data is the end of a sentence, \
# # we need to add one more chunk to the end to \
# # ensure the end of the sentence is recognized correctly.
# auto_det_end = True
# if auto_det_end:
# total_chunk_num += 1
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
# end_idx = None
# for i in range(total_chunk_num):
# if auto_det_end:
# is_end = i == total_chunk_num - 1
# start_idx = i*self.chunk_partial_size
# if auto_det_end:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
# else:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# # t_stamp = time.time()
# speech_chunk = self.audio_cache[start_idx:end_idx]
# # TODO: exceptions processes
# try:
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# except ValueError as e:
# print(f"ValueError: {e}")
# continue
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# # print(f"each chunk time: {time.time()-t_stamp}")
# if is_end:
# self.audio_cache = None
# self.asr_cache = {}
# else:
# if end_idx:
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
# text_dict['is_end'] = is_end
# # print(f"text_dict: {text_dict}")
# return text_dict