pun_emo_speaker_utils/takway/stt/emotion_utils.py

142 lines
4.7 KiB
Python
Raw Permalink Normal View History

import io
import numpy as np
import base64
import wave
from funasr import AutoModel
import time
"""
Base模型
不能进行情绪分类,只能用作特征提取
"""
FUNASRBASE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base",
"model_revision": "v2.0.4"
}
"""
Finetune模型
输出分类结果
"""
FUNASRFINETUNE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base_finetuned"
}
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class Emotion:
def __init__(self,
model_type="funasr",
model_path="iic/emotion2vec_base",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type = model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
# 初始化模型
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == "funasr":
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
# 检查输入类型
def check_audio_type(self,
audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
pass
else:
raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# 输入类型必须是float32
if isinstance(audio_data, np.ndarray):
audio_data = audio_data.astype(np.float32)
else:
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
return audio_data
def process(self,
audio_data,
granularity="utterance",
extract_embedding=False,
output_dir=None,
only_score=True):
"""
audio_data: only float32 expected beacause layernorm
extract_embedding: save embedding if true
output_dir: save path for embedding
only_Score: only return lables & scores if true
"""
audio_data = self.check_audio_type(audio_data)
if self.model_type == 'funasr':
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
else:
pass
# 只保留 lables 和 scores
if only_score:
maintain_key = ["labels", "scores"]
for res in result:
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
for k in keys_to_remove:
res.pop(k)
return result[0]
# only for test
def load_audio_file(wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
if __name__ == "__main__":
inputs = r".\example\test.wav"
inputs = load_audio_file(inputs)
device = "cuda"
# FUNASRBASE.update({"device": device})
FUNASRFINETUNE.update({"deivce": device})
emotion_model = Emotion(**FUNASRFINETUNE)
s = time.time()
result = emotion_model.process(inputs)
t = time.time()
print(t - s)
print(result)