feat: 增加修改speaker_id接口
This commit is contained in:
parent
ce033dca2b
commit
30fdb9c6bd
|
@ -325,6 +325,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
||||||
header = struct.pack('!II',len(response_bytes),len(audio))
|
header = struct.pack('!II',len(response_bytes),len(audio))
|
||||||
message_bytes = header + response_bytes + audio
|
message_bytes = header + response_bytes + audio
|
||||||
|
await ws.send_bytes(message_bytes)
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
logger.debug(f"websocket返回: {sentence}")
|
||||||
if is_end:
|
if is_end:
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
|
|
|
@ -77,3 +77,32 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
||||||
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
||||||
|
|
||||||
|
#更新Session中的Speaker Id信息
|
||||||
|
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
|
||||||
|
existing_session = ""
|
||||||
|
if redis.exists(session_id):
|
||||||
|
existing_session = redis.get(session_id)
|
||||||
|
else:
|
||||||
|
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
||||||
|
if not existing_session:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||||
|
|
||||||
|
#更新Session字段
|
||||||
|
session = json.loads(existing_session)
|
||||||
|
session_llm_info = json.loads(session["llm_info"])
|
||||||
|
session_llm_info["speaker_id"] = session_data.speaker_id
|
||||||
|
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
|
||||||
|
|
||||||
|
#存储Session
|
||||||
|
|
||||||
|
session_str = json.dumps(session,ensure_ascii=False)
|
||||||
|
redis.set(session_id, session_str)
|
||||||
|
try:
|
||||||
|
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
|
||||||
|
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)
|
|
@ -26,4 +26,11 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
|
||||||
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
|
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
|
||||||
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||||
response = await update_session_handler(session_id, session_data, db, redis)
|
response = await update_session_handler(session_id, session_data, db, redis)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
#session声音信息更新接口
|
||||||
|
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
|
||||||
|
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||||
|
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
|
||||||
return response
|
return response
|
|
@ -55,4 +55,16 @@ class SessionUpdateData(BaseModel):
|
||||||
#session修改响应类
|
#session修改响应类
|
||||||
class SessionUpdateResponse(BaseResponse):
|
class SessionUpdateResponse(BaseResponse):
|
||||||
data: Optional[SessionUpdateData]
|
data: Optional[SessionUpdateData]
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
#------------------------------Session Speaker Id修改----------------------
|
||||||
|
class SessionSpeakerUpdateRequest(BaseModel):
|
||||||
|
speaker_id: int
|
||||||
|
|
||||||
|
class SessionSpeakerUpdateData(BaseModel):
|
||||||
|
updatedAt:str
|
||||||
|
|
||||||
|
class SessionSpeakerUpdateResponse(BaseResponse):
|
||||||
|
data: Optional[SessionSpeakerUpdateData]
|
||||||
#--------------------------------------------------------------------------
|
#--------------------------------------------------------------------------
|
|
@ -63,7 +63,7 @@ class ChatServiceTest:
|
||||||
current_file_path = os.path.abspath(__file__)
|
current_file_path = os.path.abspath(__file__)
|
||||||
current_file_path = os.path.dirname(current_file_path)
|
current_file_path = os.path.dirname(current_file_path)
|
||||||
tests_dir = os.path.dirname(current_file_path)
|
tests_dir = os.path.dirname(current_file_path)
|
||||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3')
|
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||||
with open(mp3_file_path, 'rb') as audio_file:
|
with open(mp3_file_path, 'rb') as audio_file:
|
||||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||||
response = requests.post(url,files=files)
|
response = requests.post(url,files=files)
|
||||||
|
@ -135,7 +135,7 @@ class ChatServiceTest:
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||||
"user_info": "{\"character\": \"\", \"events\": [] }",
|
"user_info": "{\"character\": \"\", \"events\": [] }",
|
||||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\":1.0}",
|
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
|
||||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||||
"token": 0}
|
"token": 0}
|
||||||
)
|
)
|
||||||
|
@ -148,6 +148,19 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("Session更新测试失败")
|
raise Exception("Session更新测试失败")
|
||||||
|
|
||||||
|
def test_session_speakerid_update(self):
|
||||||
|
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
|
||||||
|
payload = json.dumps({
|
||||||
|
"speaker_id" :37
|
||||||
|
})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Session SpeakerId更新测试成功")
|
||||||
|
else:
|
||||||
|
raise Exception("Session SpeakerId更新测试失败")
|
||||||
|
|
||||||
#测试单次聊天
|
#测试单次聊天
|
||||||
async def test_chat_temporary(self):
|
async def test_chat_temporary(self):
|
||||||
|
@ -355,10 +368,11 @@ def chat_test():
|
||||||
chat_service_test.test_create_chat()
|
chat_service_test.test_create_chat()
|
||||||
chat_service_test.test_session_id_query()
|
chat_service_test.test_session_id_query()
|
||||||
chat_service_test.test_session_content_query()
|
chat_service_test.test_session_content_query()
|
||||||
# chat_service_test.test_session_update()
|
chat_service_test.test_session_update()
|
||||||
|
chat_service_test.test_session_speakerid_update()
|
||||||
asyncio.run(chat_service_test.test_chat_temporary())
|
asyncio.run(chat_service_test.test_chat_temporary())
|
||||||
# asyncio.run(chat_service_test.test_chat_lasting())
|
asyncio.run(chat_service_test.test_chat_lasting())
|
||||||
# asyncio.run(chat_service_test.test_voice_call())
|
asyncio.run(chat_service_test.test_voice_call())
|
||||||
chat_service_test.test_chat_delete()
|
chat_service_test.test_chat_delete()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue