1
0
Fork 0

feat: 标点添加,情感识别,说话人识别utils, 并给出示例

This commit is contained in:
bing 2024-05-13 12:55:44 +08:00
parent 4974570f20
commit a776258f8b
9 changed files with 880 additions and 0 deletions

Binary file not shown.

284
examples/audio_utils.py Normal file
View File

@ -0,0 +1,284 @@
import os
import io
import numpy as np
import pyaudio
import wave
import base64
"""
audio utils for modified_funasr_demo.py
"""
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseRecorder(BaseAudio):
def __init__(self,
input=True,
base_chunk_size=None,
RATE=16000,
**kwargs):
super().__init__(input=input, RATE=RATE, **kwargs)
self.base_chunk_size = base_chunk_size
if base_chunk_size is None:
self.base_chunk_size = self.CHUNK
def record(self,
filename,
duration=5,
return_type='io',
logger=None):
if logger is not None:
logger.info("Recording started.")
else:
print("Recording started.")
frames = []
for i in range(0, int(self.RATE / self.CHUNK * duration)):
data = self.stream.read(self.CHUNK, exception_on_overflow=False)
frames.append(data)
if logger is not None:
logger.info("Recording stopped.")
else:
print("Recording stopped.")
return self.write_wave(filename, frames, return_type)
def record_chunk_voice(self,
return_type='bytes',
CHUNK=None,
exception_on_overflow=True,
queue=None):
data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK,
exception_on_overflow=exception_on_overflow)
if return_type is not None:
return self.write_wave(None, [data], return_type)
return data

View File

@ -0,0 +1,39 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from audio_utils import BaseRecorder
from utils.stt.modified_funasr import ModifiedRecognizer
def asr_file_stream(file_path=r'.\assets\example_recording.wav'):
# 读入音频文件
rec = BaseRecorder()
data = rec.load_audio_file(file_path)
# 创建模型
asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True)
asr.session_signup("test")
# 记录目标说话人
asr.initialize_speaker(r".\assets\example_recording.wav")
# 语音识别
print("===============================================")
text_dict = asr.streaming_recognize("test", data, auto_det_end=True)
print(f"text_dict: {text_dict}")
if not isinstance(text_dict, str):
print("".join(text_dict['text']))
# 情感识别
print("===============================================")
emotion_dict = asr.recognize_emotion(data)
print(f"emotion_dict: {emotion_dict}")
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
print("emotion: " +emotion_dict['labels'][max_index])
asr_file_stream()

View File

@ -0,0 +1 @@
存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH

Binary file not shown.

142
utils/stt/emotion_utils.py Normal file
View File

@ -0,0 +1,142 @@
import io
import numpy as np
import base64
import wave
from funasr import AutoModel
import time
"""
Base模型
不能进行情绪分类,只能用作特征提取
"""
FUNASRBASE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base",
"model_revision": "v2.0.4"
}
"""
Finetune模型
输出分类结果
"""
FUNASRFINETUNE = {
"model_type": "funasr",
"model_path": "iic/emotion2vec_base_finetuned"
}
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class Emotion:
def __init__(self,
model_type="funasr",
model_path="iic/emotion2vec_base",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type = model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
# 初始化模型
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == "funasr":
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
# 检查输入类型
def check_audio_type(self,
audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
pass
else:
raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# 输入类型必须是float32
if isinstance(audio_data, np.ndarray):
audio_data = audio_data.astype(np.float32)
else:
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
return audio_data
def process(self,
audio_data,
granularity="utterance",
extract_embedding=False,
output_dir=None,
only_score=True):
"""
audio_data: only float32 expected beacause layernorm
extract_embedding: save embedding if true
output_dir: save path for embedding
only_Score: only return lables & scores if true
"""
audio_data = self.check_audio_type(audio_data)
if self.model_type == 'funasr':
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
else:
pass
# 只保留 lables 和 scores
if only_score:
maintain_key = ["labels", "scores"]
for res in result:
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
for k in keys_to_remove:
res.pop(k)
return result[0]
# only for test
def load_audio_file(wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
if __name__ == "__main__":
inputs = r".\example\test.wav"
inputs = load_audio_file(inputs)
device = "cuda"
# FUNASRBASE.update({"device": device})
FUNASRFINETUNE.update({"deivce": device})
emotion_model = Emotion(**FUNASRFINETUNE)
s = time.time()
result = emotion_model.process(inputs)
t = time.time()
print(t - s)
print(result)

View File

@ -0,0 +1,209 @@
from .funasr_utils import FunAutoSpeechRecognizer
from .punctuation_utils import CTTRANSFORMER, Punctuation
from .emotion_utils import FUNASRFINETUNE, Emotion
from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
import os
import numpy as np
class ModifiedRecognizer(FunAutoSpeechRecognizer):
def __init__(self,
use_punct=True,
use_emotion=False,
use_speaker_ver=True):
# 创建基础的 funasr模型用于语音识别识别出不带标点的句子
super().__init__(
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1)
# 记录是否具备附加功能
self.use_punct = use_punct
self.use_emotion = use_emotion
self.use_speaker_ver = use_speaker_ver
# 增加标点模型
if use_punct:
self.puctuation_model = Punctuation(**CTTRANSFORMER)
# 情绪识别模型
if use_emotion:
self.emotion_model = Emotion(**FUNASRFINETUNE)
# 说话人识别模型
if use_speaker_ver:
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
def initialize_speaker(self, speaker_1_wav):
"""
用于说话人识别将输入的音频(speaker_1_wav)设立为目标说话人并将其特征保存本地
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if speaker_1_wav.endswith(".npy"):
self.save_speaker_path = speaker_1_wav
elif speaker_1_wav.endswith('.wav'):
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
# self.save_speaker_path = DEFALUT_SAVE_PATH
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
else:
raise TypeError("only support [.npy] or [.wav].")
def speaker_ver(self, speaker_2_wav):
"""
用于说话人识别判断输入音频是否为目标说话人
是返回True不是返回False
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if not hasattr(self, "save_speaker_path"):
raise NotImplementedError("please initialize speaker first")
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
speaker_2_wav=speaker_2_wav) == 'yes'
def recognize(self, audio_data):
"""
非流式语音识别返回识别出的文本返回值类型 str
"""
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
result = self.asr_model.generate(input=audio_data,
batch_size_s=300,
hotword=self.hotwords)
text = ''
for res in result:
text += res['text']
# 添加标点
if self.use_punct:
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
return text
def recognize_emotion(self, audio_data):
"""
情感识别返回值为:
1. 如果说话人非目标说话人返回字符串 "Other People"
2. 如果说话人为目标说话人返回字典{"Labels": List[str], "scores": List[int]}
"""
audio_data = self.check_audio_type(audio_data)
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
if self.use_emotion:
return self.emotion_model.process(audio_data)
else:
raise NotImplementedError("no access")
def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
流式语音识别返回值为
1. 如果说话人非目标说话人返回字符串 "Other People"
2. 如果说话人为目标说话人返回字典{"test": List[str], "is_end": boolean}
"""
audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id]
text_dict = dict(text=[], is_end=is_end)
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
if audio_cache is None:
audio_cache = audio_data
else:
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
return text_dict
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
except ValueError as e:
print(f"ValueError: {e}")
continue
# 增添标点
if self.use_punct:
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
else:
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
asr_cache = {}
else:
if end_idx:
audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
if self.use_punct and is_end:
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache
# print(f"text_dict: {text_dict}")
return text_dict

View File

@ -0,0 +1,119 @@
from funasr import AutoModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
PUNCTUATION_MARK = [",", ".", "?", "!", "", "", "", ""]
"""
FUNASR
模型大小: 1G
效果: 较好
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
"""
FUNASR = {
"model_type": "funasr",
"model_path": "ct-punc",
"model_revision": "v2.0.4"
}
"""
CTTRANSFORMER
模型大小: 275M
效果较差
输入类型: 支持字符串与list, 同时支持输入cache
"""
CTTRANSFORMER = {
"model_type": "ct-transformer",
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
"model_revision": "v2.0.4"
}
class Punctuation:
def __init__(self,
model_type="funasr", # funasr | ct-transformer
model_path="ct-punc",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type=model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == 'funasr':
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
elif model_type == 'ct-transformer':
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
def check_text_type(self,
text_data):
# funasr只支持单个str输入不支持list输入此处将list转化为字符串
if self.model_type == 'funasr':
if isinstance(text_data, str):
pass
elif isinstance(text_data, list):
text_data = ''.join(text_data)
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
# ct-transformer支持list输入
# TODO 验证拆分字符串能否提高效率
elif self.model_type == 'ct-transformer':
if isinstance(text_data, str):
text_data = [text_data]
elif isinstance(text_data, list):
pass
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
else:
pass
return text_data
def generate_cache(self, cache):
new_cache = {'pre_text': ""}
for text in cache['text']:
if text != '':
new_cache['pre_text'] = new_cache['pre_text']+text
return new_cache
def process(self,
text,
append_period=False,
cache={}):
if text == '':
return ''
text = self.check_text_type(text)
if self.model_type == 'funasr':
result = self.punc_model.generate(text)
elif self.model_type == 'ct-transformer':
if cache != {}:
cache = self.generate_cache(cache)
result = self.punc_model(text, cache=cache)
punced_text = ''
for res in result:
punced_text += res['text']
# 如果最后没有标点符号,手动加上。
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
punced_text += ""
return punced_text
if __name__ == "__main__":
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
"""
把字符串拆分为list只适用于ct-transformer模型,
在数据处理部分,已经把list转为单个字符串
"""
vads = inputs.split("|")
device = "cuda"
CTTRANSFORMER.update({"device": device})
puct_model = Punctuation(**CTTRANSFORMER)
result = puct_model.process(vads)
print(result)
# FUNASR.update({"device":"cuda"})
# puct_model = Punctuation(**FUNASR)
# result = puct_model.process(vads)
# print(result)

View File

@ -0,0 +1,86 @@
from modelscope.pipelines import pipeline
import numpy as np
import os
ERES2NETV2 = {
"task": 'speaker-verification',
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
"model_revision": 'v1.0.1',
"save_embeddings": False
}
# 保存 embedding 的路径
DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding")
class speaker_verfication:
def __init__(self,
task='speaker-verification',
model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common',
model_revision='v1.0.1',
device="cuda",
save_embeddings=False):
self.pipeline = pipeline(
task=task,
model=model_name,
model_revision=model_revision,
device=device)
self.save_embeddings = save_embeddings
def wav2embeddings(self, speaker_1_wav, save_path=None):
result = self.pipeline([speaker_1_wav], output_emb=True)
speaker_1_emb = result['embs'][0]
if save_path is not None:
np.save(save_path, speaker_1_emb)
return speaker_1_emb
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
if not self.save_embeddings:
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold)
return result["text"]
else:
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True)
speaker1_emb = result["embs"][0]
speaker2_emb = result["embs"][1]
np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb)
return result['outputs']["text"]
def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold):
base_emb = np.load(base_emb)
result = self.pipeline([speaker_2_wav], output_emb=True)
speaker2_emb = result["embs"][0]
similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb))
if similarity > threshold:
return "yes"
else:
return "no"
def verfication(self,
base_emb=None,
speaker_1_wav=None,
speaker_2_wav=None,
threshold=0.333,
save_path=None):
if base_emb is not None and speaker_1_wav is not None:
raise ValueError("Only need one of them, base_emb or speaker_1_wav")
if base_emb is not None and speaker_2_wav is not None:
return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
elif speaker_1_wav is not None and speaker_2_wav is not None:
return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
else:
raise NotImplementedError
if __name__ == '__main__':
verifier = speaker_verfication(**ERES2NETV2)
verifier = speaker_verfication(save_embeddings=False)
result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav",
speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav",
threshold=0.333,
save_path=r"D:\python\irving\takway_base-main\savePath"
)
print("---")
print(result)
print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy",
speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav",
threshold=0.333,
))