forked from killua/TakwayPlatform
142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
import io
|
|
import numpy as np
|
|
import base64
|
|
import wave
|
|
from funasr import AutoModel
|
|
import time
|
|
"""
|
|
Base模型
|
|
不能进行情绪分类,只能用作特征提取
|
|
"""
|
|
FUNASRBASE = {
|
|
"model_type": "funasr",
|
|
"model_path": "iic/emotion2vec_base",
|
|
"model_revision": "v2.0.4"
|
|
}
|
|
|
|
|
|
"""
|
|
Finetune模型
|
|
输出分类结果
|
|
"""
|
|
FUNASRFINETUNE = {
|
|
"model_type": "funasr",
|
|
"model_path": "iic/emotion2vec_base_finetuned"
|
|
}
|
|
|
|
def decode_str2bytes(data):
|
|
# 将Base64编码的字节串解码为字节串
|
|
if data is None:
|
|
return None
|
|
return base64.b64decode(data.encode('utf-8'))
|
|
|
|
class Emotion:
|
|
def __init__(self,
|
|
model_type="funasr",
|
|
model_path="iic/emotion2vec_base",
|
|
device="cuda",
|
|
model_revision="v2.0.4",
|
|
**kwargs):
|
|
|
|
self.model_type = model_type
|
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
|
|
|
# 初始化模型
|
|
def initialize(self,
|
|
model_type,
|
|
model_path,
|
|
device,
|
|
model_revision,
|
|
**kwargs):
|
|
if model_type == "funasr":
|
|
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
|
else:
|
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
|
|
|
|
# 检查输入类型
|
|
def check_audio_type(self,
|
|
audio_data):
|
|
"""check audio data type and convert it to bytes if necessary."""
|
|
if isinstance(audio_data, bytes):
|
|
pass
|
|
elif isinstance(audio_data, list):
|
|
audio_data = b''.join(audio_data)
|
|
elif isinstance(audio_data, str):
|
|
audio_data = decode_str2bytes(audio_data)
|
|
elif isinstance(audio_data, io.BytesIO):
|
|
wf = wave.open(audio_data, 'rb')
|
|
audio_data = wf.readframes(wf.getnframes())
|
|
elif isinstance(audio_data, np.ndarray):
|
|
pass
|
|
else:
|
|
raise TypeError(f"audio_data must be bytes, list, str, \
|
|
io.BytesIO or numpy array, but got {type(audio_data)}")
|
|
|
|
if isinstance(audio_data, bytes):
|
|
audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
|
elif isinstance(audio_data, np.ndarray):
|
|
if audio_data.dtype != np.int16:
|
|
audio_data = audio_data.astype(np.int16)
|
|
else:
|
|
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
|
|
|
# 输入类型必须是float32
|
|
if isinstance(audio_data, np.ndarray):
|
|
audio_data = audio_data.astype(np.float32)
|
|
else:
|
|
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
|
|
return audio_data
|
|
|
|
def process(self,
|
|
audio_data,
|
|
granularity="utterance",
|
|
extract_embedding=False,
|
|
output_dir=None,
|
|
only_score=True):
|
|
"""
|
|
audio_data: only float32 expected beacause layernorm
|
|
extract_embedding: save embedding if true
|
|
output_dir: save path for embedding
|
|
only_Score: only return lables & scores if true
|
|
"""
|
|
audio_data = self.check_audio_type(audio_data)
|
|
if self.model_type == 'funasr':
|
|
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
|
|
else:
|
|
pass
|
|
|
|
# 只保留 lables 和 scores
|
|
if only_score:
|
|
maintain_key = ["labels", "scores"]
|
|
for res in result:
|
|
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
|
|
for k in keys_to_remove:
|
|
res.pop(k)
|
|
return result[0]
|
|
|
|
# only for test
|
|
def load_audio_file(wav_file):
|
|
with wave.open(wav_file, 'rb') as wf:
|
|
params = wf.getparams()
|
|
frames = wf.readframes(params.nframes)
|
|
print("Audio file loaded.")
|
|
# Audio Parameters
|
|
# print("Channels:", params.nchannels)
|
|
# print("Sample width:", params.sampwidth)
|
|
# print("Frame rate:", params.framerate)
|
|
# print("Number of frames:", params.nframes)
|
|
# print("Compression type:", params.comptype)
|
|
return frames
|
|
|
|
if __name__ == "__main__":
|
|
inputs = r".\example\test.wav"
|
|
inputs = load_audio_file(inputs)
|
|
device = "cuda"
|
|
# FUNASRBASE.update({"device": device})
|
|
FUNASRFINETUNE.update({"deivce": device})
|
|
emotion_model = Emotion(**FUNASRFINETUNE)
|
|
s = time.time()
|
|
result = emotion_model.process(inputs)
|
|
t = time.time()
|
|
print(t - s)
|
|
print(result) |