pun_emo_speaker_utils/takway/stt/emotion_utils.py

142 lines
4.7 KiB
Python
Raw Normal View History

import io
import numpy as np
import base64
import wave
from funasr import AutoModel
import time
"""
Base模型
不能进行情绪分类,只能用作特征提取
"""
FUNASRBASE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base",
"model_revision": "v2.0.4"
}
"""
Finetune模型
输出分类结果
"""
FUNASRFINETUNE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base_finetuned"
}
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class Emotion:
def __init__(self,
model_type="funasr",
model_path="iic/emotion2vec_base",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type = model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
# 初始化模型
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == "funasr":
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
# 检查输入类型
def check_audio_type(self,
audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
pass
else:
raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# 输入类型必须是float32
if isinstance(audio_data, np.ndarray):
audio_data = audio_data.astype(np.float32)
else:
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
return audio_data
def process(self,
audio_data,
granularity="utterance",
extract_embedding=False,
output_dir=None,
only_score=True):
"""
audio_data: only float32 expected beacause layernorm
extract_embedding: save embedding if true
output_dir: save path for embedding
only_Score: only return lables & scores if true
"""
audio_data = self.check_audio_type(audio_data)
if self.model_type == 'funasr':
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
else:
pass
# 只保留 lables 和 scores
if only_score:
maintain_key = ["labels", "scores"]
for res in result:
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
for k in keys_to_remove:
res.pop(k)
return result[0]
# only for test
def load_audio_file(wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
if __name__ == "__main__":
inputs = r".\example\test.wav"
inputs = load_audio_file(inputs)
device = "cuda"
# FUNASRBASE.update({"device": device})
FUNASRFINETUNE.update({"deivce": device})
emotion_model = Emotion(**FUNASRFINETUNE)
s = time.time()
result = emotion_model.process(inputs)
t = time.time()
print(t - s)
print(result)