Compare commits
3 Commits
master
...
bing-patch
Author | SHA1 | Date |
---|---|---|
|
513ab861af | |
|
25b99502fd | |
|
ee3dbb04f0 |
48
README .md
48
README .md
|
@ -1,48 +0,0 @@
|
||||||
# 部署步骤
|
|
||||||
|
|
||||||
1. 将该仓库clone到本地
|
|
||||||
2. 创建虚拟环境并启动
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
conda create -n takway python=3.9
|
|
||||||
conda activate takway
|
|
||||||
```
|
|
||||||
|
|
||||||
3. cd进入仓库目录下,安装依赖
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
cd ~/TakwayDisplayPlatform
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
4. 在./utils/目录下创建vits_model文件夹
|
|
||||||
|
|
||||||
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
|
||||||
|
|
||||||
5. 在./utils/bert_vits2/目录下创建bert,slm文件夹
|
|
||||||
|
|
||||||
从[链接1](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large),[链接2](https://huggingface.co/microsoft/deberta-v3-large),[链接3](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)下载模型,放入bert目录下
|
|
||||||
|
|
||||||
从[链接4](https://huggingface.co/microsoft/wavlm-base-plus)下载模型,放入slm目录下
|
|
||||||
|
|
||||||
在./utils/bert_vits2/bert目录下创建bert_models.json文件,填入如下内容
|
|
||||||
|
|
||||||
``` json
|
|
||||||
{
|
|
||||||
"deberta-v2-large-japanese-char-wwm": {
|
|
||||||
"repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
|
|
||||||
"files": ["pytorch_model.bin"]
|
|
||||||
},
|
|
||||||
"chinese-roberta-wwm-ext-large": {
|
|
||||||
"repo_id": "hfl/chinese-roberta-wwm-ext-large",
|
|
||||||
"files": ["pytorch_model.bin"]
|
|
||||||
},
|
|
||||||
"deberta-v3-large": {
|
|
||||||
"repo_id": "microsoft/deberta-v3-large",
|
|
||||||
"files": ["spm.model", "pytorch_model.bin"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
6. 在 ./utils/bert_vits2/data/mix/目录下创建models文件夹,并放入预训练模型250000_G.pth
|
|
||||||
7. 回到根目录,`python main.py`启动程序
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
# 部署步骤
|
||||||
|
|
||||||
|
1. 将该仓库clone到本地
|
||||||
|
2. 创建虚拟环境并启动
|
||||||
|
|
||||||
|
``` shell
|
||||||
|
conda create -n takway python=3.9
|
||||||
|
conda activate takway
|
||||||
|
```
|
||||||
|
|
||||||
|
3. cd进入仓库目录下,安装依赖
|
||||||
|
|
||||||
|
``` shell
|
||||||
|
cd ~/TakwayDisplayPlatform
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 在utils/目录下创建vits_model文件夹
|
||||||
|
|
||||||
|
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
||||||
|
|
||||||
|
5. 回到仓库根目录下运行`python main.py`启动程序
|
||||||
|
|
|
@ -6,8 +6,7 @@ from .abstract import *
|
||||||
from .public import *
|
from .public import *
|
||||||
from .exception import *
|
from .exception import *
|
||||||
from .dependency import get_logger
|
from .dependency import get_logger
|
||||||
from utils.vits_utils import TextToSpeech as VITS_TextToSpeech
|
from utils.vits_utils import TextToSpeech
|
||||||
from utils.bert_vits2_utils import TextToSpeech as BertVits_TextToSpeech
|
|
||||||
from config import Config
|
from config import Config
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
|
@ -18,11 +17,7 @@ import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# ----------- 初始化vits ----------- #
|
# ----------- 初始化vits ----------- #
|
||||||
vits = VITS_TextToSpeech()
|
vits = TextToSpeech()
|
||||||
# ---------------------------------- #
|
|
||||||
|
|
||||||
# -------- 初始化bert-vits --------- #
|
|
||||||
bert_vits = BertVits_TextToSpeech()
|
|
||||||
# ---------------------------------- #
|
# ---------------------------------- #
|
||||||
|
|
||||||
# ---------- 初始化logger ---------- #
|
# ---------- 初始化logger ---------- #
|
||||||
|
@ -299,14 +294,6 @@ class VITS_TTS(TTS):
|
||||||
def synthetize(self, assistant, text):
|
def synthetize(self, assistant, text):
|
||||||
tts_info = json.loads(assistant.tts_info)
|
tts_info = json.loads(assistant.tts_info)
|
||||||
return vits.synthesize(text, tts_info)
|
return vits.synthesize(text, tts_info)
|
||||||
|
|
||||||
class BertVits_TTS(TTS):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def synthetize(self, assistant, text):
|
|
||||||
tts_info = json.loads(assistant.tts_info)
|
|
||||||
return bert_vits.synthesize(text, tts_info)
|
|
||||||
# --------------------------------- #
|
# --------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@ -332,8 +319,6 @@ class TTSFactory:
|
||||||
def create_tts(self,tts_type:str) -> TTS:
|
def create_tts(self,tts_type:str) -> TTS:
|
||||||
if tts_type == 'VITS':
|
if tts_type == 'VITS':
|
||||||
return VITS_TTS()
|
return VITS_TTS()
|
||||||
if tts_type == 'BertVits':
|
|
||||||
return BertVits_TTS()
|
|
||||||
# --------------------------------- #
|
# --------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@ -435,12 +420,7 @@ class Agent():
|
||||||
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
||||||
|
|
||||||
def init_recorder(self,user_id):
|
def init_recorder(self,user_id):
|
||||||
input_sr = 16000
|
self.recorder = Recorder(user_id)
|
||||||
if isinstance(self.tts, BertVits_TTS):
|
|
||||||
output_sr = 44100
|
|
||||||
elif isinstance(self.tts, VITS_TTS):
|
|
||||||
output_sr = 22050
|
|
||||||
self.recorder = Recorder(user_id,input_sr,output_sr)
|
|
||||||
|
|
||||||
# 对用户输入的音频进行预处理
|
# 对用户输入的音频进行预处理
|
||||||
def user_audio_process(self, audio):
|
def user_audio_process(self, audio):
|
||||||
|
|
|
@ -30,12 +30,12 @@ class SentenceSegmentation():
|
||||||
return self.__sentenceSegmentation(llm_chunk)
|
return self.__sentenceSegmentation(llm_chunk)
|
||||||
|
|
||||||
class Recorder:
|
class Recorder:
|
||||||
def __init__(self, user_id, input_sr, output_sr):
|
def __init__(self, user_id):
|
||||||
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
||||||
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
||||||
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
||||||
self.input_sr = input_sr
|
self.input_sr = 16000
|
||||||
self.output_sr = output_sr
|
self.output_sr = 22050
|
||||||
self.user_audio = b''
|
self.user_audio = b''
|
||||||
self.tts_audio = b''
|
self.tts_audio = b''
|
||||||
self.input_text = ""
|
self.input_text = ""
|
||||||
|
|
|
@ -33,13 +33,8 @@ class update_assistant_deatil_params_request(BaseModel):
|
||||||
platform:str
|
platform:str
|
||||||
model :str
|
model :str
|
||||||
temperature :float
|
temperature :float
|
||||||
tts_engine:str
|
|
||||||
speaker_id:int
|
speaker_id:int
|
||||||
length_scale:float
|
length_scale:float
|
||||||
language:str
|
|
||||||
style_text:str
|
|
||||||
style_weight:float
|
|
||||||
|
|
||||||
|
|
||||||
class update_assistant_max_tokens_request(BaseModel):
|
class update_assistant_max_tokens_request(BaseModel):
|
||||||
max_tokens:int
|
max_tokens:int
|
|
@ -1,5 +1,8 @@
|
||||||
class Config:
|
class Config:
|
||||||
SQLITE_URL = 'sqlite:///takway.db'
|
SQLITE_URL = 'sqlite:///takway.db'
|
||||||
|
ASR = "XF" #在此处选择语音识别引擎
|
||||||
|
LLM = "MINIMAX" #在此处选择大模型
|
||||||
|
TTS = "VITS" #在此处选择语音合成引擎
|
||||||
LOG_LEVEL = "DEBUG"
|
LOG_LEVEL = "DEBUG"
|
||||||
class UVICORN:
|
class UVICORN:
|
||||||
HOST = '0.0.0.0'
|
HOST = '0.0.0.0'
|
||||||
|
|
17
main.py
17
main.py
|
@ -122,16 +122,7 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
|
||||||
llm_info['temperature'] = request.temperature
|
llm_info['temperature'] = request.temperature
|
||||||
tts_info['speaker_id'] = request.speaker_id
|
tts_info['speaker_id'] = request.speaker_id
|
||||||
tts_info['length_scale'] = request.length_scale
|
tts_info['length_scale'] = request.length_scale
|
||||||
tts_info['language'] = request.language
|
|
||||||
tts_info['style_text'] = request.style_text
|
|
||||||
tts_info['style_weight'] = request.style_weight
|
|
||||||
tts_info['sdp_ratio'] = 0.5
|
|
||||||
tts_info['opt_cut_by_send'] = False
|
|
||||||
tts_info['interval_between_para'] = 1.0
|
|
||||||
tts_info['interval_between_sent'] = 0.2
|
|
||||||
tts_info['en_ratio'] = 1.0
|
|
||||||
user_info['llm_type'] = request.platform
|
user_info['llm_type'] = request.platform
|
||||||
user_info['tts_type'] = request.tts_engine
|
|
||||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||||||
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
assistant.user_info = json.dumps(user_info, ensure_ascii=False)
|
||||||
|
@ -236,7 +227,15 @@ async def streaming_chat(ws: WebSocket):
|
||||||
agent.recorder.input_text = prompt
|
agent.recorder.input_text = prompt
|
||||||
logger.debug("开始调用大模型")
|
logger.debug("开始调用大模型")
|
||||||
llm_frames = await agent.chat(assistant, prompt)
|
llm_frames = await agent.chat(assistant, prompt)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
is_first_response = True
|
||||||
|
|
||||||
for llm_frame in llm_frames:
|
for llm_frame in llm_frames:
|
||||||
|
if is_first_response:
|
||||||
|
end_time = time.time()
|
||||||
|
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||||
|
is_first_response = False
|
||||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||||
for resp_msg in resp_msgs:
|
for resp_msg in resp_msgs:
|
||||||
llm_text += resp_msg
|
llm_text += resp_msg
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
|
|
||||||
# MQTT Broker信息
|
|
||||||
broker = '127.0.0.1'
|
|
||||||
port = 1883
|
|
||||||
topic = 'audio/test'
|
|
||||||
|
|
||||||
# 音频文件路径
|
|
||||||
audio_file_path = 'tmp2.wav'
|
|
||||||
|
|
||||||
def publish_audio():
|
|
||||||
client = mqtt.Client()
|
|
||||||
client.connect(broker, port)
|
|
||||||
|
|
||||||
with open(audio_file_path, 'rb') as audio_file:
|
|
||||||
audio_data = audio_file.read()
|
|
||||||
|
|
||||||
client.publish(topic, audio_data)
|
|
||||||
client.disconnect()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
publish_audio()
|
|
||||||
print("Audio published successfully.")
|
|
|
@ -1,33 +0,0 @@
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
# MQTT Broker信息
|
|
||||||
broker = '127.0.0.1'
|
|
||||||
port = 1883
|
|
||||||
topic = 'audio/test'
|
|
||||||
|
|
||||||
# 音频文件路径列表
|
|
||||||
audio_file_paths = ['tmp2.wav', 'tmp3.wav', 'tmp4.wav'] # 添加多个音频文件路径
|
|
||||||
|
|
||||||
def publish_audio(file_path):
|
|
||||||
client = mqtt.Client()
|
|
||||||
client.connect(broker, port)
|
|
||||||
|
|
||||||
with open(file_path, 'rb') as audio_file:
|
|
||||||
audio_data = audio_file.read()
|
|
||||||
|
|
||||||
client.publish(topic, audio_data)
|
|
||||||
client.disconnect()
|
|
||||||
print(f"Audio from {file_path} published successfully.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
threads = []
|
|
||||||
for file_path in audio_file_paths:
|
|
||||||
thread = threading.Thread(target=publish_audio, args=(file_path,))
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
time.sleep(1) # 可选:给每个线程一些间隔时间
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
|
@ -1,33 +0,0 @@
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
from pydub import AudioSegment
|
|
||||||
import io
|
|
||||||
|
|
||||||
# MQTT Broker信息
|
|
||||||
broker = '127.0.0.1'
|
|
||||||
port = 1883
|
|
||||||
topic = 'audio/test'
|
|
||||||
output_file_path = 'received_audio.wav'
|
|
||||||
|
|
||||||
def on_connect(client, userdata, flags, rc):
|
|
||||||
print("Connected with result code " + str(rc))
|
|
||||||
client.subscribe(topic)
|
|
||||||
|
|
||||||
def on_message(client, userdata, msg):
|
|
||||||
print("Audio received")
|
|
||||||
audio_data = msg.payload
|
|
||||||
audio = AudioSegment.from_file(io.BytesIO(audio_data), format="wav")
|
|
||||||
# 将音频保存为文件
|
|
||||||
with open(output_file_path, 'wb') as f:
|
|
||||||
f.write(audio_data)
|
|
||||||
print(f"Audio saved as {output_file_path}")
|
|
||||||
|
|
||||||
def subscribe_audio():
|
|
||||||
client = mqtt.Client()
|
|
||||||
client.on_connect = on_connect
|
|
||||||
client.on_message = on_message
|
|
||||||
|
|
||||||
client.connect(broker, port)
|
|
||||||
client.loop_forever()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
subscribe_audio()
|
|
|
@ -1,42 +0,0 @@
|
||||||
import paho.mqtt.client as mqtt
|
|
||||||
from pydub import AudioSegment
|
|
||||||
import io
|
|
||||||
import threading
|
|
||||||
|
|
||||||
# MQTT Broker信息
|
|
||||||
broker = '127.0.0.1'
|
|
||||||
port = 1883
|
|
||||||
topic = 'audio/test'
|
|
||||||
output_file_path = 'received_audio.wav'
|
|
||||||
|
|
||||||
def on_connect(client, userdata, flags, rc):
|
|
||||||
print("Connected with result code " + str(rc))
|
|
||||||
client.subscribe(topic)
|
|
||||||
|
|
||||||
def handle_audio(audio_data):
|
|
||||||
# 将音频保存为文件
|
|
||||||
with open(output_file_path, 'wb') as f:
|
|
||||||
f.write(audio_data)
|
|
||||||
print(f"Audio saved as {output_file_path}")
|
|
||||||
|
|
||||||
def on_message(client, userdata, msg):
|
|
||||||
print("Audio received")
|
|
||||||
audio_data = msg.payload
|
|
||||||
# 使用线程来处理音频
|
|
||||||
audio_thread = threading.Thread(target=handle_audio, args=(audio_data,))
|
|
||||||
audio_thread.start()
|
|
||||||
|
|
||||||
def subscribe_audio():
|
|
||||||
client = mqtt.Client()
|
|
||||||
client.on_connect = on_connect
|
|
||||||
client.on_message = on_message
|
|
||||||
|
|
||||||
client.connect(broker, port)
|
|
||||||
client.loop_forever()
|
|
||||||
|
|
||||||
def mqtt_thread():
|
|
||||||
subscribe_audio()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mqtt_thread = threading.Thread(target=mqtt_thread)
|
|
||||||
mqtt_thread.start()
|
|
|
@ -13,5 +13,4 @@ numba
|
||||||
librosa
|
librosa
|
||||||
aiohttp
|
aiohttp
|
||||||
'volcengine-python-sdk[ark]'
|
'volcengine-python-sdk[ark]'
|
||||||
zhipuai
|
zhipuai
|
||||||
pyopenjtalk
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
from utils.bert_vits2_utils import TextToSpeech
|
||||||
|
import soundfile as sf
|
||||||
|
tts = TextToSpeech()
|
||||||
|
tts.print_speakers_info()
|
||||||
|
|
||||||
|
audio, sample_rate= tts.synthesize("你好,我好开心", # 文本
|
||||||
|
0, # 说话人 id
|
||||||
|
style_text="我很难过!!!!呜呜呜!!!", # 情绪prompt,当language=="ZH" 才有效
|
||||||
|
style_weight=0.9, # 情绪prompt权重
|
||||||
|
language="mix", # 语言类型,包括 "ZH" "EN" "mix"
|
||||||
|
en_ratio=1.) # mix语言类型下,英文文本速度,越大速度越慢
|
||||||
|
save_path = "./tmp2.wav"
|
||||||
|
sf.write(save_path, audio, sample_rate)
|
|
@ -395,28 +395,29 @@ class TextToSpeech:
|
||||||
|
|
||||||
def synthesize(self,
|
def synthesize(self,
|
||||||
text,
|
text,
|
||||||
tts_info,
|
speaker_idx=0, # self.speakers 的 index,指定说话
|
||||||
|
sdp_ratio=0.5,
|
||||||
|
noise_scale=0.6,
|
||||||
|
noise_scale_w=0.9,
|
||||||
|
length_scale=1.0, # 越大语速越慢
|
||||||
|
language="mix", # ["ZH", "EN", "mix"] 三选一
|
||||||
|
opt_cut_by_send=False, # 按句切分 在按段落切分的基础上再按句子切分文本
|
||||||
|
interval_between_para=1.0, # 段间停顿(秒),需要大于句间停顿才有效
|
||||||
|
interval_between_sent=0.2, # 句间停顿(秒),勾选按句切分才生效
|
||||||
|
audio_prompt=None,
|
||||||
|
text_prompt="",
|
||||||
|
prompt_mode="Text prompts",
|
||||||
|
style_text="", # "使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
|
||||||
|
# "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
|
||||||
|
# "效果较不明确,留空即为不使用该功能"
|
||||||
|
style_weight=0.7, # "主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本
|
||||||
|
en_ratio=1.0 # 中英混合时,英文速度控制,越大英文速度越慢
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
return: audio, sample_rate
|
return: audio, sample_rate
|
||||||
"""
|
"""
|
||||||
speaker_id = tts_info['speaker_id'] # self.speakers 的 index,指定说话
|
|
||||||
sdp_ratio = tts_info['sdp_ratio']
|
|
||||||
noise_scale = tts_info['noise_scale']
|
|
||||||
noise_scale_w = tts_info['noise_scale_w']
|
|
||||||
length_scale = tts_info['length_scale']
|
|
||||||
language = tts_info['language'] # ["ZH", "EN", "mix"] 三选一
|
|
||||||
opt_cut_by_send = tts_info['opt_cut_by_send']
|
|
||||||
interval_between_para = tts_info['interval_between_para'] # 段间停顿(秒),需要大于句间停顿才有效
|
|
||||||
interval_between_sent = tts_info['interval_between_sent'] # 句间停顿(秒),勾选按句切分才生效
|
|
||||||
audio_prompt = None
|
|
||||||
text_prompt = ""
|
|
||||||
prompt_mode = "Text prompts"
|
|
||||||
style_text = tts_info['style_text']
|
|
||||||
style_weight = tts_info['style_weight']
|
|
||||||
en_ratio = tts_info['en_ratio']
|
|
||||||
|
|
||||||
speaker = self.speakers[speaker_id]
|
speaker = self.speakers[speaker_idx]
|
||||||
|
|
||||||
if language == "mix":
|
if language == "mix":
|
||||||
language, text = self.format_utils(text, speaker)
|
language, text = self.format_utils(text, speaker)
|
||||||
|
@ -454,17 +455,9 @@ class TextToSpeech:
|
||||||
style_weight
|
style_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.convert_numpy_to_bytes(audio_output[1])
|
# return text_output, audio_output
|
||||||
|
return audio_output[1], audio_output[0]
|
||||||
|
|
||||||
def print_speakers_info(self):
|
def print_speakers_info(self):
|
||||||
for i, speaker in enumerate(self.speakers):
|
for i, speaker in enumerate(self.speakers):
|
||||||
print(f"id: {i}, speaker: {speaker}")
|
print(f"id: {i}, speaker: {speaker}")
|
||||||
|
|
||||||
def convert_numpy_to_bytes(self, audio_data):
|
|
||||||
if isinstance(audio_data, np.ndarray):
|
|
||||||
if audio_data.dtype == np.dtype('float32'):
|
|
||||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
|
||||||
audio_data = audio_data.tobytes()
|
|
||||||
return audio_data
|
|
||||||
else:
|
|
||||||
raise TypeError("audio_data must be a numpy array")
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
import json
|
||||||
|
|
||||||
|
model_args = {
|
||||||
|
"task": Tasks.emotion_recognition,
|
||||||
|
"model": "iic/emotion2vec_base_finetuned" # Alternative: iic/emotion2vec_plus_seed, iic/emotion2vec_plus_base, iic/emotion2vec_plus_large and iic/emotion2vec_base_finetuned
|
||||||
|
# "device": 不用指定,"device" 默认为gpu
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmotionRecognition:
|
||||||
|
"""
|
||||||
|
9-class emotion
|
||||||
|
0: angry
|
||||||
|
1: disgusted
|
||||||
|
2: fearful
|
||||||
|
3: happy
|
||||||
|
4: neutral
|
||||||
|
5: other
|
||||||
|
6: sad
|
||||||
|
7: surprised
|
||||||
|
8: <unk> 指unknown
|
||||||
|
return :
|
||||||
|
[{"emotion": "angry", "weight": },
|
||||||
|
{"emotion": "disgusted", "weight": },
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.initialize(model_args=model_args)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
def initialize(self, model_args=model_args):
|
||||||
|
self.inference_pipeline = pipeline(**model_args)
|
||||||
|
|
||||||
|
def emotion_recognition(self,
|
||||||
|
audio:bytes,
|
||||||
|
granularity="utterance", # 中间特征的维度,"utterance": [*768], "frame": [T*768]
|
||||||
|
extract_embedding=False, # 是否保留提取到的特征,False表示不保留中间特征,只保留最终结果
|
||||||
|
output_dir="./outputs" # 中间特征的保存位置(extract_embedding为true时有效)
|
||||||
|
):
|
||||||
|
rec_result = self.inference_pipeline(audio, granularity=granularity, extract_embedding=extract_embedding, output_dir=output_dir)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
json_list = []
|
||||||
|
for emotion, score in zip(rec_result[0]["labels"], rec_result[0]["scores"]):
|
||||||
|
json_list.append({"emotion": emotion.split("/")[-1], "weight": round(score,4)})
|
||||||
|
recognize_result = json.dumps(json_list)
|
||||||
|
return recognize_result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = EmotionRecognition()
|
||||||
|
recognize_result = model.emotion_recognition("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||||
|
print(recognize_result)
|
Loading…
Reference in New Issue