diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 02b34f5..ff54226 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -106,7 +106,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis): try: db.add(new_chat) db.commit() - db.refresh(new_chat) except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -154,7 +153,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis): # 将Session记录存入 db.add(new_session) db.commit() - db.refresh(new_session) redis.set(session_id, json.dumps(content, ensure_ascii=False)) chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat()) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index a726d0b..4ccd404 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -1,6 +1,6 @@ from ..schemas.user_schema import * from ..dependencies.logger import get_logger -from ..models import User, Hardware +from ..models import User, Hardware, Audio from datetime import datetime from sqlalchemy.orm import Session from fastapi import HTTPException, status @@ -36,7 +36,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session): existing_user.persona = user.persona try: db.commit() - db.refresh(existing_user) except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -117,7 +116,6 @@ async def change_bind_hardware_handler(hardware_id, user, db): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") existing_hardware.user_id = user.user_id db.commit() - db.refresh(existing_hardware) except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -135,7 +133,6 @@ async def update_hardware_handler(hardware_id, hardware, db): existing_hardware.firmware = hardware.firmware existing_hardware.model = hardware.model db.commit() - db.refresh(existing_hardware) except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -153,4 +150,67 @@ async def get_hardware_handler(hardware_id, db): if existing_hardware is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model) - return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data) \ No newline at end of file + return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data) + + +#用户上传音频 +async def upload_audio_handler(user_id, audio, db): + try: + audio_data = audio.file.read() + new_audio = Audio(user_id=user_id, audio_data=audio_data) + db.add(new_audio) + db.commit() + db.refresh(new_audio) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat()) + return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data) + + +#用户更新音频 +async def update_audio_handler(audio_id, audio_file, db): + try: + existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() + if existing_audio is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") + audio_data = audio_file.file.read() + existing_audio.audio_data = audio_data + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat()) + return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data) + + +#用户查询音频 +async def download_audio_handler(audio_id, db): + try: + existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_audio is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") + audio_data = existing_audio.audio_data + return audio_data + + +#用户删除音频 +async def delete_audio_handler(audio_id, db): + try: + existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_audio is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") + try: + db.delete(existing_audio) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat()) + return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data) \ No newline at end of file diff --git a/app/models/models.py b/app/models/models.py index baa390a..ef9a4e4 100644 --- a/app/models/models.py +++ b/app/models/models.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR +from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() @@ -80,3 +80,9 @@ class Session(Base): def __repr__(self): return f"" + +class Audio(Base): + __tablename__ = 'audio' + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('user.id')) + audio_data = Column(LargeBinary) \ No newline at end of file diff --git a/app/routes/user_route.py b/app/routes/user_route.py index a0bc6a7..1a42bc1 100644 --- a/app/routes/user_route.py +++ b/app/routes/user_route.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, UploadFile, File, Response from ..controllers.user_controller import * from fastapi import Depends from sqlalchemy.orm import Session @@ -68,4 +68,32 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest @router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse) async def get_hardware(hardware_id: int, db: Session = Depends(get_db)): response = await get_hardware_handler(hardware_id, db) + return response + + +#用户音频上传 +@router.post('/users/audio',response_model=AudioUploadResponse) +async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)): + response = await upload_audio_handler(user_id, audio_file, db) + return response + + +#用户音频修改 +@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse) +async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)): + response = await update_audio_handler(audio_id, audio_file, db) + return response + + +#用户音频下载 +@router.get('/users/audio/{audio_id}') +async def download_audio(audio_id:int, db: Session = Depends(get_db)): + audio_data = await download_audio_handler(audio_id, db) + return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"}) + + +#用户音频删除 +@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse) +async def delete_audio(audio_id:int, db: Session = Depends(get_db)): + response = await delete_audio_handler(audio_id, db) return response \ No newline at end of file diff --git a/app/schemas/user_schema.py b/app/schemas/user_schema.py index 4157397..53adc42 100644 --- a/app/schemas/user_schema.py +++ b/app/schemas/user_schema.py @@ -3,7 +3,6 @@ from typing import Optional from .base_schema import BaseResponse - #---------------------------------用户创建---------------------------------- #用户创建请求类 class UserCrateRequest(BaseModel): @@ -137,4 +136,33 @@ class HardwareQueryData(BaseModel): class HardwareQueryResponse(BaseResponse): data: Optional[HardwareQueryData] -#------------------------------------------------------------------------------ \ No newline at end of file +#------------------------------------------------------------------------------ + + + +#-------------------------------用户音频上传------------------------------------- +class AudioUploadData(BaseModel): + audio_id: int + uploadedAt: str + +class AudioUploadResponse(BaseResponse): + data: Optional[AudioUploadData] +#------------------------------------------------------------------------------- + + +#-------------------------------用户音频修改------------------------------------- +class AudioUpdateData(BaseModel): + updatedAt: str + +class AudioUpdateResponse(BaseResponse): + data: Optional[AudioUpdateData] +#------------------------------------------------------------------------------- + + +#-------------------------------用户音频删除------------------------------------- +class AudioDeleteData(BaseModel): + deletedAt: str + +class AudioDeleteResponse(BaseResponse): + data: Optional[AudioDeleteData] +#------------------------------------------------------------------------------- \ No newline at end of file diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py index 9b7eec7..8300ad2 100644 --- a/utils/stt/funasr_utils.py +++ b/utils/stt/funasr_utils.py @@ -241,176 +241,4 @@ class FunAutoSpeechRecognizer(STTBase): self.audio_cache[session_id] = audio_cache self.asr_cache[session_id] = asr_cache - return text_dict - - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# ####################################################### # -# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR -# ####################################################### # -# import io -# import numpy as np -# import base64 -# import wave -# from funasr import AutoModel -# from .base_stt import STTBase - -# def decode_str2bytes(data): -# # 将Base64编码的字节串解码为字节串 -# if data is None: -# return None -# return base64.b64decode(data.encode('utf-8')) - -# class FunAutoSpeechRecognizer(STTBase): -# def __init__(self, -# model_path="paraformer-zh-streaming", -# device="cuda", -# RATE=16000, -# cfg_path=None, -# debug=False, -# chunk_ms=480, -# encoder_chunk_look_back=4, -# decoder_chunk_look_back=1, -# **kwargs): -# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) - -# self.asr_model = AutoModel(model=model_path, device=device, **kwargs) - -# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention -# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention - -# #[0, 8, 4] 480ms, [0, 10, 5] 600ms -# if chunk_ms == 480: -# self.chunk_size = [0, 8, 4] -# elif chunk_ms == 600: -# self.chunk_size = [0, 10, 5] -# else: -# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.") -# self.chunk_partial_size = self.chunk_size[1] * 960 -# self.audio_cache = None -# self.asr_cache = {} - - - -# self._init_asr() - -# def check_audio_type(self, audio_data): -# """check audio data type and convert it to bytes if necessary.""" -# if isinstance(audio_data, bytes): -# pass -# elif isinstance(audio_data, list): -# audio_data = b''.join(audio_data) -# elif isinstance(audio_data, str): -# audio_data = decode_str2bytes(audio_data) -# elif isinstance(audio_data, io.BytesIO): -# wf = wave.open(audio_data, 'rb') -# audio_data = wf.readframes(wf.getnframes()) -# elif isinstance(audio_data, np.ndarray): -# pass -# else: -# raise TypeError(f"audio_data must be bytes, list, str, \ -# io.BytesIO or numpy array, but got {type(audio_data)}") - -# if isinstance(audio_data, bytes): -# audio_data = np.frombuffer(audio_data, dtype=np.int16) -# elif isinstance(audio_data, np.ndarray): -# if audio_data.dtype != np.int16: -# audio_data = audio_data.astype(np.int16) -# else: -# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") -# return audio_data - -# def _init_asr(self): -# # 随机初始化一段音频数据 -# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) -# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) -# self.audio_cache = None -# self.asr_cache = {} -# # print("init ASR model done.") - -# def recognize(self, audio_data): -# """recognize audio data to text""" -# audio_data = self.check_audio_type(audio_data) -# result = self.asr_model.generate(input=audio_data, -# batch_size_s=300, -# hotword=self.hotwords) - -# # print(result) -# text = '' -# for res in result: -# text += res['text'] -# return text - -# def streaming_recognize(self, -# audio_data, -# is_end=False, -# auto_det_end=False): -# """recognize partial result - -# Args: -# audio_data: bytes or numpy array, partial audio data -# is_end: bool, whether the audio data is the end of a sentence -# auto_det_end: bool, whether to automatically detect the end of a audio data -# """ -# text_dict = dict(text=[], is_end=is_end) - -# audio_data = self.check_audio_type(audio_data) -# if self.audio_cache is None: -# self.audio_cache = audio_data -# else: -# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") -# if self.audio_cache.shape[0] > 0: -# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) - -# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: -# return text_dict - -# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) - -# if is_end: -# # if the audio data is the end of a sentence, \ -# # we need to add one more chunk to the end to \ -# # ensure the end of the sentence is recognized correctly. -# auto_det_end = True - -# if auto_det_end: -# total_chunk_num += 1 - -# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") -# end_idx = None -# for i in range(total_chunk_num): -# if auto_det_end: -# is_end = i == total_chunk_num - 1 -# start_idx = i*self.chunk_partial_size -# if auto_det_end: -# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 -# else: -# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 -# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") -# # t_stamp = time.time() - -# speech_chunk = self.audio_cache[start_idx:end_idx] - -# # TODO: exceptions processes -# try: -# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) -# except ValueError as e: -# print(f"ValueError: {e}") -# continue -# text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) -# # print(f"each chunk time: {time.time()-t_stamp}") - -# if is_end: -# self.audio_cache = None -# self.asr_cache = {} -# else: -# if end_idx: -# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache -# text_dict['is_end'] = is_end - -# # print(f"text_dict: {text_dict}") -# return text_dict - - - \ No newline at end of file + return text_dict \ No newline at end of file