From ee3dbb04f0264afeab82a28b0aeb7e6739bebda3 Mon Sep 17 00:00:00 2001 From: bing <2524698668@qq.com> Date: Wed, 4 Sep 2024 17:09:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0EmotionRecognition=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/emotion2vec_utils.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 utils/emotion2vec_utils.py diff --git a/utils/emotion2vec_utils.py b/utils/emotion2vec_utils.py new file mode 100644 index 0000000..92564ea --- /dev/null +++ b/utils/emotion2vec_utils.py @@ -0,0 +1,37 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import json + +model_args = { + "task": Tasks.emotion_recognition, + "model": "iic/emotion2vec_base_finetuned" # Alternative: iic/emotion2vec_plus_seed, iic/emotion2vec_plus_base, iic/emotion2vec_plus_large and iic/emotion2vec_base_finetuned + # "device": 不用指定,"device" 默认为gpu +} + +class EmotionRecognition: + def __init__(self) -> None: + self.initialize(model_args=model_args) + + # 初始化模型 + def initialize(self, model_args=model_args): + self.inference_pipeline = pipeline(**model_args) + + def emotion_recognition(self, + audio:bytes, + granularity="utterance", # 中间特征的维度,"utterance": [*768], "frame": [T*768] + extract_embedding=False, # 是否保留提取到的特征,False表示不保留中间特征,只保留最终结果 + output_dir="./outputs" # 中间特征的保存位置(extract_embedding为true时有效) + ): + rec_result = self.inference_pipeline(audio, granularity=granularity, extract_embedding=extract_embedding, output_dir=output_dir) + + # 保存结果 + json_list = [] + for emotion, score in zip(rec_result[0]["labels"], rec_result[0]["scores"]): + json_list.append({"emotion": emotion.split("/")[-1], "weight": round(score,4)}) + recognize_result = json.dumps(json_list) + return recognize_result + +if __name__ == "__main__": + model = EmotionRecognition() + recognize_result = model.emotion_recognition("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") + print(recognize_result) \ No newline at end of file