utils for punctuation and emotion and speaker ver
This commit is contained in:
commit
31017e1ec4
|
@ -0,0 +1 @@
|
||||||
|
from .base_stt import *
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,65 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import wave
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from ..common_utils import decode_str2bytes
|
||||||
|
|
||||||
|
class STTBase:
|
||||||
|
def __init__(self, RATE=16000, cfg_path=None, debug=False):
|
||||||
|
self.RATE = RATE
|
||||||
|
self.debug = debug
|
||||||
|
self.asr_cfg = self.parse_json(cfg_path)
|
||||||
|
|
||||||
|
def parse_json(self, cfg_path):
|
||||||
|
cfg = None
|
||||||
|
self.hotwords = None
|
||||||
|
if cfg_path is not None:
|
||||||
|
with open(cfg_path, 'r', encoding='utf-8') as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
self.hotwords = cfg.get('hot_words', None)
|
||||||
|
logging.info(f"load STT config file: {cfg_path}")
|
||||||
|
logging.info(f"Hot words: {self.hotwords}")
|
||||||
|
else:
|
||||||
|
logging.warning("No STT config file provided, using default config.")
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def add_hotword(self, hotword):
|
||||||
|
"""add hotword to list"""
|
||||||
|
if self.hotwords is None:
|
||||||
|
self.hotwords = ""
|
||||||
|
if isinstance(hotword, str):
|
||||||
|
self.hotwords = self.hotwords + " " + "hotword"
|
||||||
|
elif isinstance(hotword, (list, tuple)):
|
||||||
|
# 将hotwords转换为str,并用空格隔开
|
||||||
|
self.hotwords = self.hotwords + " " + " ".join(hotword)
|
||||||
|
else:
|
||||||
|
raise TypeError("hotword must be str or list")
|
||||||
|
|
||||||
|
def check_audio_type(self, audio_data):
|
||||||
|
"""check audio data type and convert it to bytes if necessary."""
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
elif isinstance(audio_data, list):
|
||||||
|
audio_data = b''.join(audio_data)
|
||||||
|
elif isinstance(audio_data, str):
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
elif isinstance(audio_data, io.BytesIO):
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def text_postprecess(self, result, data_id='text'):
|
||||||
|
"""postprecess recognized result."""
|
||||||
|
text = result[data_id]
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = ''.join(text)
|
||||||
|
return text.replace(' ', '')
|
||||||
|
|
||||||
|
def recognize(self, audio_data, queue=None):
|
||||||
|
"""recognize audio data to text"""
|
||||||
|
pass
|
|
@ -0,0 +1,142 @@
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import wave
|
||||||
|
from funasr import AutoModel
|
||||||
|
import time
|
||||||
|
"""
|
||||||
|
Base模型
|
||||||
|
不能进行情绪分类,只能用作特征提取
|
||||||
|
"""
|
||||||
|
FUNASRBASE = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "iic/emotion2vec_base",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Finetune模型
|
||||||
|
输出分类结果
|
||||||
|
"""
|
||||||
|
FUNASRFINETUNE = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "iic/emotion2vec_base_finetuned"
|
||||||
|
}
|
||||||
|
|
||||||
|
def decode_str2bytes(data):
|
||||||
|
# 将Base64编码的字节串解码为字节串
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64decode(data.encode('utf-8'))
|
||||||
|
|
||||||
|
class Emotion:
|
||||||
|
def __init__(self,
|
||||||
|
model_type="funasr",
|
||||||
|
model_path="iic/emotion2vec_base",
|
||||||
|
device="cuda",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.model_type = model_type
|
||||||
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
def initialize(self,
|
||||||
|
model_type,
|
||||||
|
model_path,
|
||||||
|
device,
|
||||||
|
model_revision,
|
||||||
|
**kwargs):
|
||||||
|
if model_type == "funasr":
|
||||||
|
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
|
||||||
|
|
||||||
|
# 检查输入类型
|
||||||
|
def check_audio_type(self,
|
||||||
|
audio_data):
|
||||||
|
"""check audio data type and convert it to bytes if necessary."""
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
elif isinstance(audio_data, list):
|
||||||
|
audio_data = b''.join(audio_data)
|
||||||
|
elif isinstance(audio_data, str):
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
elif isinstance(audio_data, io.BytesIO):
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, list, str, \
|
||||||
|
io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype != np.int16:
|
||||||
|
audio_data = audio_data.astype(np.int16)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
# 输入类型必须是float32
|
||||||
|
if isinstance(audio_data, np.ndarray):
|
||||||
|
audio_data = audio_data.astype(np.float32)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def process(self,
|
||||||
|
audio_data,
|
||||||
|
granularity="utterance",
|
||||||
|
extract_embedding=False,
|
||||||
|
output_dir=None,
|
||||||
|
only_score=True):
|
||||||
|
"""
|
||||||
|
audio_data: only float32 expected beacause layernorm
|
||||||
|
extract_embedding: save embedding if true
|
||||||
|
output_dir: save path for embedding
|
||||||
|
only_Score: only return lables & scores if true
|
||||||
|
"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 只保留 lables 和 scores
|
||||||
|
if only_score:
|
||||||
|
maintain_key = ["labels", "scores"]
|
||||||
|
for res in result:
|
||||||
|
keys_to_remove = [k for k in res.keys() if k not in maintain_key]
|
||||||
|
for k in keys_to_remove:
|
||||||
|
res.pop(k)
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
# only for test
|
||||||
|
def load_audio_file(wav_file):
|
||||||
|
with wave.open(wav_file, 'rb') as wf:
|
||||||
|
params = wf.getparams()
|
||||||
|
frames = wf.readframes(params.nframes)
|
||||||
|
print("Audio file loaded.")
|
||||||
|
# Audio Parameters
|
||||||
|
# print("Channels:", params.nchannels)
|
||||||
|
# print("Sample width:", params.sampwidth)
|
||||||
|
# print("Frame rate:", params.framerate)
|
||||||
|
# print("Number of frames:", params.nframes)
|
||||||
|
# print("Compression type:", params.comptype)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inputs = r".\example\test.wav"
|
||||||
|
inputs = load_audio_file(inputs)
|
||||||
|
device = "cuda"
|
||||||
|
# FUNASRBASE.update({"device": device})
|
||||||
|
FUNASRFINETUNE.update({"deivce": device})
|
||||||
|
emotion_model = Emotion(**FUNASRFINETUNE)
|
||||||
|
s = time.time()
|
||||||
|
result = emotion_model.process(inputs)
|
||||||
|
t = time.time()
|
||||||
|
print(t - s)
|
||||||
|
print(result)
|
|
@ -0,0 +1,186 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# ####################################################### #
|
||||||
|
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
|
||||||
|
# ####################################################### #
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from takway.common_utils import decode_str2bytes
|
||||||
|
from funasr import AutoModel
|
||||||
|
|
||||||
|
from takway.stt.base_stt import STTBase
|
||||||
|
|
||||||
|
class FunAutoSpeechRecognizer(STTBase):
|
||||||
|
def __init__(self,
|
||||||
|
model_path="paraformer-zh-streaming",
|
||||||
|
device="cuda",
|
||||||
|
RATE=16000,
|
||||||
|
cfg_path=None,
|
||||||
|
debug=False,
|
||||||
|
chunk_ms=480,
|
||||||
|
encoder_chunk_look_back=4,
|
||||||
|
decoder_chunk_look_back=1,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
|
||||||
|
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||||
|
|
||||||
|
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||||
|
self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
|
||||||
|
|
||||||
|
#[0, 8, 4] 480ms, [0, 10, 5] 600ms
|
||||||
|
if chunk_ms == 480:
|
||||||
|
self.chunk_size = [0, 8, 4]
|
||||||
|
elif chunk_ms == 600:
|
||||||
|
self.chunk_size = [0, 10, 5]
|
||||||
|
else:
|
||||||
|
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||||
|
self.chunk_partial_size = self.chunk_size[1] * 960
|
||||||
|
self.audio_cache = None
|
||||||
|
self.asr_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self._init_asr()
|
||||||
|
|
||||||
|
def check_audio_type(self, audio_data):
|
||||||
|
"""check audio data type and convert it to bytes if necessary."""
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
elif isinstance(audio_data, list):
|
||||||
|
audio_data = b''.join(audio_data)
|
||||||
|
elif isinstance(audio_data, str):
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
elif isinstance(audio_data, io.BytesIO):
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, list, str, \
|
||||||
|
io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
if isinstance(audio_data, bytes):
|
||||||
|
audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype != np.int16:
|
||||||
|
audio_data = audio_data.astype(np.int16)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def _init_asr(self):
|
||||||
|
# 随机初始化一段音频数据
|
||||||
|
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
|
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
self.audio_cache = None
|
||||||
|
self.asr_cache = {}
|
||||||
|
print("init ASR model done.")
|
||||||
|
|
||||||
|
def recognize(self, audio_data):
|
||||||
|
"""recognize audio data to text"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
result = self.asr_model.generate(input=audio_data,
|
||||||
|
batch_size_s=300,
|
||||||
|
hotword=self.hotwords)
|
||||||
|
|
||||||
|
# print(result)
|
||||||
|
text = ''
|
||||||
|
for res in result:
|
||||||
|
text += res['text']
|
||||||
|
return text
|
||||||
|
|
||||||
|
def streaming_recognize(self,
|
||||||
|
audio_data,
|
||||||
|
is_end=False,
|
||||||
|
auto_det_end=False):
|
||||||
|
"""recognize partial result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: bytes or numpy array, partial audio data
|
||||||
|
is_end: bool, whether the audio data is the end of a sentence
|
||||||
|
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
|
"""
|
||||||
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
if self.audio_cache is None:
|
||||||
|
self.audio_cache = audio_data
|
||||||
|
else:
|
||||||
|
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||||
|
if self.audio_cache.shape[0] > 0:
|
||||||
|
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||||
|
|
||||||
|
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
return text_dict
|
||||||
|
|
||||||
|
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
# if the audio data is the end of a sentence, \
|
||||||
|
# we need to add one more chunk to the end to \
|
||||||
|
# ensure the end of the sentence is recognized correctly.
|
||||||
|
auto_det_end = True
|
||||||
|
|
||||||
|
if auto_det_end:
|
||||||
|
total_chunk_num += 1
|
||||||
|
|
||||||
|
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||||
|
end_idx = None
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
if auto_det_end:
|
||||||
|
is_end = i == total_chunk_num - 1
|
||||||
|
start_idx = i*self.chunk_partial_size
|
||||||
|
if auto_det_end:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||||
|
else:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||||
|
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
|
# t_stamp = time.time()
|
||||||
|
|
||||||
|
speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
|
# TODO: exceptions processes
|
||||||
|
try:
|
||||||
|
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"ValueError: {e}")
|
||||||
|
continue
|
||||||
|
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||||
|
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
self.audio_cache = None
|
||||||
|
self.asr_cache = {}
|
||||||
|
else:
|
||||||
|
if end_idx:
|
||||||
|
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
text_dict['is_end'] = is_end
|
||||||
|
|
||||||
|
# print(f"text_dict: {text_dict}")
|
||||||
|
return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from takway.audio_utils import BaseAudio
|
||||||
|
rec = BaseAudio(input=True, CHUNK=3840)
|
||||||
|
|
||||||
|
# return_type = 'bytes'
|
||||||
|
file_path = 'my_recording.wav'
|
||||||
|
data = rec.load_audio_file(file_path)
|
||||||
|
|
||||||
|
asr = FunAutoSpeechRecognizer()
|
||||||
|
|
||||||
|
# asr.recognize(data)
|
||||||
|
total_chunk_num = int((len(data)-1)/rec.CHUNK+1)
|
||||||
|
print(f"total_chunk_num: {total_chunk_num}")
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
is_end = i == total_chunk_num - 1
|
||||||
|
speech_chunk = data[i*rec.CHUNK:(i+1)*rec.CHUNK]
|
||||||
|
text_dict = asr.streaming_recognize(speech_chunk, is_end)
|
||||||
|
'''
|
||||||
|
asr.streaming_recognize(data, auto_det_end=True)
|
||||||
|
'''
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
from takway.stt.funasr_utils import FunAutoSpeechRecognizer
|
||||||
|
from takway.stt.punctuation_utils import CTTRANSFORMER, Punctuation
|
||||||
|
from takway.stt.emotion_utils import FUNASRFINETUNE, Emotion
|
||||||
|
from takway.stt.speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import numpy as np
|
||||||
|
class ModifiedRecognizer(FunAutoSpeechRecognizer):
|
||||||
|
def __init__(self,
|
||||||
|
use_punct=True,
|
||||||
|
use_emotion=False,
|
||||||
|
use_speaker_ver=True):
|
||||||
|
super().__init__(
|
||||||
|
model_path="paraformer-zh-streaming",
|
||||||
|
device="cuda",
|
||||||
|
RATE=16000,
|
||||||
|
cfg_path=None,
|
||||||
|
debug=False,
|
||||||
|
chunk_ms=480,
|
||||||
|
encoder_chunk_look_back=4,
|
||||||
|
decoder_chunk_look_back=1)
|
||||||
|
self.use_punct = use_punct
|
||||||
|
self.use_emotion = use_emotion
|
||||||
|
self.use_speaker_ver = use_speaker_ver
|
||||||
|
|
||||||
|
if use_punct:
|
||||||
|
self.puctuation_model = Punctuation(**CTTRANSFORMER)
|
||||||
|
if use_emotion:
|
||||||
|
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
||||||
|
if use_speaker_ver:
|
||||||
|
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
|
||||||
|
|
||||||
|
def initialize_speaker(self, speaker_1_wav):
|
||||||
|
if not self.use_speaker_ver:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
if speaker_1_wav.endswith(".npy"):
|
||||||
|
self.save_speaker_path = speaker_1_wav
|
||||||
|
elif speaker_1_wav.endswith('.wav'):
|
||||||
|
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
|
||||||
|
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
|
||||||
|
# self.save_speaker_path = DEFALUT_SAVE_PATH
|
||||||
|
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
|
||||||
|
else:
|
||||||
|
raise TypeError("only support [.npy] or [.wav].")
|
||||||
|
|
||||||
|
|
||||||
|
def speaker_ver(self, speaker_2_wav):
|
||||||
|
if not self.use_speaker_ver:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
if not hasattr(self, "save_speaker_path"):
|
||||||
|
raise NotImplementedError("please initialize speaker first")
|
||||||
|
# pdb.set_trace()
|
||||||
|
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
|
||||||
|
speaker_2_wav=speaker_2_wav) == 'yes'
|
||||||
|
|
||||||
|
|
||||||
|
def recognize(self, audio_data):
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
result = self.asr_model.generate(input=audio_data,
|
||||||
|
batch_size_s=300,
|
||||||
|
hotword=self.hotwords)
|
||||||
|
text = ''
|
||||||
|
for res in result:
|
||||||
|
text += res['text']
|
||||||
|
if self.use_punct:
|
||||||
|
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def recognize_emotion(self, audio_data):
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
if self.use_emotion:
|
||||||
|
return self.emotion_model.process(audio_data)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
|
||||||
|
def streaming_recognize(self, audio_data, is_end=False, auto_det_end=False):
|
||||||
|
"""recognize partial result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: bytes or numpy array, partial audio data
|
||||||
|
is_end: bool, whether the audio data is the end of a sentence
|
||||||
|
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
|
"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
|
if self.audio_cache is None:
|
||||||
|
self.audio_cache = audio_data
|
||||||
|
else:
|
||||||
|
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||||
|
if self.audio_cache.shape[0] > 0:
|
||||||
|
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||||
|
|
||||||
|
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
return text_dict
|
||||||
|
|
||||||
|
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
# if the audio data is the end of a sentence, \
|
||||||
|
# we need to add one more chunk to the end to \
|
||||||
|
# ensure the end of the sentence is recognized correctly.
|
||||||
|
auto_det_end = True
|
||||||
|
|
||||||
|
if auto_det_end:
|
||||||
|
total_chunk_num += 1
|
||||||
|
|
||||||
|
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||||
|
end_idx = None
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
if auto_det_end:
|
||||||
|
is_end = i == total_chunk_num - 1
|
||||||
|
start_idx = i*self.chunk_partial_size
|
||||||
|
if auto_det_end:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||||
|
else:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||||
|
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
|
# t_stamp = time.time()
|
||||||
|
|
||||||
|
speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
|
# TODO: exceptions processes
|
||||||
|
try:
|
||||||
|
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"ValueError: {e}")
|
||||||
|
continue
|
||||||
|
if self.use_punct:
|
||||||
|
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
|
||||||
|
else:
|
||||||
|
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||||
|
|
||||||
|
|
||||||
|
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
self.audio_cache = None
|
||||||
|
self.asr_cache = {}
|
||||||
|
else:
|
||||||
|
if end_idx:
|
||||||
|
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
text_dict['is_end'] = is_end
|
||||||
|
|
||||||
|
if self.use_punct and is_end:
|
||||||
|
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
|
||||||
|
|
||||||
|
# print(f"text_dict: {text_dict}")
|
||||||
|
return text_dict
|
|
@ -0,0 +1,119 @@
|
||||||
|
from funasr import AutoModel
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"]
|
||||||
|
"""
|
||||||
|
FUNASR
|
||||||
|
模型大小: 1G
|
||||||
|
效果: 较好
|
||||||
|
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
|
||||||
|
"""
|
||||||
|
FUNASR = {
|
||||||
|
"model_type": "funasr",
|
||||||
|
"model_path": "ct-punc",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
CTTRANSFORMER
|
||||||
|
模型大小: 275M
|
||||||
|
效果:较差
|
||||||
|
输入类型: 支持字符串与list, 同时支持输入cache
|
||||||
|
"""
|
||||||
|
CTTRANSFORMER = {
|
||||||
|
"model_type": "ct-transformer",
|
||||||
|
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||||||
|
"model_revision": "v2.0.4"
|
||||||
|
}
|
||||||
|
|
||||||
|
class Punctuation:
|
||||||
|
def __init__(self,
|
||||||
|
model_type="funasr", # funasr | ct-transformer
|
||||||
|
model_path="ct-punc",
|
||||||
|
device="cuda",
|
||||||
|
model_revision="v2.0.4",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.model_type=model_type
|
||||||
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
||||||
|
|
||||||
|
def initialize(self,
|
||||||
|
model_type,
|
||||||
|
model_path,
|
||||||
|
device,
|
||||||
|
model_revision,
|
||||||
|
**kwargs):
|
||||||
|
if model_type == 'funasr':
|
||||||
|
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
||||||
|
elif model_type == 'ct-transformer':
|
||||||
|
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
|
||||||
|
|
||||||
|
def check_text_type(self,
|
||||||
|
text_data):
|
||||||
|
# funasr只支持单个str输入,不支持list输入,此处将list转化为字符串
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
if isinstance(text_data, str):
|
||||||
|
pass
|
||||||
|
elif isinstance(text_data, list):
|
||||||
|
text_data = ''.join(text_data)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||||||
|
# ct-transformer支持list输入
|
||||||
|
# TODO 验证拆分字符串能否提高效率
|
||||||
|
elif self.model_type == 'ct-transformer':
|
||||||
|
if isinstance(text_data, str):
|
||||||
|
text_data = [text_data]
|
||||||
|
elif isinstance(text_data, list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return text_data
|
||||||
|
|
||||||
|
def generate_cache(self, cache):
|
||||||
|
new_cache = {'pre_text': ""}
|
||||||
|
for text in cache['text']:
|
||||||
|
if text != '':
|
||||||
|
new_cache['pre_text'] = new_cache['pre_text']+text
|
||||||
|
return new_cache
|
||||||
|
|
||||||
|
def process(self,
|
||||||
|
text,
|
||||||
|
append_period=False,
|
||||||
|
cache={}):
|
||||||
|
if text == '':
|
||||||
|
return ''
|
||||||
|
text = self.check_text_type(text)
|
||||||
|
if self.model_type == 'funasr':
|
||||||
|
result = self.punc_model.generate(text)
|
||||||
|
elif self.model_type == 'ct-transformer':
|
||||||
|
if cache != {}:
|
||||||
|
cache = self.generate_cache(cache)
|
||||||
|
result = self.punc_model(text, cache=cache)
|
||||||
|
punced_text = ''
|
||||||
|
for res in result:
|
||||||
|
punced_text += res['text']
|
||||||
|
# 如果最后没有标点符号,手动加上。
|
||||||
|
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
|
||||||
|
punced_text += "。"
|
||||||
|
return punced_text
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
|
||||||
|
"""
|
||||||
|
把字符串拆分为list只适用于ct-transformer模型,
|
||||||
|
在数据处理部分,已经把list转为单个字符串
|
||||||
|
"""
|
||||||
|
vads = inputs.split("|")
|
||||||
|
device = "cuda"
|
||||||
|
CTTRANSFORMER.update({"device": device})
|
||||||
|
puct_model = Punctuation(**CTTRANSFORMER)
|
||||||
|
result = puct_model.process(vads)
|
||||||
|
print(result)
|
||||||
|
# FUNASR.update({"device":"cuda"})
|
||||||
|
# puct_model = Punctuation(**FUNASR)
|
||||||
|
# result = puct_model.process(vads)
|
||||||
|
# print(result)
|
|
@ -0,0 +1,86 @@
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
ERES2NETV2 = {
|
||||||
|
"task": 'speaker-verification',
|
||||||
|
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||||
|
"model_revision": 'v1.0.1',
|
||||||
|
"save_embeddings": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存 embedding 的路径
|
||||||
|
DEFALUT_SAVE_PATH = r"D:\python\irving\takway_base-main\examples"
|
||||||
|
|
||||||
|
class speaker_verfication:
|
||||||
|
def __init__(self,
|
||||||
|
task='speaker-verification',
|
||||||
|
model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||||
|
model_revision='v1.0.1',
|
||||||
|
device="cuda",
|
||||||
|
save_embeddings=False):
|
||||||
|
self.pipeline = pipeline(
|
||||||
|
task=task,
|
||||||
|
model=model_name,
|
||||||
|
model_revision=model_revision,
|
||||||
|
device=device)
|
||||||
|
self.save_embeddings = save_embeddings
|
||||||
|
|
||||||
|
def wav2embeddings(self, speaker_1_wav, save_path=None):
|
||||||
|
result = self.pipeline([speaker_1_wav], output_emb=True)
|
||||||
|
speaker_1_emb = result['embs'][0]
|
||||||
|
if save_path is not None:
|
||||||
|
np.save(save_path, speaker_1_emb)
|
||||||
|
return speaker_1_emb
|
||||||
|
|
||||||
|
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
|
||||||
|
if not self.save_embeddings:
|
||||||
|
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold)
|
||||||
|
return result["text"]
|
||||||
|
else:
|
||||||
|
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True)
|
||||||
|
speaker1_emb = result["embs"][0]
|
||||||
|
speaker2_emb = result["embs"][1]
|
||||||
|
np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb)
|
||||||
|
return result['outputs']["text"]
|
||||||
|
|
||||||
|
def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold):
|
||||||
|
base_emb = np.load(base_emb)
|
||||||
|
result = self.pipeline([speaker_2_wav], output_emb=True)
|
||||||
|
speaker2_emb = result["embs"][0]
|
||||||
|
similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb))
|
||||||
|
if similarity > threshold:
|
||||||
|
return "yes"
|
||||||
|
else:
|
||||||
|
return "no"
|
||||||
|
|
||||||
|
def verfication(self,
|
||||||
|
base_emb=None,
|
||||||
|
speaker_1_wav=None,
|
||||||
|
speaker_2_wav=None,
|
||||||
|
threshold=0.333,
|
||||||
|
save_path=None):
|
||||||
|
if base_emb is not None and speaker_1_wav is not None:
|
||||||
|
raise ValueError("Only need one of them, base_emb or speaker_1_wav")
|
||||||
|
if base_emb is not None and speaker_2_wav is not None:
|
||||||
|
return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
|
||||||
|
elif speaker_1_wav is not None and speaker_2_wav is not None:
|
||||||
|
return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
verifier = speaker_verfication(**ERES2NETV2)
|
||||||
|
|
||||||
|
verifier = speaker_verfication(save_embeddings=False)
|
||||||
|
result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav",
|
||||||
|
speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav",
|
||||||
|
threshold=0.333,
|
||||||
|
save_path=r"D:\python\irving\takway_base-main\savePath"
|
||||||
|
)
|
||||||
|
print("---")
|
||||||
|
print(result)
|
||||||
|
print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy",
|
||||||
|
speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav",
|
||||||
|
threshold=0.333,
|
||||||
|
))
|
|
@ -0,0 +1,120 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# ####################################################### #
|
||||||
|
# VOSKAutoSpeechRecognizer
|
||||||
|
# ####################################################### #
|
||||||
|
import json
|
||||||
|
import wave
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from vosk import Model, KaldiRecognizer, SetLogLevel
|
||||||
|
from .base_stt import STTBase
|
||||||
|
from ..common_utils import decode_str2bytes
|
||||||
|
|
||||||
|
class VOSKAutoSpeechRecognizer(STTBase):
|
||||||
|
def __init__(self, model_path="vosk-model-small-cn-0.22", RATE=16000, cfg_path=None, efficent_mode=True, debug=False):
|
||||||
|
super().__init__(self, model_path=model_path, RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
self.asr_model = AutoModel(model="paraformer-zh-streaming")
|
||||||
|
|
||||||
|
self.apply_asr_config(self.asr_cfg)
|
||||||
|
|
||||||
|
def recognize_keywords(self, audio_data, partial_size=None, queue=None):
|
||||||
|
"""recognize keywords in audio data"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
if partial_size is None:
|
||||||
|
rec_result = self.recognize(audio_data, queue)
|
||||||
|
rec_text = self.result_postprecess(rec_result)
|
||||||
|
else:
|
||||||
|
rec_result = self.partial_recognize(audio_data, partial_size, queue)
|
||||||
|
rec_text = self.result_postprecess(rec_result, 'partial')
|
||||||
|
print(f"rec_text: {rec_text}")
|
||||||
|
if rec_text != '':
|
||||||
|
print(f"rec_text: {rec_text}")
|
||||||
|
if any(keyword in rec_text for keyword in self.keywords):
|
||||||
|
print("Keyword detected.")
|
||||||
|
return True, rec_text
|
||||||
|
else:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def recognize(self, audio_data, queue=None):
|
||||||
|
"""recognize audio data to text"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
self.asr.AcceptWaveform(audio_data)
|
||||||
|
result = json.loads(self.asr.FinalResult())
|
||||||
|
# TODO: put result to queue
|
||||||
|
return result
|
||||||
|
|
||||||
|
def partial_recognize(self, audio_data, partial_size=1024, queue=None):
|
||||||
|
"""recognize partial result"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
text_dict = dict(
|
||||||
|
text=[],
|
||||||
|
partial=[],
|
||||||
|
final=[],
|
||||||
|
is_end=False)
|
||||||
|
# 逐个分割音频数据进行识别
|
||||||
|
for i in range(0, len(audio_data), partial_size):
|
||||||
|
# print(f"partial data: {i} - {i+partial_size}")
|
||||||
|
data = audio_data[i:i+partial_size]
|
||||||
|
if len(data) == 0:
|
||||||
|
break
|
||||||
|
if self.asr.AcceptWaveform(data):
|
||||||
|
result = json.loads(self.asr.Result())
|
||||||
|
if result['text'] != '':
|
||||||
|
text_dict['text'].append(result['text'])
|
||||||
|
if queue is not None:
|
||||||
|
queue.put(('stt_info', text_dict))
|
||||||
|
# print(f"text result: {result}")
|
||||||
|
else:
|
||||||
|
result = json.loads(self.asr.PartialResult())
|
||||||
|
if result['partial'] != '':
|
||||||
|
# text_dict['partial'].append(result['partial'])
|
||||||
|
text_dict['partial'] = [result['partial']]
|
||||||
|
if queue is not None:
|
||||||
|
queue.put(('stt_info', text_dict))
|
||||||
|
# print(f"partial result: {result}")
|
||||||
|
|
||||||
|
# final recognize
|
||||||
|
final_result = json.loads(self.asr.FinalResult())
|
||||||
|
if final_result['text'] != '':
|
||||||
|
text_dict['final'].append(final_result['text'])
|
||||||
|
text_dict['text'].append(final_result['text'])
|
||||||
|
|
||||||
|
text_dict['is_end'] = True
|
||||||
|
|
||||||
|
print(f"final dict: {text_dict}")
|
||||||
|
if queue is not None:
|
||||||
|
queue.put(('stt_info', text_dict))
|
||||||
|
return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
'''
|
||||||
|
wav_file_path = "recording.wav"
|
||||||
|
|
||||||
|
# You can set log level to -1 to disable debug messages
|
||||||
|
SetLogLevel(0)
|
||||||
|
|
||||||
|
model = Model(model_path="vosk-model-small-cn-0.22")
|
||||||
|
|
||||||
|
# 调用函数进行录音
|
||||||
|
# record_audio(wav_file_path)
|
||||||
|
data = record_audio()
|
||||||
|
|
||||||
|
# 调用函数进行音频转写
|
||||||
|
result = audio_to_text(data, model)
|
||||||
|
|
||||||
|
print("-------------")
|
||||||
|
print(result)
|
||||||
|
'''
|
||||||
|
from takway.audio_utils import Recorder
|
||||||
|
rec = Recorder()
|
||||||
|
|
||||||
|
return_type = 'bytes'
|
||||||
|
data = rec.record(return_type)
|
||||||
|
print(type(data))
|
||||||
|
|
||||||
|
asr = AutoSpeechRecognizer()
|
||||||
|
# asr.recognize(data)
|
||||||
|
asr.add_keyword("你好")
|
||||||
|
asr.recognize_keywords(data)
|
Loading…
Reference in New Issue