TakwayDisplayPlatform/utils/emotion2vec_utils.py

37 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import json
model_args = {
"task": Tasks.emotion_recognition,
"model": "iic/emotion2vec_base_finetuned" # Alternative: iic/emotion2vec_plus_seed, iic/emotion2vec_plus_base, iic/emotion2vec_plus_large and iic/emotion2vec_base_finetuned
# "device": 不用指定,"device" 默认为gpu
}
class EmotionRecognition:
def __init__(self) -> None:
self.initialize(model_args=model_args)
# 初始化模型
def initialize(self, model_args=model_args):
self.inference_pipeline = pipeline(**model_args)
def emotion_recognition(self,
audio:bytes,
granularity="utterance", # 中间特征的维度,"utterance": [*768], "frame": [T*768]
extract_embedding=False, # 是否保留提取到的特征False表示不保留中间特征只保留最终结果
output_dir="./outputs" # 中间特征的保存位置(extract_embedding为true时有效)
):
rec_result = self.inference_pipeline(audio, granularity=granularity, extract_embedding=extract_embedding, output_dir=output_dir)
# 保存结果
json_list = []
for emotion, score in zip(rec_result[0]["labels"], rec_result[0]["scores"]):
json_list.append({"emotion": emotion.split("/")[-1], "weight": round(score,4)})
recognize_result = json.dumps(json_list)
return recognize_result
if __name__ == "__main__":
model = EmotionRecognition()
recognize_result = model.emotion_recognition("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(recognize_result)