feat: 增加修改speaker_id接口

This commit is contained in:
killua4396 2024-05-29 18:41:35 +08:00
parent ce033dca2b
commit 30fdb9c6bd
5 changed files with 68 additions and 5 deletions

View File

@ -325,6 +325,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(response_bytes),len(audio))
message_bytes = header + response_bytes + audio
await ws.send_bytes(message_bytes)
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")

View File

@ -77,3 +77,32 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
#更新Session中的Speaker Id信息
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
existing_session = ""
if redis.exists(session_id):
existing_session = redis.get(session_id)
else:
existing_session = db.query(Session).filter(Session.id == session_id).first().content
if not existing_session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
#更新Session字段
session = json.loads(existing_session)
session_llm_info = json.loads(session["llm_info"])
session_llm_info["speaker_id"] = session_data.speaker_id
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
#存储Session
session_str = json.dumps(session,ensure_ascii=False)
redis.set(session_id, session_str)
try:
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)

View File

@ -26,4 +26,11 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_handler(session_id, session_data, db, redis)
return response
#session声音信息更新接口
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
return response

View File

@ -55,4 +55,16 @@ class SessionUpdateData(BaseModel):
#session修改响应类
class SessionUpdateResponse(BaseResponse):
data: Optional[SessionUpdateData]
#--------------------------------------------------------------------------
#------------------------------Session Speaker Id修改----------------------
class SessionSpeakerUpdateRequest(BaseModel):
speaker_id: int
class SessionSpeakerUpdateData(BaseModel):
updatedAt:str
class SessionSpeakerUpdateResponse(BaseResponse):
data: Optional[SessionSpeakerUpdateData]
#--------------------------------------------------------------------------

View File

@ -63,7 +63,7 @@ class ChatServiceTest:
current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3')
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files)
@ -135,7 +135,7 @@ class ChatServiceTest:
"user_id": self.user_id,
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]",
"user_info": "{\"character\": \"\", \"events\": [] }",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\":1.0}",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
"token": 0}
)
@ -148,6 +148,19 @@ class ChatServiceTest:
else:
raise Exception("Session更新测试失败")
def test_session_speakerid_update(self):
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
payload = json.dumps({
"speaker_id" :37
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("Session SpeakerId更新测试成功")
else:
raise Exception("Session SpeakerId更新测试失败")
#测试单次聊天
async def test_chat_temporary(self):
@ -355,10 +368,11 @@ def chat_test():
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
# chat_service_test.test_session_update()
chat_service_test.test_session_update()
chat_service_test.test_session_speakerid_update()
asyncio.run(chat_service_test.test_chat_temporary())
# asyncio.run(chat_service_test.test_chat_lasting())
# asyncio.run(chat_service_test.test_voice_call())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()