2024-05-11 22:02:52 +08:00
|
|
|
|
from takway.stt.funasr_utils import FunAutoSpeechRecognizer
|
|
|
|
|
from takway.stt.punctuation_utils import CTTRANSFORMER, Punctuation
|
|
|
|
|
from takway.stt.emotion_utils import FUNASRFINETUNE, Emotion
|
|
|
|
|
from takway.stt.speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
|
|
|
|
|
import os
|
|
|
|
|
import pdb
|
|
|
|
|
import numpy as np
|
|
|
|
|
class ModifiedRecognizer(FunAutoSpeechRecognizer):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
use_punct=True,
|
|
|
|
|
use_emotion=False,
|
|
|
|
|
use_speaker_ver=True):
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 创建基础的 funasr模型,用于语音识别,识别出不带标点的句子
|
2024-05-11 22:02:52 +08:00
|
|
|
|
super().__init__(
|
|
|
|
|
model_path="paraformer-zh-streaming",
|
|
|
|
|
device="cuda",
|
|
|
|
|
RATE=16000,
|
|
|
|
|
cfg_path=None,
|
|
|
|
|
debug=False,
|
|
|
|
|
chunk_ms=480,
|
|
|
|
|
encoder_chunk_look_back=4,
|
|
|
|
|
decoder_chunk_look_back=1)
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 记录是否具备附加功能
|
2024-05-11 22:02:52 +08:00
|
|
|
|
self.use_punct = use_punct
|
|
|
|
|
self.use_emotion = use_emotion
|
|
|
|
|
self.use_speaker_ver = use_speaker_ver
|
|
|
|
|
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# 增加标点模型
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if use_punct:
|
|
|
|
|
self.puctuation_model = Punctuation(**CTTRANSFORMER)
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 情绪识别模型
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if use_emotion:
|
|
|
|
|
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 说话人识别模型
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if use_speaker_ver:
|
|
|
|
|
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
|
|
|
|
|
|
|
|
|
|
def initialize_speaker(self, speaker_1_wav):
|
2024-05-11 22:34:08 +08:00
|
|
|
|
"""
|
|
|
|
|
用于说话人识别,将输入的音频(speaker_1_wav)设立为目标说话人,并将其特征保存本地
|
|
|
|
|
"""
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if not self.use_speaker_ver:
|
|
|
|
|
raise NotImplementedError("no access")
|
|
|
|
|
if speaker_1_wav.endswith(".npy"):
|
|
|
|
|
self.save_speaker_path = speaker_1_wav
|
|
|
|
|
elif speaker_1_wav.endswith('.wav'):
|
|
|
|
|
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
|
|
|
|
|
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
|
|
|
|
|
# self.save_speaker_path = DEFALUT_SAVE_PATH
|
|
|
|
|
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("only support [.npy] or [.wav].")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def speaker_ver(self, speaker_2_wav):
|
2024-05-11 22:34:08 +08:00
|
|
|
|
"""
|
|
|
|
|
用于说话人识别,判断输入音频是否为目标说话人,
|
|
|
|
|
是返回True,不是返回False
|
|
|
|
|
"""
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if not self.use_speaker_ver:
|
|
|
|
|
raise NotImplementedError("no access")
|
|
|
|
|
if not hasattr(self, "save_speaker_path"):
|
|
|
|
|
raise NotImplementedError("please initialize speaker first")
|
|
|
|
|
# pdb.set_trace()
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
|
2024-05-11 22:02:52 +08:00
|
|
|
|
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
|
|
|
|
|
speaker_2_wav=speaker_2_wav) == 'yes'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recognize(self, audio_data):
|
2024-05-11 22:34:08 +08:00
|
|
|
|
"""
|
|
|
|
|
非流式语音识别,返回识别出的文本,返回值类型 str
|
|
|
|
|
"""
|
2024-05-11 22:02:52 +08:00
|
|
|
|
audio_data = self.check_audio_type(audio_data)
|
|
|
|
|
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# 说话人识别
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if self.use_speaker_ver:
|
|
|
|
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
|
|
|
|
speaker_2_wav=audio_data) == 'no':
|
|
|
|
|
return "Other People"
|
|
|
|
|
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# 语音识别
|
2024-05-11 22:02:52 +08:00
|
|
|
|
result = self.asr_model.generate(input=audio_data,
|
|
|
|
|
batch_size_s=300,
|
|
|
|
|
hotword=self.hotwords)
|
|
|
|
|
text = ''
|
|
|
|
|
for res in result:
|
|
|
|
|
text += res['text']
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 添加标点
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if self.use_punct:
|
|
|
|
|
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
|
|
|
|
|
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
def recognize_emotion(self, audio_data):
|
2024-05-11 22:34:08 +08:00
|
|
|
|
"""
|
|
|
|
|
情感识别,返回值为:
|
|
|
|
|
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
|
|
|
|
2. 如果说话人为目标说话人,返回字典{"Labels": List[str], "scores": List[int]}
|
|
|
|
|
"""
|
2024-05-11 22:02:52 +08:00
|
|
|
|
audio_data = self.check_audio_type(audio_data)
|
|
|
|
|
|
|
|
|
|
if self.use_speaker_ver:
|
|
|
|
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
|
|
|
|
speaker_2_wav=audio_data) == 'no':
|
|
|
|
|
return "Other People"
|
|
|
|
|
|
|
|
|
|
if self.use_emotion:
|
|
|
|
|
return self.emotion_model.process(audio_data)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("no access")
|
|
|
|
|
|
|
|
|
|
def streaming_recognize(self, audio_data, is_end=False, auto_det_end=False):
|
|
|
|
|
"""recognize partial result
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
audio_data: bytes or numpy array, partial audio data
|
|
|
|
|
is_end: bool, whether the audio data is the end of a sentence
|
|
|
|
|
auto_det_end: bool, whether to automatically detect the end of a audio data
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
流式语音识别,返回值为:
|
|
|
|
|
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
|
|
|
|
2. 如果说话人为目标说话人,返回字典{"test": List[str], "is_end": boolean}
|
2024-05-11 22:02:52 +08:00
|
|
|
|
"""
|
|
|
|
|
audio_data = self.check_audio_type(audio_data)
|
|
|
|
|
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# 说话人识别
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if self.use_speaker_ver:
|
|
|
|
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
|
|
|
|
speaker_2_wav=audio_data) == 'no':
|
|
|
|
|
return "Other People"
|
|
|
|
|
|
2024-05-11 22:34:08 +08:00
|
|
|
|
# 语音识别
|
2024-05-11 22:02:52 +08:00
|
|
|
|
text_dict = dict(text=[], is_end=is_end)
|
|
|
|
|
|
|
|
|
|
if self.audio_cache is None:
|
|
|
|
|
self.audio_cache = audio_data
|
|
|
|
|
else:
|
|
|
|
|
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
|
|
|
|
if self.audio_cache.shape[0] > 0:
|
|
|
|
|
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
|
|
|
|
|
|
|
|
|
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
|
|
|
|
return text_dict
|
|
|
|
|
|
|
|
|
|
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
|
|
|
|
|
|
|
|
|
if is_end:
|
|
|
|
|
# if the audio data is the end of a sentence, \
|
|
|
|
|
# we need to add one more chunk to the end to \
|
|
|
|
|
# ensure the end of the sentence is recognized correctly.
|
|
|
|
|
auto_det_end = True
|
|
|
|
|
|
|
|
|
|
if auto_det_end:
|
|
|
|
|
total_chunk_num += 1
|
|
|
|
|
|
|
|
|
|
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
|
|
|
|
end_idx = None
|
|
|
|
|
for i in range(total_chunk_num):
|
|
|
|
|
if auto_det_end:
|
|
|
|
|
is_end = i == total_chunk_num - 1
|
|
|
|
|
start_idx = i*self.chunk_partial_size
|
|
|
|
|
if auto_det_end:
|
|
|
|
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
|
|
|
|
else:
|
|
|
|
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
|
|
|
|
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
|
|
|
|
# t_stamp = time.time()
|
|
|
|
|
|
|
|
|
|
speech_chunk = self.audio_cache[start_idx:end_idx]
|
|
|
|
|
|
|
|
|
|
# TODO: exceptions processes
|
|
|
|
|
try:
|
|
|
|
|
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
print(f"ValueError: {e}")
|
|
|
|
|
continue
|
2024-05-11 22:34:08 +08:00
|
|
|
|
|
|
|
|
|
# 增添标点
|
2024-05-11 22:02:52 +08:00
|
|
|
|
if self.use_punct:
|
|
|
|
|
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
|
|
|
|
|
else:
|
|
|
|
|
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# print(f"each chunk time: {time.time()-t_stamp}")
|
|
|
|
|
|
|
|
|
|
if is_end:
|
|
|
|
|
self.audio_cache = None
|
|
|
|
|
self.asr_cache = {}
|
|
|
|
|
else:
|
|
|
|
|
if end_idx:
|
|
|
|
|
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
|
|
|
|
text_dict['is_end'] = is_end
|
|
|
|
|
|
|
|
|
|
if self.use_punct and is_end:
|
|
|
|
|
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
|
|
|
|
|
|
|
|
|
|
# print(f"text_dict: {text_dict}")
|
|
|
|
|
return text_dict
|