TakwayBoard/takway/clients/client_utils.py

192 lines
6.2 KiB
Python
Raw Normal View History

2024-05-23 01:27:51 +08:00
import os
import json
import time
import datetime
import requests
2024-05-23 20:00:40 +08:00
import struct
2024-05-23 01:27:51 +08:00
from takway.common_utils import encode_bytes2str, decode_str2bytes
'''
{
"RESPONSE_INFO": {
"status": "success/error", # string
"message": "xxxxx", # string
},
"DATA": {
"Audio": {
"data": "xxxxx", # base64 encoded data
"metadata": {
"rate": ; # int
"channels": ; # int
"format": ; # int
}
},
"Text": {
"data": "xxxxx", # base64 encoded data
"metadata": {
"is_end": True/False, # bool
}
},
"Image": {
"data": "xxxxx", # base64 encoded data
"metadata": {
"width": ; # int
"height": ; # int
"format": ; # string
}
}
}
}
'''
class Client:
def __init__(self, server_url):
self.server_url = server_url
def gen_request_data(self, **kwargs):
# print("kwargs:", kwargs)
audio_data = kwargs.get("audio_data", None)
text_data = kwargs.get("text_data", dict())
return json.dumps(
{
"is_end": audio_data.get("is_end"), # bool
"is_bgn": audio_data.get("is_bgn"), # bool
"DATA": {
"Audio": {
"data": encode_bytes2str(audio_data['frames']), # base64 encoded data
"metadata": {
"frames_size": audio_data.get("frames_size"), # string
"chunk_size": audio_data.get("chunk_size"), # int
"is_end": audio_data.get("is_end"), # bool
}
},
"Text": {
"data": text_data.get("text"), # base64 encoded data
"metadata": {
"chat_status": text_data.get("chat_status"), # string
"chat_history": text_data.get("chat_history"), # list of dict
}
},
},
"META_INFO": {
# "model_version": kwargs.get("model_version", ""), # string
# "model_url": kwargs.get("model_url", ""), # string
"character": {
"name": kwargs.get("character", "Klee"), # string
"speaker_id": kwargs.get("speaker_id", 113), # int
"wakeup_words": kwargs.get("wakeup_words", "可莉来啦"), # list of string
}
}
}
) + '\n'
def send_data_to_server(self, **kwargs):
return requests.post(self.server_url,
data=self.gen_request_data(**kwargs), stream=True)
# ############################################ #
# ############ WebSocket Client ############# #
def check_audio_type(data, return_type='base64'):
'''
Check if the data type is valid.
'''
assert return_type in ['bytes', 'base64']
if return_type == 'base64':
if isinstance(data, bytes):
return encode_bytes2str(data)
elif return_type == 'bytes':
if isinstance(data, str):
return decode_str2bytes(data)
else:
raise ValueError('Invalid data type: {}.'.format(type(data)))
import websocket
from websocket import create_connection
class BaseWebSocketClient:
def __init__(self, server_url, session_id):
self.server_url = server_url
self.session_id = session_id
def wakeup_client(self):
'''
Start the client.
'''
self.websocket = create_connection(self.server_url)
2024-05-23 16:07:23 +08:00
def close_client(self):
'''
Close the client.
'''
self.websocket.close()
2024-05-23 01:27:51 +08:00
def send_per_data(self,
text: str = '',
audio: bytes = b'',
stream: bool = True,
voice_synthesize: bool = False,
is_end: bool = False,
encoding: str = 'base64',
):
'''
Send data to server.
Args:
data: bytes, data to be sent to server.
'''
self.websocket.send(json.dumps({
"text": text,
"audio": check_audio_type(audio, return_type=encoding),
"meta_info": {
"session_id": self.session_id,
"stream": stream,
"voice_synthesize": voice_synthesize,
"is_end": is_end,
"encoding": encoding,
}}))
def receive_per_data(self):
try:
recv_data = self.websocket.recv()
except websocket._exceptions.WebSocketConnectionClosedException:
return None, None
2024-05-23 20:00:40 +08:00
'''
2024-05-23 01:27:51 +08:00
try:
recv_data = json.loads(recv_data)
2024-05-23 20:00:40 +08:00
#解析头信息,假设头信息前 8 个字节包含两个长度字段
2024-05-23 01:27:51 +08:00
except json.JSONDecodeError as e:
# print(f"JSONDecodeError: {e}")
# is_end = True
pass
except Exception as e:
# print(f"receive_per_data error: {e}")
assert isinstance(recv_data, bytes), ValueError(f"Received data is not bytes, got {type(recv_data)}.")
2024-05-23 20:00:40 +08:00
return recv_data, type(recv_data)
'''
try:
#解析头信息,假设头信息前 8 个字节包含两个长度字段
header = recv_data[:8]
text_length, audio_length = struct.unpack('!II', header)
#提取文本和二进制音频数据
text_bytes = recv_data[8:8 + text_length]
audio = recv_data[8 + text_length:8 + text_length + audio_length]
2024-05-23 20:05:05 +08:00
text = json.loads(text_bytes.decode('utf-8'))
2024-05-23 20:00:40 +08:00
print("Received text:",text)
#处理音频数据,例如播放音频
print("Received audio(length):",len(audio))
return [text, audio], list
except TypeError as e:
try:
recv_data = json.loads(recv_data)
print(f"json: {recv_data}")
return recv_data, type(recv_data)
except json.JSONDecodeError as e:
2024-06-18 17:03:00 +08:00
return None, None
2024-05-23 20:00:40 +08:00