forked from killua/TakwayDisplayPlatform
Compare commits
4 Commits
bing-patch
...
master
Author | SHA1 | Date |
---|---|---|
|
21f9c86c46 | |
|
18dbabdd19 | |
|
c58c7c9a5b | |
|
2b98752db1 |
|
@ -0,0 +1,48 @@
|
||||||
|
# 部署步骤
|
||||||
|
|
||||||
|
1. 将该仓库clone到本地
|
||||||
|
2. 创建虚拟环境并启动
|
||||||
|
|
||||||
|
``` shell
|
||||||
|
conda create -n takway python=3.9
|
||||||
|
conda activate takway
|
||||||
|
```
|
||||||
|
|
||||||
|
3. cd进入仓库目录下,安装依赖
|
||||||
|
|
||||||
|
``` shell
|
||||||
|
cd ~/TakwayDisplayPlatform
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 在./utils/目录下创建vits_model文件夹
|
||||||
|
|
||||||
|
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
||||||
|
|
||||||
|
5. 在./utils/bert_vits2/目录下创建bert,slm文件夹
|
||||||
|
|
||||||
|
从[链接1](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large),[链接2](https://huggingface.co/microsoft/deberta-v3-large),[链接3](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)下载模型,放入bert目录下
|
||||||
|
|
||||||
|
从[链接4](https://huggingface.co/microsoft/wavlm-base-plus)下载模型,放入slm目录下
|
||||||
|
|
||||||
|
在./utils/bert_vits2/bert目录下创建bert_models.json文件,填入如下内容
|
||||||
|
|
||||||
|
``` json
|
||||||
|
{
|
||||||
|
"deberta-v2-large-japanese-char-wwm": {
|
||||||
|
"repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
|
||||||
|
"files": ["pytorch_model.bin"]
|
||||||
|
},
|
||||||
|
"chinese-roberta-wwm-ext-large": {
|
||||||
|
"repo_id": "hfl/chinese-roberta-wwm-ext-large",
|
||||||
|
"files": ["pytorch_model.bin"]
|
||||||
|
},
|
||||||
|
"deberta-v3-large": {
|
||||||
|
"repo_id": "microsoft/deberta-v3-large",
|
||||||
|
"files": ["spm.model", "pytorch_model.bin"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
6. 在 ./utils/bert_vits2/data/mix/目录下创建models文件夹,并放入预训练模型250000_G.pth
|
||||||
|
7. 回到根目录,`python main.py`启动程序
|
23
README.md
23
README.md
|
@ -1,23 +0,0 @@
|
||||||
# 部署步骤
|
|
||||||
|
|
||||||
1. 将该仓库clone到本地
|
|
||||||
2. 创建虚拟环境并启动
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
conda create -n takway python=3.9
|
|
||||||
conda activate takway
|
|
||||||
```
|
|
||||||
|
|
||||||
3. cd进入仓库目录下,安装依赖
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
cd ~/TakwayDisplayPlatform
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
4. 在utils/目录下创建vits_model文件夹
|
|
||||||
|
|
||||||
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
|
||||||
|
|
||||||
5. 回到仓库根目录下运行`python main.py`启动程序
|
|
||||||
|
|
|
@ -6,7 +6,8 @@ from .abstract import *
|
||||||
from .public import *
|
from .public import *
|
||||||
from .exception import *
|
from .exception import *
|
||||||
from .dependency import get_logger
|
from .dependency import get_logger
|
||||||
from utils.vits_utils import TextToSpeech
|
from utils.vits_utils import TextToSpeech as VITS_TextToSpeech
|
||||||
|
from utils.bert_vits2_utils import TextToSpeech as BertVits_TextToSpeech
|
||||||
from config import Config
|
from config import Config
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
|
@ -17,7 +18,11 @@ import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# ----------- 初始化vits ----------- #
|
# ----------- 初始化vits ----------- #
|
||||||
vits = TextToSpeech()
|
vits = VITS_TextToSpeech()
|
||||||
|
# ---------------------------------- #
|
||||||
|
|
||||||
|
# -------- 初始化bert-vits --------- #
|
||||||
|
bert_vits = BertVits_TextToSpeech()
|
||||||
# ---------------------------------- #
|
# ---------------------------------- #
|
||||||
|
|
||||||
# ---------- 初始化logger ---------- #
|
# ---------- 初始化logger ---------- #
|
||||||
|
@ -294,6 +299,14 @@ class VITS_TTS(TTS):
|
||||||
def synthetize(self, assistant, text):
|
def synthetize(self, assistant, text):
|
||||||
tts_info = json.loads(assistant.tts_info)
|
tts_info = json.loads(assistant.tts_info)
|
||||||
return vits.synthesize(text, tts_info)
|
return vits.synthesize(text, tts_info)
|
||||||
|
|
||||||
|
class BertVits_TTS(TTS):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def synthetize(self, assistant, text):
|
||||||
|
tts_info = json.loads(assistant.tts_info)
|
||||||
|
return bert_vits.synthesize(text, tts_info)
|
||||||
# --------------------------------- #
|
# --------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@ -319,6 +332,8 @@ class TTSFactory:
|
||||||
def create_tts(self,tts_type:str) -> TTS:
|
def create_tts(self,tts_type:str) -> TTS:
|
||||||
if tts_type == 'VITS':
|
if tts_type == 'VITS':
|
||||||
return VITS_TTS()
|
return VITS_TTS()
|
||||||
|
if tts_type == 'BertVits':
|
||||||
|
return BertVits_TTS()
|
||||||
# --------------------------------- #
|
# --------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@ -420,7 +435,12 @@ class Agent():
|
||||||
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
||||||
|
|
||||||
def init_recorder(self,user_id):
|
def init_recorder(self,user_id):
|
||||||
self.recorder = Recorder(user_id)
|
input_sr = 16000
|
||||||
|
if isinstance(self.tts, BertVits_TTS):
|
||||||
|
output_sr = 44100
|
||||||
|
elif isinstance(self.tts, VITS_TTS):
|
||||||
|
output_sr = 22050
|
||||||
|
self.recorder = Recorder(user_id,input_sr,output_sr)
|
||||||
|
|
||||||
# 对用户输入的音频进行预处理
|
# 对用户输入的音频进行预处理
|
||||||
def user_audio_process(self, audio):
|
def user_audio_process(self, audio):
|
||||||
|
|
|
@ -30,12 +30,12 @@ class SentenceSegmentation():
|
||||||
return self.__sentenceSegmentation(llm_chunk)
|
return self.__sentenceSegmentation(llm_chunk)
|
||||||
|
|
||||||
class Recorder:
|
class Recorder:
|
||||||
def __init__(self, user_id):
|
def __init__(self, user_id, input_sr, output_sr):
|
||||||
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
||||||
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
||||||
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
||||||
self.input_sr = 16000
|
self.input_sr = input_sr
|
||||||
self.output_sr = 22050
|
self.output_sr = output_sr
|
||||||
self.user_audio = b''
|
self.user_audio = b''
|
||||||
self.tts_audio = b''
|
self.tts_audio = b''
|
||||||
self.input_text = ""
|
self.input_text = ""
|
||||||
|
|
|
@ -33,8 +33,13 @@ class update_assistant_deatil_params_request(BaseModel):
|
||||||
platform:str
|
platform:str
|
||||||
model :str
|
model :str
|
||||||
temperature :float
|
temperature :float
|
||||||
|
tts_engine:str
|
||||||
speaker_id:int
|
speaker_id:int
|
||||||
length_scale:float
|
length_scale:float
|
||||||
|
language:str
|
||||||
|
style_text:str
|
||||||
|
style_weight:float
|
||||||
|
|
||||||
|
|
||||||
class update_assistant_max_tokens_request(BaseModel):
|
class update_assistant_max_tokens_request(BaseModel):
|
||||||
max_tokens:int
|
max_tokens:int
|
|
@ -1,8 +1,5 @@
|
||||||
class Config:
|
class Config:
|
||||||
SQLITE_URL = 'sqlite:///takway.db'
|
SQLITE_URL = 'sqlite:///takway.db'
|
||||||
ASR = "XF" #在此处选择语音识别引擎
|
|
||||||
LLM = "MINIMAX" #在此处选择大模型
|
|
||||||
TTS = "VITS" #在此处选择语音合成引擎
|
|
||||||
LOG_LEVEL = "DEBUG"
|
LOG_LEVEL = "DEBUG"
|
||||||
class UVICORN:
|
class UVICORN:
|
||||||
HOST = '0.0.0.0'
|
HOST = '0.0.0.0'
|
||||||
|
|
17
main.py
17
main.py
|
@ -122,7 +122,16 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
|
||||||
llm_info['temperature'] = request.temperature
|
llm_info['temperature'] = request.temperature
|
||||||
tts_info['speaker_id'] = request.speaker_id
|
tts_info['speaker_id'] = request.speaker_id
|
||||||
tts_info['length_scale'] = request.length_scale
|
tts_info['length_scale'] = request.length_scale
|
||||||
|
tts_info['language'] = request.language
|
||||||
|
tts_info['style_text'] = request.style_text
|
||||||
|
tts_info['style_weight'] = request.style_weight
|
||||||
|
tts_info['sdp_ratio'] = 0.5
|
||||||
|
tts_info['opt_cut_by_send'] = False
|
||||||
|
tts_info['interval_between_para'] = 1.0
|
||||||
|
tts_info['interval_between_sent'] = 0.2
|
||||||
|
tts_info['en_ratio'] = 1.0
|
||||||
user_info['llm_type'] = request.platform
|
user_info['llm_type'] = request.platform
|
||||||
|
user_info['tts_type'] = request.tts_engine
|
||||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||||||
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
||||||
|
@ -227,15 +236,7 @@ async def streaming_chat(ws: WebSocket):
|
||||||
agent.recorder.input_text = prompt
|
agent.recorder.input_text = prompt
|
||||||
logger.debug("开始调用大模型")
|
logger.debug("开始调用大模型")
|
||||||
llm_frames = await agent.chat(assistant, prompt)
|
llm_frames = await agent.chat(assistant, prompt)
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
is_first_response = True
|
|
||||||
|
|
||||||
for llm_frame in llm_frames:
|
for llm_frame in llm_frames:
|
||||||
if is_first_response:
|
|
||||||
end_time = time.time()
|
|
||||||
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
|
||||||
is_first_response = False
|
|
||||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||||
for resp_msg in resp_msgs:
|
for resp_msg in resp_msgs:
|
||||||
llm_text += resp_msg
|
llm_text += resp_msg
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
|
# MQTT Broker信息
|
||||||
|
broker = '127.0.0.1'
|
||||||
|
port = 1883
|
||||||
|
topic = 'audio/test'
|
||||||
|
|
||||||
|
# 音频文件路径
|
||||||
|
audio_file_path = 'tmp2.wav'
|
||||||
|
|
||||||
|
def publish_audio():
|
||||||
|
client = mqtt.Client()
|
||||||
|
client.connect(broker, port)
|
||||||
|
|
||||||
|
with open(audio_file_path, 'rb') as audio_file:
|
||||||
|
audio_data = audio_file.read()
|
||||||
|
|
||||||
|
client.publish(topic, audio_data)
|
||||||
|
client.disconnect()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
publish_audio()
|
||||||
|
print("Audio published successfully.")
|
|
@ -0,0 +1,33 @@
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
# MQTT Broker信息
|
||||||
|
broker = '127.0.0.1'
|
||||||
|
port = 1883
|
||||||
|
topic = 'audio/test'
|
||||||
|
|
||||||
|
# 音频文件路径列表
|
||||||
|
audio_file_paths = ['tmp2.wav', 'tmp3.wav', 'tmp4.wav'] # 添加多个音频文件路径
|
||||||
|
|
||||||
|
def publish_audio(file_path):
|
||||||
|
client = mqtt.Client()
|
||||||
|
client.connect(broker, port)
|
||||||
|
|
||||||
|
with open(file_path, 'rb') as audio_file:
|
||||||
|
audio_data = audio_file.read()
|
||||||
|
|
||||||
|
client.publish(topic, audio_data)
|
||||||
|
client.disconnect()
|
||||||
|
print(f"Audio from {file_path} published successfully.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
threads = []
|
||||||
|
for file_path in audio_file_paths:
|
||||||
|
thread = threading.Thread(target=publish_audio, args=(file_path,))
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
time.sleep(1) # 可选:给每个线程一些间隔时间
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
|
@ -0,0 +1,33 @@
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import io
|
||||||
|
|
||||||
|
# MQTT Broker信息
|
||||||
|
broker = '127.0.0.1'
|
||||||
|
port = 1883
|
||||||
|
topic = 'audio/test'
|
||||||
|
output_file_path = 'received_audio.wav'
|
||||||
|
|
||||||
|
def on_connect(client, userdata, flags, rc):
|
||||||
|
print("Connected with result code " + str(rc))
|
||||||
|
client.subscribe(topic)
|
||||||
|
|
||||||
|
def on_message(client, userdata, msg):
|
||||||
|
print("Audio received")
|
||||||
|
audio_data = msg.payload
|
||||||
|
audio = AudioSegment.from_file(io.BytesIO(audio_data), format="wav")
|
||||||
|
# 将音频保存为文件
|
||||||
|
with open(output_file_path, 'wb') as f:
|
||||||
|
f.write(audio_data)
|
||||||
|
print(f"Audio saved as {output_file_path}")
|
||||||
|
|
||||||
|
def subscribe_audio():
|
||||||
|
client = mqtt.Client()
|
||||||
|
client.on_connect = on_connect
|
||||||
|
client.on_message = on_message
|
||||||
|
|
||||||
|
client.connect(broker, port)
|
||||||
|
client.loop_forever()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
subscribe_audio()
|
|
@ -0,0 +1,42 @@
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import io
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# MQTT Broker信息
|
||||||
|
broker = '127.0.0.1'
|
||||||
|
port = 1883
|
||||||
|
topic = 'audio/test'
|
||||||
|
output_file_path = 'received_audio.wav'
|
||||||
|
|
||||||
|
def on_connect(client, userdata, flags, rc):
|
||||||
|
print("Connected with result code " + str(rc))
|
||||||
|
client.subscribe(topic)
|
||||||
|
|
||||||
|
def handle_audio(audio_data):
|
||||||
|
# 将音频保存为文件
|
||||||
|
with open(output_file_path, 'wb') as f:
|
||||||
|
f.write(audio_data)
|
||||||
|
print(f"Audio saved as {output_file_path}")
|
||||||
|
|
||||||
|
def on_message(client, userdata, msg):
|
||||||
|
print("Audio received")
|
||||||
|
audio_data = msg.payload
|
||||||
|
# 使用线程来处理音频
|
||||||
|
audio_thread = threading.Thread(target=handle_audio, args=(audio_data,))
|
||||||
|
audio_thread.start()
|
||||||
|
|
||||||
|
def subscribe_audio():
|
||||||
|
client = mqtt.Client()
|
||||||
|
client.on_connect = on_connect
|
||||||
|
client.on_message = on_message
|
||||||
|
|
||||||
|
client.connect(broker, port)
|
||||||
|
client.loop_forever()
|
||||||
|
|
||||||
|
def mqtt_thread():
|
||||||
|
subscribe_audio()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mqtt_thread = threading.Thread(target=mqtt_thread)
|
||||||
|
mqtt_thread.start()
|
|
@ -14,3 +14,4 @@ librosa
|
||||||
aiohttp
|
aiohttp
|
||||||
'volcengine-python-sdk[ark]'
|
'volcengine-python-sdk[ark]'
|
||||||
zhipuai
|
zhipuai
|
||||||
|
pyopenjtalk
|
13
test.py
13
test.py
|
@ -1,13 +0,0 @@
|
||||||
from utils.bert_vits2_utils import TextToSpeech
|
|
||||||
import soundfile as sf
|
|
||||||
tts = TextToSpeech()
|
|
||||||
tts.print_speakers_info()
|
|
||||||
|
|
||||||
audio, sample_rate= tts.synthesize("你好,我好开心", # 文本
|
|
||||||
0, # 说话人 id
|
|
||||||
style_text="我很难过!!!!呜呜呜!!!", # 情绪prompt,当language=="ZH" 才有效
|
|
||||||
style_weight=0.9, # 情绪prompt权重
|
|
||||||
language="mix", # 语言类型,包括 "ZH" "EN" "mix"
|
|
||||||
en_ratio=1.) # mix语言类型下,英文文本速度,越大速度越慢
|
|
||||||
save_path = "./tmp2.wav"
|
|
||||||
sf.write(save_path, audio, sample_rate)
|
|
|
@ -395,29 +395,28 @@ class TextToSpeech:
|
||||||
|
|
||||||
def synthesize(self,
|
def synthesize(self,
|
||||||
text,
|
text,
|
||||||
speaker_idx=0, # self.speakers 的 index,指定说话
|
tts_info,
|
||||||
sdp_ratio=0.5,
|
|
||||||
noise_scale=0.6,
|
|
||||||
noise_scale_w=0.9,
|
|
||||||
length_scale=1.0, # 越大语速越慢
|
|
||||||
language="mix", # ["ZH", "EN", "mix"] 三选一
|
|
||||||
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
|
|
||||||
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
|
|
||||||
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
|
|
||||||
audio_prompt=None,
|
|
||||||
text_prompt="",
|
|
||||||
prompt_mode="Text prompts",
|
|
||||||
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
|
||||||
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
|
|
||||||
# "效果较不明确,留空即为不使用该功能"
|
|
||||||
style_weight=0.7, # "主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本
|
|
||||||
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
return: audio, sample_rate
|
return: audio, sample_rate
|
||||||
"""
|
"""
|
||||||
|
speaker_id = tts_info['speaker_id'] # self.speakers 的 index,指定说话
|
||||||
|
sdp_ratio = tts_info['sdp_ratio']
|
||||||
|
noise_scale = tts_info['noise_scale']
|
||||||
|
noise_scale_w = tts_info['noise_scale_w']
|
||||||
|
length_scale = tts_info['length_scale']
|
||||||
|
language = tts_info['language'] # ["ZH", "EN", "mix"] 三选一
|
||||||
|
opt_cut_by_send = tts_info['opt_cut_by_send']
|
||||||
|
interval_between_para = tts_info['interval_between_para'] # 段间停顿(秒),需要大于句间停顿才有效
|
||||||
|
interval_between_sent = tts_info['interval_between_sent'] # 句间停顿(秒),勾选按句切分才生效
|
||||||
|
audio_prompt = None
|
||||||
|
text_prompt = ""
|
||||||
|
prompt_mode = "Text prompts"
|
||||||
|
style_text = tts_info['style_text']
|
||||||
|
style_weight = tts_info['style_weight']
|
||||||
|
en_ratio = tts_info['en_ratio']
|
||||||
|
|
||||||
speaker = self.speakers[speaker_idx]
|
speaker = self.speakers[speaker_id]
|
||||||
|
|
||||||
if language == "mix":
|
if language == "mix":
|
||||||
language, text = self.format_utils(text, speaker)
|
language, text = self.format_utils(text, speaker)
|
||||||
|
@ -455,9 +454,17 @@ class TextToSpeech:
|
||||||
style_weight
|
style_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
# return text_output, audio_output
|
return self.convert_numpy_to_bytes(audio_output[1])
|
||||||
return audio_output[1], audio_output[0]
|
|
||||||
|
|
||||||
def print_speakers_info(self):
|
def print_speakers_info(self):
|
||||||
for i, speaker in enumerate(self.speakers):
|
for i, speaker in enumerate(self.speakers):
|
||||||
print(f"id: {i}, speaker: {speaker}")
|
print(f"id: {i}, speaker: {speaker}")
|
||||||
|
|
||||||
|
def convert_numpy_to_bytes(self, audio_data):
|
||||||
|
if isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype == np.dtype('float32'):
|
||||||
|
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||||
|
audio_data = audio_data.tobytes()
|
||||||
|
return audio_data
|
||||||
|
else:
|
||||||
|
raise TypeError("audio_data must be a numpy array")
|
||||||
|
|
Loading…
Reference in New Issue