diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index d56ae10..5b5a8b9 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -325,6 +325,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8') header = struct.pack('!II',len(response_bytes),len(audio)) message_bytes = header + response_bytes + audio + await ws.send_bytes(message_bytes) logger.debug(f"websocket返回: {sentence}") if is_end: logger.debug(f"llm返回结果: {llm_response}") diff --git a/app/controllers/session_controller.py b/app/controllers/session_controller.py index bac9e4a..507fe6b 100644 --- a/app/controllers/session_controller.py +++ b/app/controllers/session_controller.py @@ -77,3 +77,32 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest, raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat()) return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data) + +#更新Session中的Speaker Id信息 +async def update_session_speaker_id_handler(session_id, session_data, db, redis): + existing_session = "" + if redis.exists(session_id): + existing_session = redis.get(session_id) + else: + existing_session = db.query(Session).filter(Session.id == session_id).first().content + if not existing_session: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + + #更新Session字段 + session = json.loads(existing_session) + session_llm_info = json.loads(session["llm_info"]) + session_llm_info["speaker_id"] = session_data.speaker_id + session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False) + + #存储Session + + session_str = json.dumps(session,ensure_ascii=False) + redis.set(session_id, session_str) + try: + db.query(Session).filter(Session.id == session_id).update({"content": session_str}) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat()) + return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data) \ No newline at end of file diff --git a/app/routes/session_route.py b/app/routes/session_route.py index 6dd26c1..5500565 100644 --- a/app/routes/session_route.py +++ b/app/routes/session_route.py @@ -26,4 +26,11 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red @router.put("/sessions/{session_id}", response_model=SessionUpdateResponse) async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)): response = await update_session_handler(session_id, session_data, db, redis) + return response + + +#session声音信息更新接口 +@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse) +async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)): + response = await update_session_speaker_id_handler(session_id, session_data, db, redis) return response \ No newline at end of file diff --git a/app/schemas/session_schema.py b/app/schemas/session_schema.py index 3cfa60d..f0256dd 100644 --- a/app/schemas/session_schema.py +++ b/app/schemas/session_schema.py @@ -55,4 +55,16 @@ class SessionUpdateData(BaseModel): #session修改响应类 class SessionUpdateResponse(BaseResponse): data: Optional[SessionUpdateData] +#-------------------------------------------------------------------------- + + +#------------------------------Session Speaker Id修改---------------------- +class SessionSpeakerUpdateRequest(BaseModel): + speaker_id: int + +class SessionSpeakerUpdateData(BaseModel): + updatedAt:str + +class SessionSpeakerUpdateResponse(BaseResponse): + data: Optional[SessionSpeakerUpdateData] #-------------------------------------------------------------------------- \ No newline at end of file diff --git a/tests/unit_test/chat_test.py b/tests/unit_test/chat_test.py index df26cc0..3c00824 100644 --- a/tests/unit_test/chat_test.py +++ b/tests/unit_test/chat_test.py @@ -63,7 +63,7 @@ class ChatServiceTest: current_file_path = os.path.abspath(__file__) current_file_path = os.path.dirname(current_file_path) tests_dir = os.path.dirname(current_file_path) - mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3') + mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') with open(mp3_file_path, 'rb') as audio_file: files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')} response = requests.post(url,files=files) @@ -135,7 +135,7 @@ class ChatServiceTest: "user_id": self.user_id, "messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]", "user_info": "{\"character\": \"\", \"events\": [] }", - "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\":1.0}", + "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}", "llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}", "token": 0} ) @@ -148,6 +148,19 @@ class ChatServiceTest: else: raise Exception("Session更新测试失败") + def test_session_speakerid_update(self): + url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}" + payload = json.dumps({ + "speaker_id" :37 + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("PUT", url, headers=headers, data=payload) + if response.status_code == 200: + print("Session SpeakerId更新测试成功") + else: + raise Exception("Session SpeakerId更新测试失败") #测试单次聊天 async def test_chat_temporary(self): @@ -355,10 +368,11 @@ def chat_test(): chat_service_test.test_create_chat() chat_service_test.test_session_id_query() chat_service_test.test_session_content_query() - # chat_service_test.test_session_update() + chat_service_test.test_session_update() + chat_service_test.test_session_speakerid_update() asyncio.run(chat_service_test.test_chat_temporary()) - # asyncio.run(chat_service_test.test_chat_lasting()) - # asyncio.run(chat_service_test.test_voice_call()) + asyncio.run(chat_service_test.test_chat_lasting()) + asyncio.run(chat_service_test.test_voice_call()) chat_service_test.test_chat_delete()