diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 9d31bbb..b5007bf 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -15,6 +15,7 @@ import uuid import json import asyncio import aiohttp +import io # 依赖注入获取logger logger = get_logger() @@ -107,10 +108,10 @@ def get_emb(session_id,db): try: session_record = db.query(Session).filter(Session.id == session_id).first() user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first() - audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first() - emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32) - emb_npy_3d = np.reshape(emb_npy,(1,256,1)) - return emb_npy_3d + user_record = db.query(User).filter(User.id == user_character_record.user_id).first() + audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first() + emb_npy = np.load(io.BytesIO(audio_record.emb_data)) + return emb_npy except Exception as e: logger.error("未找到音频:"+str(e)) return np.array([]) @@ -149,7 +150,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis): "noise_scale": 0.1, "noise_scale_w":0.668, "length_scale": 1.2, - "speend":1 + "speed":1 } llm_info = { "model": "abab5.5-chat", @@ -243,7 +244,7 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f logger.error(f"用户输入处理函数发生错误: {str(e)}") #语音识别 -async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event): +async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") is_signup = False audio = "" @@ -259,6 +260,9 @@ async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_ current_message += ''.join(asr_result['text']) asr_result = asr.streaming_recognize(session_id,b'',is_end=True) current_message += ''.join(asr_result['text']) + if current_message == "": + await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False)) + return current_message = asr.punctuation_correction(current_message) emotion_dict = asr.emtion_recognition(audio) #情感辨识 if not isinstance(emotion_dict, str): @@ -366,7 +370,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): session_id = await future_session_id #获取session_id update_session_activity(session_id,db) response_type = await future_response_type #获取返回类型 - asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event)) + asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event)) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index aee1da3..9788675 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -163,19 +163,17 @@ async def get_hardware_handler(hardware_id, db): async def upload_audio_handler(user_id, audio, db): try: audio_data = await audio.read() - raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data - numpy_data = np.frombuffer(raw_data, dtype=np.int32) - emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes() - new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频 - db.flush() - existing_user = db.query(User).filter(User.id == user_id).first() - existing_user.selected_audio_id = new_audio.id #绑定音频到用户 + emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True) + out = io.BytesIO() + np.save(out, emb_data) + out.seek(0) + emb_binary = out.read() + new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频 db.add(new_audio) db.commit() db.refresh(new_audio) except Exception as e: db.rollback() - print(str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat()) return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data) diff --git a/tests/unit_test/chat_test.py b/tests/unit_test/chat_test.py index 97225de..df26cc0 100644 --- a/tests/unit_test/chat_test.py +++ b/tests/unit_test/chat_test.py @@ -63,7 +63,7 @@ class ChatServiceTest: current_file_path = os.path.abspath(__file__) current_file_path = os.path.dirname(current_file_path) tests_dir = os.path.dirname(current_file_path) - mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') + mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3') with open(mp3_file_path, 'rb') as audio_file: files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')} response = requests.post(url,files=files) @@ -73,6 +73,22 @@ class ChatServiceTest: else: raise Exception("音频上传失败") + #绑定音频 + url = f"{self.socket}/users/audio/bind" + payload = json.dumps({ + "user_id":self.user_id, + "audio_id":self.audio_id + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("音频绑定测试成功") + else: + raise Exception("音频绑定测试失败") + + #创建一个对话 url = f"{self.socket}/chats" payload = json.dumps({ @@ -339,10 +355,10 @@ def chat_test(): chat_service_test.test_create_chat() chat_service_test.test_session_id_query() chat_service_test.test_session_content_query() - chat_service_test.test_session_update() + # chat_service_test.test_session_update() asyncio.run(chat_service_test.test_chat_temporary()) - asyncio.run(chat_service_test.test_chat_lasting()) - asyncio.run(chat_service_test.test_voice_call()) + # asyncio.run(chat_service_test.test_chat_lasting()) + # asyncio.run(chat_service_test.test_voice_call()) chat_service_test.test_chat_delete()