feat: 增加情感检测与标点识别
This commit is contained in:
parent
017997a33e
commit
54d13fba87
|
@ -16,9 +16,9 @@ import aiohttp
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# --------------------初始化本地ASR-----------------------
|
# --------------------初始化本地ASR-----------------------
|
||||||
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
|
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||||
|
|
||||||
asr = FunAutoSpeechRecognizer()
|
asr = ModifiedRecognizer()
|
||||||
logger.info("本地ASR初始化成功")
|
logger.info("本地ASR初始化成功")
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|
||||||
|
@ -60,6 +60,10 @@ def parseChunkDelta(chunk):
|
||||||
return "end"
|
return "end"
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"error chunk: {chunk}")
|
logger.error(f"error chunk: {chunk}")
|
||||||
|
return ""
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"error chunk: {chunk}")
|
||||||
|
return ""
|
||||||
|
|
||||||
#断句函数
|
#断句函数
|
||||||
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
||||||
|
@ -224,6 +228,7 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
is_signup = False
|
is_signup = False
|
||||||
|
audio = ""
|
||||||
try:
|
try:
|
||||||
current_message = ""
|
current_message = ""
|
||||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||||
|
@ -231,10 +236,16 @@ async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_
|
||||||
asr.session_signup(session_id)
|
asr.session_signup(session_id)
|
||||||
is_signup = True
|
is_signup = True
|
||||||
audio_data = await user_input_q.get()
|
audio_data = await user_input_q.get()
|
||||||
|
audio += audio_data
|
||||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
|
current_message = asr.punctuation_correction(current_message)
|
||||||
|
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||||
|
if not isinstance(emotion_dict, str):
|
||||||
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
|
current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
asr.session_signout(session_id)
|
asr.session_signout(session_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -371,6 +382,7 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
is_signup = False
|
is_signup = False
|
||||||
current_message = ""
|
current_message = ""
|
||||||
|
audio = ""
|
||||||
while not (input_finished_event.is_set() and user_input_q.empty()):
|
while not (input_finished_event.is_set() and user_input_q.empty()):
|
||||||
try:
|
try:
|
||||||
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
||||||
|
@ -380,15 +392,24 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
||||||
if aduio_frame['is_end']:
|
if aduio_frame['is_end']:
|
||||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
|
current_message = asr.punctuation_correction(current_message)
|
||||||
|
audio += aduio_frame['audio']
|
||||||
|
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
||||||
|
if not isinstance(emotion_dict, str):
|
||||||
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
|
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||||
|
print(current_message)
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
|
current_message = ""
|
||||||
|
audio = ""
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
else:
|
else:
|
||||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||||
|
audio += aduio_frame['audio']
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asr.session_signout(session_id)
|
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
break
|
break
|
||||||
asr.session_signout(session_id)
|
asr.session_signout(session_id)
|
||||||
|
@ -523,6 +544,7 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
||||||
current_message = ""
|
current_message = ""
|
||||||
vad_count = 0
|
vad_count = 0
|
||||||
is_signup = False
|
is_signup = False
|
||||||
|
audio = ""
|
||||||
while not (input_finished_event.is_set() and audio_q.empty()):
|
while not (input_finished_event.is_set() and audio_q.empty()):
|
||||||
try:
|
try:
|
||||||
if not is_signup:
|
if not is_signup:
|
||||||
|
@ -533,14 +555,22 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
||||||
if vad_count > 0:
|
if vad_count > 0:
|
||||||
vad_count -= 1
|
vad_count -= 1
|
||||||
asr_result = asr.streaming_recognize(session_id, audio_data)
|
asr_result = asr.streaming_recognize(session_id, audio_data)
|
||||||
|
audio += audio_data
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
else:
|
else:
|
||||||
vad_count += 1
|
vad_count += 1
|
||||||
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
||||||
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
||||||
if current_message:
|
if current_message:
|
||||||
|
current_message = asr.punctuation_correction(current_message)
|
||||||
|
audio += audio_data
|
||||||
|
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
||||||
|
if not isinstance(emotion_dict, str):
|
||||||
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
|
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||||
await asr_result_q.put(current_message)
|
await asr_result_q.put(current_message)
|
||||||
|
audio = ""
|
||||||
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
current_message = ""
|
current_message = ""
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import wave
|
||||||
|
from funasr import AutoModel
|
||||||
|
import time
|
||||||
|
"""
|
||||||
|
Base模型
|
||||||
|
不能进行情绪分类,只能用作特征提取
|
||||||
|
"""
|
||||||
|
FUNASRBASE = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "iic/emotion2vec_base",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Finetune模型
|
||||||
|
输出分类结果
|
||||||
|
"""
|
||||||
|
FUNASRFINETUNE = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "iic/emotion2vec_base_finetuned"
|
||||||
|
}
|
||||||
|
|
||||||
|
def decode_str2bytes(data):
|
||||||
|
# 将Base64编码的字节串解码为字节串
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64decode(data.encode('utf-8'))
|
||||||
|
|
||||||
|
class Emotion:
|
||||||
|
def __init__(self,
|
||||||
|
model_type="funasr",
|
||||||
|
model_path="iic/emotion2vec_base",
|
||||||
|
device="cuda",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.model_type = model_type
|
||||||
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
def initialize(self,
|
||||||
|
model_type,
|
||||||
|
model_path,
|
||||||
|
device,
|
||||||
|
model_revision,
|
||||||
|
**kwargs):
|
||||||
|
if model_type == "funasr":
|
||||||
|
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
|
||||||
|
|
||||||
|
# 检查输入类型
|
||||||
|
def check_audio_type(self,
|
||||||
|
audio_data):
|
||||||
|
"""check audio data type and convert it to bytes if necessary."""
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
elif isinstance(audio_data, list):
|
||||||
|
audio_data = b''.join(audio_data)
|
||||||
|
elif isinstance(audio_data, str):
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
elif isinstance(audio_data, io.BytesIO):
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, list, str, \
|
||||||
|
io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype != np.int16:
|
||||||
|
audio_data = audio_data.astype(np.int16)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
# 输入类型必须是float32
|
||||||
|
if isinstance(audio_data, np.ndarray):
|
||||||
|
audio_data = audio_data.astype(np.float32)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def process(self,
|
||||||
|
audio_data,
|
||||||
|
granularity="utterance",
|
||||||
|
extract_embedding=False,
|
||||||
|
output_dir=None,
|
||||||
|
only_score=True):
|
||||||
|
"""
|
||||||
|
audio_data: only float32 expected beacause layernorm
|
||||||
|
extract_embedding: save embedding if true
|
||||||
|
output_dir: save path for embedding
|
||||||
|
only_Score: only return lables & scores if true
|
||||||
|
"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 只保留 lables 和 scores
|
||||||
|
if only_score:
|
||||||
|
maintain_key = ["labels", "scores"]
|
||||||
|
for res in result:
|
||||||
|
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
|
||||||
|
for k in keys_to_remove:
|
||||||
|
res.pop(k)
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
# only for test
|
||||||
|
def load_audio_file(wav_file):
|
||||||
|
with wave.open(wav_file, 'rb') as wf:
|
||||||
|
params = wf.getparams()
|
||||||
|
frames = wf.readframes(params.nframes)
|
||||||
|
print("Audio file loaded.")
|
||||||
|
# Audio Parameters
|
||||||
|
# print("Channels:", params.nchannels)
|
||||||
|
# print("Sample width:", params.sampwidth)
|
||||||
|
# print("Frame rate:", params.framerate)
|
||||||
|
# print("Number of frames:", params.nframes)
|
||||||
|
# print("Compression type:", params.comptype)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inputs = r".\example\test.wav"
|
||||||
|
inputs = load_audio_file(inputs)
|
||||||
|
device = "cuda"
|
||||||
|
# FUNASRBASE.update({"device": device})
|
||||||
|
FUNASRFINETUNE.update({"deivce": device})
|
||||||
|
emotion_model = Emotion(**FUNASRFINETUNE)
|
||||||
|
s = time.time()
|
||||||
|
result = emotion_model.process(inputs)
|
||||||
|
t = time.time()
|
||||||
|
print(t - s)
|
||||||
|
print(result)
|
|
@ -0,0 +1,29 @@
|
||||||
|
from .funasr_utils import FunAutoSpeechRecognizer
|
||||||
|
from .punctuation_utils import FUNASR, Punctuation
|
||||||
|
from .emotion_utils import FUNASRFINETUNE, Emotion
|
||||||
|
|
||||||
|
class ModifiedRecognizer():
|
||||||
|
def __init__(self):
|
||||||
|
#增加语音识别模型
|
||||||
|
self.asr_model = FunAutoSpeechRecognizer()
|
||||||
|
|
||||||
|
#增加标点模型
|
||||||
|
self.puctuation_model = Punctuation(**FUNASR)
|
||||||
|
|
||||||
|
# 情绪识别模型
|
||||||
|
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
||||||
|
|
||||||
|
def session_signup(self, session_id):
|
||||||
|
self.asr_model.session_signup(session_id)
|
||||||
|
|
||||||
|
def session_signout(self, session_id):
|
||||||
|
self.asr_model.session_signout(session_id)
|
||||||
|
|
||||||
|
def streaming_recognize(self, session_id, audio_data,is_end=False):
|
||||||
|
return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end)
|
||||||
|
|
||||||
|
def punctuation_correction(self, sentence):
|
||||||
|
return self.puctuation_model.process(sentence)
|
||||||
|
|
||||||
|
def emtion_recognition(self, audio):
|
||||||
|
return self.emotion_model.process(audio)
|
|
@ -0,0 +1,119 @@
|
||||||
|
from funasr import AutoModel
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"]
|
||||||
|
"""
|
||||||
|
FUNASR
|
||||||
|
模型大小: 1G
|
||||||
|
效果: 较好
|
||||||
|
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
|
||||||
|
"""
|
||||||
|
FUNASR = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "ct-punc",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
CTTRANSFORMER
|
||||||
|
模型大小: 275M
|
||||||
|
效果:较差
|
||||||
|
输入类型: 支持字符串与list, 同时支持输入cache
|
||||||
|
"""
|
||||||
|
CTTRANSFORMER = {
|
||||||
|
"model_type": "ct-transformer",
|
||||||
|
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
|
||||||
|
class Punctuation:
|
||||||
|
def __init__(self,
|
||||||
|
model_type="funasr", # funasr | ct-transformer
|
||||||
|
model_path="ct-punc",
|
||||||
|
device="cuda",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.model_type=model_type
|
||||||
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
||||||
|
|
||||||
|
def initialize(self,
|
||||||
|
model_type,
|
||||||
|
model_path,
|
||||||
|
device,
|
||||||
|
model_revision,
|
||||||
|
**kwargs):
|
||||||
|
if model_type == 'funasr':
|
||||||
|
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
||||||
|
elif model_type == 'ct-transformer':
|
||||||
|
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
|
||||||
|
|
||||||
|
def check_text_type(self,
|
||||||
|
text_data):
|
||||||
|
# funasr只支持单个str输入,不支持list输入,此处将list转化为字符串
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
if isinstance(text_data, str):
|
||||||
|
pass
|
||||||
|
elif isinstance(text_data, list):
|
||||||
|
text_data = ''.join(text_data)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||||||
|
# ct-transformer支持list输入
|
||||||
|
# TODO 验证拆分字符串能否提高效率
|
||||||
|
elif self.model_type == 'ct-transformer':
|
||||||
|
if isinstance(text_data, str):
|
||||||
|
text_data = [text_data]
|
||||||
|
elif isinstance(text_data, list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return text_data
|
||||||
|
|
||||||
|
def generate_cache(self, cache):
|
||||||
|
new_cache = {'pre_text': ""}
|
||||||
|
for text in cache['text']:
|
||||||
|
if text != '':
|
||||||
|
new_cache['pre_text'] = new_cache['pre_text']+text
|
||||||
|
return new_cache
|
||||||
|
|
||||||
|
def process(self,
|
||||||
|
text,
|
||||||
|
append_period=False,
|
||||||
|
cache={}):
|
||||||
|
if text == '':
|
||||||
|
return ''
|
||||||
|
text = self.check_text_type(text)
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
result = self.punc_model.generate(text)
|
||||||
|
elif self.model_type == 'ct-transformer':
|
||||||
|
if cache != {}:
|
||||||
|
cache = self.generate_cache(cache)
|
||||||
|
result = self.punc_model(text, cache=cache)
|
||||||
|
punced_text = ''
|
||||||
|
for res in result:
|
||||||
|
punced_text += res['text']
|
||||||
|
# 如果最后没有标点符号,手动加上。
|
||||||
|
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
|
||||||
|
punced_text += "。"
|
||||||
|
return punced_text
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
|
||||||
|
"""
|
||||||
|
把字符串拆分为list只适用于ct-transformer模型,
|
||||||
|
在数据处理部分,已经把list转为单个字符串
|
||||||
|
"""
|
||||||
|
vads = inputs.split("|")
|
||||||
|
device = "cuda"
|
||||||
|
CTTRANSFORMER.update({"device": device})
|
||||||
|
puct_model = Punctuation(**CTTRANSFORMER)
|
||||||
|
result = puct_model.process(vads)
|
||||||
|
print(result)
|
||||||
|
# FUNASR.update({"device":"cuda"})
|
||||||
|
# puct_model = Punctuation(**FUNASR)
|
||||||
|
# result = puct_model.process(vads)
|
||||||
|
# print(result)
|
|
@ -0,0 +1,75 @@
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
ERES2NETV2 = {
|
||||||
|
"task": 'speaker-verification',
|
||||||
|
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||||
|
"model_revision": 'v1.0.1',
|
||||||
|
"save_embeddings": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存 embedding 的路径
|
||||||
|
DEFALUT_SAVE_PATH = r".\takway\savePath"
|
||||||
|
|
||||||
|
class speaker_verfication:
|
||||||
|
def __init__(self,
|
||||||
|
task='speaker-verification',
|
||||||
|
model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.1',
|
||||||
|
device="cuda",
|
||||||
|
save_embeddings=False):
|
||||||
|
self.pipeline = pipeline(
|
||||||
|
task=task,
|
||||||
|
model=model_name,
|
||||||
|
model_revision=model_revision,
|
||||||
|
device=device)
|
||||||
|
self.save_embeddings = save_embeddings
|
||||||
|
|
||||||
|
def wav2embeddings(self, speaker_1_wav):
|
||||||
|
result = self.pipeline([speaker_1_wav], output_emb=True)
|
||||||
|
speaker_1_emb = result['embs'][0]
|
||||||
|
return speaker_1_emb
|
||||||
|
|
||||||
|
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
|
||||||
|
if not self.save_embeddings:
|
||||||
|
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold)
|
||||||
|
return result["text"]
|
||||||
|
else:
|
||||||
|
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True)
|
||||||
|
speaker1_emb = result["embs"][0]
|
||||||
|
speaker2_emb = result["embs"][1]
|
||||||
|
np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb)
|
||||||
|
return result['outputs']["text"]
|
||||||
|
|
||||||
|
def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold):
|
||||||
|
base_emb = np.load(base_emb)
|
||||||
|
result = self.pipeline([speaker_2_wav], output_emb=True)
|
||||||
|
speaker2_emb = result["embs"][0]
|
||||||
|
similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb))
|
||||||
|
if similarity > threshold:
|
||||||
|
return "yes"
|
||||||
|
else:
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
def verfication(self,
|
||||||
|
base_emb,
|
||||||
|
speaker_emb,
|
||||||
|
threshold=0.333, ):
|
||||||
|
return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
verifier = speaker_verfication(**ERES2NETV2)
|
||||||
|
|
||||||
|
verifier = speaker_verfication(save_embeddings=False)
|
||||||
|
result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav",
|
||||||
|
speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav",
|
||||||
|
threshold=0.333,
|
||||||
|
save_path=r"D:\python\irving\takway_base-main\savePath"
|
||||||
|
)
|
||||||
|
print("---")
|
||||||
|
print(result)
|
||||||
|
print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy",
|
||||||
|
speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav",
|
||||||
|
threshold=0.333,
|
||||||
|
))
|
Loading…
Reference in New Issue