pun_emo_speaker_utils/takway/stt/modified_funasr.py

205 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from takway.stt.funasr_utils import FunAutoSpeechRecognizer
from takway.stt.punctuation_utils import CTTRANSFORMER, Punctuation
from takway.stt.emotion_utils import FUNASRFINETUNE, Emotion
from takway.stt.speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
import os
import pdb
import numpy as np
class ModifiedRecognizer(FunAutoSpeechRecognizer):
def __init__(self,
use_punct=True,
use_emotion=False,
use_speaker_ver=True):
# 创建基础的 funasr模型用于语音识别识别出不带标点的句子
super().__init__(
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1)
# 记录是否具备附加功能
self.use_punct = use_punct
self.use_emotion = use_emotion
self.use_speaker_ver = use_speaker_ver
# 增加标点模型
if use_punct:
self.puctuation_model = Punctuation(**CTTRANSFORMER)
# 情绪识别模型
if use_emotion:
self.emotion_model = Emotion(**FUNASRFINETUNE)
# 说话人识别模型
if use_speaker_ver:
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
def initialize_speaker(self, speaker_1_wav):
"""
用于说话人识别,将输入的音频(speaker_1_wav)设立为目标说话人,并将其特征保存本地
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if speaker_1_wav.endswith(".npy"):
self.save_speaker_path = speaker_1_wav
elif speaker_1_wav.endswith('.wav'):
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
# self.save_speaker_path = DEFALUT_SAVE_PATH
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
else:
raise TypeError("only support [.npy] or [.wav].")
def speaker_ver(self, speaker_2_wav):
"""
用于说话人识别,判断输入音频是否为目标说话人,
是返回True不是返回False
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if not hasattr(self, "save_speaker_path"):
raise NotImplementedError("please initialize speaker first")
# pdb.set_trace()
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
speaker_2_wav=speaker_2_wav) == 'yes'
def recognize(self, audio_data):
"""
非流式语音识别,返回识别出的文本,返回值类型 str
"""
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
result = self.asr_model.generate(input=audio_data,
batch_size_s=300,
hotword=self.hotwords)
text = ''
for res in result:
text += res['text']
# 添加标点
if self.use_punct:
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
return text
def recognize_emotion(self, audio_data):
"""
情感识别,返回值为:
1. 如果说话人非目标说话人,返回字符串 "Other People"
2. 如果说话人为目标说话人,返回字典{"Labels": List[str], "scores": List[int]}
"""
audio_data = self.check_audio_type(audio_data)
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
if self.use_emotion:
return self.emotion_model.process(audio_data)
else:
raise NotImplementedError("no access")
def streaming_recognize(self, audio_data, is_end=False, auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
流式语音识别,返回值为:
1. 如果说话人非目标说话人,返回字符串 "Other People"
2. 如果说话人为目标说话人,返回字典{"test": List[str], "is_end": boolean}
"""
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
text_dict = dict(text=[], is_end=is_end)
if self.audio_cache is None:
self.audio_cache = audio_data
else:
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
if self.audio_cache.shape[0] > 0:
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
return text_dict
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = self.audio_cache[start_idx:end_idx]
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
except ValueError as e:
print(f"ValueError: {e}")
continue
# 增添标点
if self.use_punct:
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
else:
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
self.audio_cache = None
self.asr_cache = {}
else:
if end_idx:
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
if self.use_punct and is_end:
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
# print(f"text_dict: {text_dict}")
return text_dict