350 lines
14 KiB
Python
350 lines
14 KiB
Python
import requests
|
||
import base64
|
||
import wave
|
||
import json
|
||
import os
|
||
import uuid
|
||
import asyncio
|
||
import websockets
|
||
|
||
|
||
|
||
class ChatServiceTest:
|
||
def __init__(self,socket="http://127.0.0.1:7878"):
|
||
self.socket = socket
|
||
|
||
|
||
def test_create_chat(self):
|
||
#创建一个用户
|
||
url = f"{self.socket}/users"
|
||
open_id = str(uuid.uuid4())
|
||
payload = json.dumps({
|
||
"open_id": open_id,
|
||
"username": "test_user",
|
||
"avatar_id": "0",
|
||
"tags" : "[]",
|
||
"persona" : "{}"
|
||
})
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
response = requests.request("POST", url, headers=headers, data=payload)
|
||
if response.status_code == 200:
|
||
print("用户创建成功")
|
||
self.user_id = response.json()['data']['user_id']
|
||
else:
|
||
raise Exception("创建聊天时,用户创建失败")
|
||
|
||
#创建一个角色
|
||
url = f"{self.socket}/characters"
|
||
payload = json.dumps({
|
||
"voice_id": 97,
|
||
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
|
||
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
|
||
"name": "test",
|
||
"wakeup_words": "你好啊,海绵宝宝",
|
||
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
|
||
"description": "厨师,做汉堡",
|
||
"emojis": "大笑,微笑",
|
||
"dialogues": "我准备好了"
|
||
})
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
response = requests.request("POST", url, headers=headers, data=payload)
|
||
if response.status_code == 200:
|
||
print("角色创建成功")
|
||
self.character_id = response.json()['data']['character_id']
|
||
else:
|
||
raise Exception("创建聊天时,角色创建失败")
|
||
|
||
#上传音频用于音频克隆
|
||
url = f"{self.socket}/users/audio?user_id={self.user_id}"
|
||
current_file_path = os.path.abspath(__file__)
|
||
current_file_path = os.path.dirname(current_file_path)
|
||
tests_dir = os.path.dirname(current_file_path)
|
||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||
with open(mp3_file_path, 'rb') as audio_file:
|
||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||
response = requests.post(url,files=files)
|
||
if response.status_code == 200:
|
||
self.audio_id = response.json()['data']['audio_id']
|
||
print("音频上传成功")
|
||
else:
|
||
raise Exception("音频上传失败")
|
||
|
||
#创建一个对话
|
||
url = f"{self.socket}/chats"
|
||
payload = json.dumps({
|
||
"user_id": self.user_id,
|
||
"character_id": self.character_id
|
||
})
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
response = requests.request("POST", url, headers=headers, data=payload)
|
||
if response.status_code == 200:
|
||
print("对话创建成功")
|
||
self.session_id = response.json()['data']['session_id']
|
||
self.user_character_id = response.json()['data']['user_character_id']
|
||
else:
|
||
raise Exception("对话创建测试失败")
|
||
|
||
|
||
#测试查询session_id
|
||
def test_session_id_query(self):
|
||
url = f"{self.socket}/sessions?user_id={self.user_id}&character_id={self.character_id}"
|
||
response = requests.request("GET", url)
|
||
if response.status_code == 200:
|
||
print("session_id查询测试成功")
|
||
else:
|
||
raise Exception("session_id查询测试失败")
|
||
|
||
|
||
#测试查询session内容
|
||
def test_session_content_query(self):
|
||
url = f"{self.socket}/sessions/{self.session_id}"
|
||
response = requests.request("GET", url)
|
||
if response.status_code == 200:
|
||
print("session内容查询测试成功")
|
||
else:
|
||
raise Exception("session内容查询测试失败")
|
||
|
||
|
||
#测试修改session
|
||
def test_session_update(self):
|
||
url = f"{self.socket}/sessions/{self.session_id}"
|
||
payload = json.dumps({
|
||
"user_id": self.user_id,
|
||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||
"user_info": "{\"character\": \"\", \"events\": [] }",
|
||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\":1.0}",
|
||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||
"token": 0}
|
||
)
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||
if response.status_code == 200:
|
||
print("Session更新测试成功")
|
||
else:
|
||
raise Exception("Session更新测试失败")
|
||
|
||
|
||
#测试单次聊天
|
||
async def test_chat_temporary(self):
|
||
current_file_path = os.path.abspath(__file__)
|
||
current_dir = os.path.dirname(current_file_path)
|
||
tests_dir = os.path.dirname(current_dir)
|
||
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
|
||
def read_wav_file_in_chunks(chunk_size):
|
||
with open(wav_file_path, 'rb') as pcm_file:
|
||
while True:
|
||
data = pcm_file.read(chunk_size)
|
||
if not data:
|
||
break
|
||
yield data
|
||
data = {
|
||
"text": "",
|
||
"audio": "",
|
||
"meta_info": {
|
||
"session_id":self.session_id,
|
||
"stream": True,
|
||
"voice_synthesize": True,
|
||
"is_end": False,
|
||
"encoding": "raw"
|
||
}
|
||
}
|
||
|
||
#发送音频数据
|
||
async def send_audio_chunk(websocket, chunk):
|
||
encoded_data = base64.b64encode(chunk).decode('utf-8')
|
||
data["audio"] = encoded_data
|
||
message = json.dumps(data)
|
||
await websocket.send(message)
|
||
|
||
|
||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/temporary') as websocket:
|
||
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
|
||
for chunk in chunks:
|
||
await send_audio_chunk(websocket, chunk)
|
||
await asyncio.sleep(0.01)
|
||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||
data["meta_info"]["is_end"] = True
|
||
# 发送最后一个数据块和流结束信号
|
||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||
|
||
audio_bytes = b''
|
||
while True:
|
||
data_ws = await websocket.recv()
|
||
try:
|
||
message_json = json.loads(data_ws)
|
||
if message_json["type"] == "close":
|
||
print("单次聊天测试成功")
|
||
break # 如果没有接收到消息,则退出循环
|
||
except Exception as e:
|
||
audio_bytes += data_ws
|
||
|
||
await asyncio.sleep(0.04) # 等待0.04秒后断开连接
|
||
await websocket.close()
|
||
|
||
|
||
#测试持续聊天
|
||
async def test_chat_lasting(self):
|
||
current_file_path = os.path.abspath(__file__)
|
||
current_dir = os.path.dirname(current_file_path)
|
||
tests_dir = os.path.dirname(current_dir)
|
||
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
|
||
def read_wav_file_in_chunks(chunk_size):
|
||
with open(wav_file_path, 'rb') as pcm_file:
|
||
while True:
|
||
data = pcm_file.read(chunk_size)
|
||
if not data:
|
||
break
|
||
yield data
|
||
data = {
|
||
"text": "",
|
||
"audio": "",
|
||
"meta_info": {
|
||
"session_id":self.session_id,
|
||
"stream": True,
|
||
"voice_synthesize": True,
|
||
"is_end": False,
|
||
"encoding": "raw"
|
||
},
|
||
"is_close":False
|
||
}
|
||
async def send_audio_chunk(websocket, chunk):
|
||
encoded_data = base64.b64encode(chunk).decode('utf-8')
|
||
data["audio"] = encoded_data
|
||
message = json.dumps(data)
|
||
await websocket.send(message)
|
||
|
||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/lasting') as websocket:
|
||
#发送第一次
|
||
chunks = read_wav_file_in_chunks(2048)
|
||
for chunk in chunks:
|
||
await send_audio_chunk(websocket, chunk)
|
||
await asyncio.sleep(0.01)
|
||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||
data["meta_info"]["is_end"] = True
|
||
# 发送最后一个数据块和流结束信号
|
||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||
|
||
await asyncio.sleep(3) #模拟发送间隔
|
||
|
||
#发送第二次
|
||
data["meta_info"]["is_end"] = False
|
||
chunks = read_wav_file_in_chunks(2048)
|
||
for chunk in chunks:
|
||
await send_audio_chunk(websocket, chunk)
|
||
await asyncio.sleep(0.01)
|
||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||
data["meta_info"]["is_end"] = True
|
||
# 发送最后一个数据块和流结束信号
|
||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||
|
||
data["is_close"] = True
|
||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||
|
||
|
||
audio_bytes = b''
|
||
while True:
|
||
data_ws = await websocket.recv()
|
||
try:
|
||
message_json = json.loads(data_ws)
|
||
if message_json["type"] == "close":
|
||
print("持续聊天测试成功")
|
||
break # 如果没有接收到消息,则退出循环
|
||
except Exception as e:
|
||
audio_bytes += data_ws
|
||
|
||
await asyncio.sleep(0.5) # 等待0.04秒后断开连接
|
||
await websocket.close()
|
||
|
||
|
||
#语音电话测试
|
||
async def test_voice_call(self):
|
||
chunk_size = 480
|
||
current_file_path = os.path.abspath(__file__)
|
||
current_dir = os.path.dirname(current_file_path)
|
||
tests_dir = os.path.dirname(current_dir)
|
||
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
|
||
url = f"ws://127.0.0.1:7878/chat/voice_call"
|
||
#发送格式
|
||
ws_data = {
|
||
"audio" : "",
|
||
"meta_info":{
|
||
"session_id":self.session_id,
|
||
"encoding": 'raw'
|
||
},
|
||
"is_close" : False
|
||
}
|
||
|
||
async def audio_stream(websocket):
|
||
with wave.open(file_path, 'rb') as wf:
|
||
frames_per_buffer = int(chunk_size / 2)
|
||
data = wf.readframes(frames_per_buffer)
|
||
while True:
|
||
if len(data) != 960:
|
||
break
|
||
encoded_data = base64.b64encode(data).decode('utf-8')
|
||
ws_data['audio'] = encoded_data
|
||
await websocket.send(json.dumps(ws_data))
|
||
data = wf.readframes(frames_per_buffer)
|
||
await asyncio.sleep(3)
|
||
ws_data['audio'] = ""
|
||
ws_data['is_close'] = True
|
||
await websocket.send(json.dumps(ws_data))
|
||
while True:
|
||
data_ws = await websocket.recv()
|
||
if data_ws:
|
||
print("语音电话测试成功")
|
||
break
|
||
await asyncio.sleep(3)
|
||
await websocket.close()
|
||
|
||
async with websockets.connect(url) as websocket:
|
||
await asyncio.gather(audio_stream(websocket))
|
||
|
||
|
||
#测试删除聊天
|
||
def test_chat_delete(self):
|
||
url = f"{self.socket}/chats/{self.user_character_id}"
|
||
response = requests.request("DELETE", url)
|
||
if response.status_code == 200:
|
||
print("聊天删除测试成功")
|
||
else:
|
||
raise Exception("聊天删除测试失败")
|
||
|
||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||
response = requests.request("DELETE", url)
|
||
if response.status_code != 200:
|
||
raise Exception("音频删除测试失败")
|
||
|
||
url = f"{self.socket}/users/{self.user_id}"
|
||
response = requests.request("DELETE", url)
|
||
if response.status_code != 200:
|
||
raise Exception("用户删除测试失败")
|
||
|
||
url = f"{self.socket}/characters/{self.character_id}"
|
||
response = requests.request("DELETE", url)
|
||
if response.status_code != 200:
|
||
raise Exception("角色删除测试失败")
|
||
|
||
def chat_test():
|
||
chat_service_test = ChatServiceTest()
|
||
chat_service_test.test_create_chat()
|
||
chat_service_test.test_session_id_query()
|
||
chat_service_test.test_session_content_query()
|
||
chat_service_test.test_session_update()
|
||
asyncio.run(chat_service_test.test_chat_temporary())
|
||
asyncio.run(chat_service_test.test_chat_lasting())
|
||
asyncio.run(chat_service_test.test_voice_call())
|
||
chat_service_test.test_chat_delete()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
chat_test() |