feat: 增加修改speaker_id接口
This commit is contained in:
parent
ce033dca2b
commit
30fdb9c6bd
|
@ -325,6 +325,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
||||
header = struct.pack('!II',len(response_bytes),len(audio))
|
||||
message_bytes = header + response_bytes + audio
|
||||
await ws.send_bytes(message_bytes)
|
||||
logger.debug(f"websocket返回: {sentence}")
|
||||
if is_end:
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
|
|
|
@ -77,3 +77,32 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
||||
|
||||
#更新Session中的Speaker Id信息
|
||||
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
|
||||
existing_session = ""
|
||||
if redis.exists(session_id):
|
||||
existing_session = redis.get(session_id)
|
||||
else:
|
||||
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
||||
if not existing_session:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
|
||||
#更新Session字段
|
||||
session = json.loads(existing_session)
|
||||
session_llm_info = json.loads(session["llm_info"])
|
||||
session_llm_info["speaker_id"] = session_data.speaker_id
|
||||
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
|
||||
|
||||
#存储Session
|
||||
|
||||
session_str = json.dumps(session,ensure_ascii=False)
|
||||
redis.set(session_id, session_str)
|
||||
try:
|
||||
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)
|
|
@ -26,4 +26,11 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
|
|||
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
|
||||
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await update_session_handler(session_id, session_data, db, redis)
|
||||
return response
|
||||
|
||||
|
||||
#session声音信息更新接口
|
||||
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
|
||||
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
|
||||
return response
|
|
@ -55,4 +55,16 @@ class SessionUpdateData(BaseModel):
|
|||
#session修改响应类
|
||||
class SessionUpdateResponse(BaseResponse):
|
||||
data: Optional[SessionUpdateData]
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
|
||||
#------------------------------Session Speaker Id修改----------------------
|
||||
class SessionSpeakerUpdateRequest(BaseModel):
|
||||
speaker_id: int
|
||||
|
||||
class SessionSpeakerUpdateData(BaseModel):
|
||||
updatedAt:str
|
||||
|
||||
class SessionSpeakerUpdateResponse(BaseResponse):
|
||||
data: Optional[SessionSpeakerUpdateData]
|
||||
#--------------------------------------------------------------------------
|
|
@ -63,7 +63,7 @@ class ChatServiceTest:
|
|||
current_file_path = os.path.abspath(__file__)
|
||||
current_file_path = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_file_path)
|
||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3')
|
||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(mp3_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||
response = requests.post(url,files=files)
|
||||
|
@ -135,7 +135,7 @@ class ChatServiceTest:
|
|||
"user_id": self.user_id,
|
||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||
"user_info": "{\"character\": \"\", \"events\": [] }",
|
||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\":1.0}",
|
||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
|
||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||
"token": 0}
|
||||
)
|
||||
|
@ -148,6 +148,19 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("Session更新测试失败")
|
||||
|
||||
def test_session_speakerid_update(self):
|
||||
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
|
||||
payload = json.dumps({
|
||||
"speaker_id" :37
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("Session SpeakerId更新测试成功")
|
||||
else:
|
||||
raise Exception("Session SpeakerId更新测试失败")
|
||||
|
||||
#测试单次聊天
|
||||
async def test_chat_temporary(self):
|
||||
|
@ -355,10 +368,11 @@ def chat_test():
|
|||
chat_service_test.test_create_chat()
|
||||
chat_service_test.test_session_id_query()
|
||||
chat_service_test.test_session_content_query()
|
||||
# chat_service_test.test_session_update()
|
||||
chat_service_test.test_session_update()
|
||||
chat_service_test.test_session_speakerid_update()
|
||||
asyncio.run(chat_service_test.test_chat_temporary())
|
||||
# asyncio.run(chat_service_test.test_chat_lasting())
|
||||
# asyncio.run(chat_service_test.test_voice_call())
|
||||
asyncio.run(chat_service_test.test_chat_lasting())
|
||||
asyncio.run(chat_service_test.test_voice_call())
|
||||
chat_service_test.test_chat_delete()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue