commit 31017e1ec483e7d83c0d42070d0585cf6348e7fc Author: bing <2524698668@qq.com> Date: Sat May 11 22:02:52 2024 +0800 utils for punctuation and emotion and speaker ver diff --git a/takway/stt/__init__.py b/takway/stt/__init__.py new file mode 100644 index 0000000..413fa1f --- /dev/null +++ b/takway/stt/__init__.py @@ -0,0 +1 @@ +from .base_stt import * \ No newline at end of file diff --git a/takway/stt/__pycache__/__init__.cpython-39.pyc b/takway/stt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..22e74bb Binary files /dev/null and b/takway/stt/__pycache__/__init__.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/base_stt.cpython-39.pyc b/takway/stt/__pycache__/base_stt.cpython-39.pyc new file mode 100644 index 0000000..ac43531 Binary files /dev/null and b/takway/stt/__pycache__/base_stt.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/emotion_utils.cpython-39.pyc b/takway/stt/__pycache__/emotion_utils.cpython-39.pyc new file mode 100644 index 0000000..0e19cdb Binary files /dev/null and b/takway/stt/__pycache__/emotion_utils.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/funasr_utils.cpython-39.pyc b/takway/stt/__pycache__/funasr_utils.cpython-39.pyc new file mode 100644 index 0000000..de72d32 Binary files /dev/null and b/takway/stt/__pycache__/funasr_utils.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/modified_funasr.cpython-39.pyc b/takway/stt/__pycache__/modified_funasr.cpython-39.pyc new file mode 100644 index 0000000..c6a7b29 Binary files /dev/null and b/takway/stt/__pycache__/modified_funasr.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/punctuation_utils.cpython-39.pyc b/takway/stt/__pycache__/punctuation_utils.cpython-39.pyc new file mode 100644 index 0000000..1655d9e Binary files /dev/null and b/takway/stt/__pycache__/punctuation_utils.cpython-39.pyc differ diff --git a/takway/stt/__pycache__/speaker_ver_utils.cpython-39.pyc b/takway/stt/__pycache__/speaker_ver_utils.cpython-39.pyc new file mode 100644 index 0000000..95b3df3 Binary files /dev/null and b/takway/stt/__pycache__/speaker_ver_utils.cpython-39.pyc differ diff --git a/takway/stt/base_stt.py b/takway/stt/base_stt.py new file mode 100644 index 0000000..4763446 --- /dev/null +++ b/takway/stt/base_stt.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import json +import wave +import io +import os +import logging +from ..common_utils import decode_str2bytes + +class STTBase: + def __init__(self, RATE=16000, cfg_path=None, debug=False): + self.RATE = RATE + self.debug = debug + self.asr_cfg = self.parse_json(cfg_path) + + def parse_json(self, cfg_path): + cfg = None + self.hotwords = None + if cfg_path is not None: + with open(cfg_path, 'r', encoding='utf-8') as f: + cfg = json.load(f) + self.hotwords = cfg.get('hot_words', None) + logging.info(f"load STT config file: {cfg_path}") + logging.info(f"Hot words: {self.hotwords}") + else: + logging.warning("No STT config file provided, using default config.") + return cfg + + def add_hotword(self, hotword): + """add hotword to list""" + if self.hotwords is None: + self.hotwords = "" + if isinstance(hotword, str): + self.hotwords = self.hotwords + " " + "hotword" + elif isinstance(hotword, (list, tuple)): + # 将hotwords转换为str,并用空格隔开 + self.hotwords = self.hotwords + " " + " ".join(hotword) + else: + raise TypeError("hotword must be str or list") + + def check_audio_type(self, audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + else: + raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}") + return audio_data + + def text_postprecess(self, result, data_id='text'): + """postprecess recognized result.""" + text = result[data_id] + if isinstance(text, list): + text = ''.join(text) + return text.replace(' ', '') + + def recognize(self, audio_data, queue=None): + """recognize audio data to text""" + pass diff --git a/takway/stt/emotion_utils.py b/takway/stt/emotion_utils.py new file mode 100644 index 0000000..8d38423 --- /dev/null +++ b/takway/stt/emotion_utils.py @@ -0,0 +1,142 @@ +import io +import numpy as np +import base64 +import wave +from funasr import AutoModel +import time +""" + Base模型 + 不能进行情绪分类,只能用作特征提取 +""" +FUNASRBASE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base", + "model_revision": "v2.0.4" +} + + +""" + Finetune模型 + 输出分类结果 +""" +FUNASRFINETUNE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base_finetuned" +} + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class Emotion: + def __init__(self, + model_type="funasr", + model_path="iic/emotion2vec_base", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type = model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + # 初始化模型 + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == "funasr": + self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.") + + # 检查输入类型 + def check_audio_type(self, + audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + pass + else: + raise TypeError(f"audio_data must be bytes, list, str, \ + io.BytesIO or numpy array, but got {type(audio_data)}") + + if isinstance(audio_data, bytes): + audio_data = np.frombuffer(audio_data, dtype=np.int16) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + else: + raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") + + # 输入类型必须是float32 + if isinstance(audio_data, np.ndarray): + audio_data = audio_data.astype(np.float32) + else: + raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}") + return audio_data + + def process(self, + audio_data, + granularity="utterance", + extract_embedding=False, + output_dir=None, + only_score=True): + """ + audio_data: only float32 expected beacause layernorm + extract_embedding: save embedding if true + output_dir: save path for embedding + only_Score: only return lables & scores if true + """ + audio_data = self.check_audio_type(audio_data) + if self.model_type == 'funasr': + result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding) + else: + pass + + # 只保留 lables 和 scores + if only_score: + maintain_key = ["labels", "scores"] + for res in result: + keys_to_remove = [k for k in res.keys() if k not in maintain_key] + for k in keys_to_remove: + res.pop(k) + return result[0] + +# only for test +def load_audio_file(wav_file): + with wave.open(wav_file, 'rb') as wf: + params = wf.getparams() + frames = wf.readframes(params.nframes) + print("Audio file loaded.") + # Audio Parameters + # print("Channels:", params.nchannels) + # print("Sample width:", params.sampwidth) + # print("Frame rate:", params.framerate) + # print("Number of frames:", params.nframes) + # print("Compression type:", params.comptype) + return frames + +if __name__ == "__main__": + inputs = r".\example\test.wav" + inputs = load_audio_file(inputs) + device = "cuda" + # FUNASRBASE.update({"device": device}) + FUNASRFINETUNE.update({"deivce": device}) + emotion_model = Emotion(**FUNASRFINETUNE) + s = time.time() + result = emotion_model.process(inputs) + t = time.time() + print(t - s) + print(result) \ No newline at end of file diff --git a/takway/stt/funasr_utils.py b/takway/stt/funasr_utils.py new file mode 100644 index 0000000..92eb41b --- /dev/null +++ b/takway/stt/funasr_utils.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ####################################################### # +# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR +# ####################################################### # +import io +import time +import numpy as np +from takway.common_utils import decode_str2bytes +from funasr import AutoModel + +from takway.stt.base_stt import STTBase + +class FunAutoSpeechRecognizer(STTBase): + def __init__(self, + model_path="paraformer-zh-streaming", + device="cuda", + RATE=16000, + cfg_path=None, + debug=False, + chunk_ms=480, + encoder_chunk_look_back=4, + decoder_chunk_look_back=1, + **kwargs): + super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) + + self.asr_model = AutoModel(model=model_path, device=device, **kwargs) + + self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention + self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention + + #[0, 8, 4] 480ms, [0, 10, 5] 600ms + if chunk_ms == 480: + self.chunk_size = [0, 8, 4] + elif chunk_ms == 600: + self.chunk_size = [0, 10, 5] + else: + raise ValueError("`chunk_ms` should be 480 or 600, and type is int.") + self.chunk_partial_size = self.chunk_size[1] * 960 + self.audio_cache = None + self.asr_cache = {} + + + + self._init_asr() + + def check_audio_type(self, audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + pass + else: + raise TypeError(f"audio_data must be bytes, list, str, \ + io.BytesIO or numpy array, but got {type(audio_data)}") + + if isinstance(audio_data, bytes): + audio_data = np.frombuffer(audio_data, dtype=np.int16) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + else: + raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") + return audio_data + + def _init_asr(self): + # 随机初始化一段音频数据 + init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) + self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + self.audio_cache = None + self.asr_cache = {} + print("init ASR model done.") + + def recognize(self, audio_data): + """recognize audio data to text""" + audio_data = self.check_audio_type(audio_data) + result = self.asr_model.generate(input=audio_data, + batch_size_s=300, + hotword=self.hotwords) + + # print(result) + text = '' + for res in result: + text += res['text'] + return text + + def streaming_recognize(self, + audio_data, + is_end=False, + auto_det_end=False): + """recognize partial result + + Args: + audio_data: bytes or numpy array, partial audio data + is_end: bool, whether the audio data is the end of a sentence + auto_det_end: bool, whether to automatically detect the end of a audio data + """ + text_dict = dict(text=[], is_end=is_end) + + audio_data = self.check_audio_type(audio_data) + if self.audio_cache is None: + self.audio_cache = audio_data + else: + # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") + if self.audio_cache.shape[0] > 0: + self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) + + if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: + return text_dict + + total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + + if is_end: + # if the audio data is the end of a sentence, \ + # we need to add one more chunk to the end to \ + # ensure the end of the sentence is recognized correctly. + auto_det_end = True + + if auto_det_end: + total_chunk_num += 1 + + # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") + end_idx = None + for i in range(total_chunk_num): + if auto_det_end: + is_end = i == total_chunk_num - 1 + start_idx = i*self.chunk_partial_size + if auto_det_end: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 + else: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 + # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") + # t_stamp = time.time() + + speech_chunk = self.audio_cache[start_idx:end_idx] + + # TODO: exceptions processes + try: + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + except ValueError as e: + print(f"ValueError: {e}") + continue + text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) + # print(f"each chunk time: {time.time()-t_stamp}") + + if is_end: + self.audio_cache = None + self.asr_cache = {} + else: + if end_idx: + self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache + text_dict['is_end'] = is_end + + # print(f"text_dict: {text_dict}") + return text_dict + + + +if __name__ == '__main__': + from takway.audio_utils import BaseAudio + rec = BaseAudio(input=True, CHUNK=3840) + + # return_type = 'bytes' + file_path = 'my_recording.wav' + data = rec.load_audio_file(file_path) + + asr = FunAutoSpeechRecognizer() + + # asr.recognize(data) + total_chunk_num = int((len(data)-1)/rec.CHUNK+1) + print(f"total_chunk_num: {total_chunk_num}") + for i in range(total_chunk_num): + is_end = i == total_chunk_num - 1 + speech_chunk = data[i*rec.CHUNK:(i+1)*rec.CHUNK] + text_dict = asr.streaming_recognize(speech_chunk, is_end) + ''' + asr.streaming_recognize(data, auto_det_end=True) + ''' + \ No newline at end of file diff --git a/takway/stt/modified_funasr.py b/takway/stt/modified_funasr.py new file mode 100644 index 0000000..5628aad --- /dev/null +++ b/takway/stt/modified_funasr.py @@ -0,0 +1,168 @@ +from takway.stt.funasr_utils import FunAutoSpeechRecognizer +from takway.stt.punctuation_utils import CTTRANSFORMER, Punctuation +from takway.stt.emotion_utils import FUNASRFINETUNE, Emotion +from takway.stt.speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication +import os +import pdb +import numpy as np +class ModifiedRecognizer(FunAutoSpeechRecognizer): + def __init__(self, + use_punct=True, + use_emotion=False, + use_speaker_ver=True): + super().__init__( + model_path="paraformer-zh-streaming", + device="cuda", + RATE=16000, + cfg_path=None, + debug=False, + chunk_ms=480, + encoder_chunk_look_back=4, + decoder_chunk_look_back=1) + self.use_punct = use_punct + self.use_emotion = use_emotion + self.use_speaker_ver = use_speaker_ver + + if use_punct: + self.puctuation_model = Punctuation(**CTTRANSFORMER) + if use_emotion: + self.emotion_model = Emotion(**FUNASRFINETUNE) + if use_speaker_ver: + self.speaker_ver_model = speaker_verfication(**ERES2NETV2) + + def initialize_speaker(self, speaker_1_wav): + if not self.use_speaker_ver: + raise NotImplementedError("no access") + if speaker_1_wav.endswith(".npy"): + self.save_speaker_path = speaker_1_wav + elif speaker_1_wav.endswith('.wav'): + self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH, + os.path.basename(speaker_1_wav).replace(".wav", ".npy")) + # self.save_speaker_path = DEFALUT_SAVE_PATH + self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path) + else: + raise TypeError("only support [.npy] or [.wav].") + + + def speaker_ver(self, speaker_2_wav): + if not self.use_speaker_ver: + raise NotImplementedError("no access") + if not hasattr(self, "save_speaker_path"): + raise NotImplementedError("please initialize speaker first") + # pdb.set_trace() + return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path, + speaker_2_wav=speaker_2_wav) == 'yes' + + + def recognize(self, audio_data): + audio_data = self.check_audio_type(audio_data) + + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + result = self.asr_model.generate(input=audio_data, + batch_size_s=300, + hotword=self.hotwords) + text = '' + for res in result: + text += res['text'] + if self.use_punct: + text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '') + + return text + + def recognize_emotion(self, audio_data): + audio_data = self.check_audio_type(audio_data) + + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + if self.use_emotion: + return self.emotion_model.process(audio_data) + else: + raise NotImplementedError("no access") + + def streaming_recognize(self, audio_data, is_end=False, auto_det_end=False): + """recognize partial result + + Args: + audio_data: bytes or numpy array, partial audio data + is_end: bool, whether the audio data is the end of a sentence + auto_det_end: bool, whether to automatically detect the end of a audio data + """ + audio_data = self.check_audio_type(audio_data) + + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + text_dict = dict(text=[], is_end=is_end) + + if self.audio_cache is None: + self.audio_cache = audio_data + else: + # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") + if self.audio_cache.shape[0] > 0: + self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) + + if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: + return text_dict + + total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + + if is_end: + # if the audio data is the end of a sentence, \ + # we need to add one more chunk to the end to \ + # ensure the end of the sentence is recognized correctly. + auto_det_end = True + + if auto_det_end: + total_chunk_num += 1 + + # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") + end_idx = None + for i in range(total_chunk_num): + if auto_det_end: + is_end = i == total_chunk_num - 1 + start_idx = i*self.chunk_partial_size + if auto_det_end: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 + else: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 + # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") + # t_stamp = time.time() + + speech_chunk = self.audio_cache[start_idx:end_idx] + + # TODO: exceptions processes + try: + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + except ValueError as e: + print(f"ValueError: {e}") + continue + if self.use_punct: + text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict)) + else: + text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) + + + # print(f"each chunk time: {time.time()-t_stamp}") + + if is_end: + self.audio_cache = None + self.asr_cache = {} + else: + if end_idx: + self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache + text_dict['is_end'] = is_end + + if self.use_punct and is_end: + text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', '')) + + # print(f"text_dict: {text_dict}") + return text_dict \ No newline at end of file diff --git a/takway/stt/punctuation_utils.py b/takway/stt/punctuation_utils.py new file mode 100644 index 0000000..9e038e0 --- /dev/null +++ b/takway/stt/punctuation_utils.py @@ -0,0 +1,119 @@ +from funasr import AutoModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"] +""" +FUNASR + 模型大小: 1G + 效果: 较好 + 输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理 +""" +FUNASR = { + "model_type": "funasr", + "model_path": "ct-punc", + "model_revision": "v2.0.4" +} +""" +CTTRANSFORMER + 模型大小: 275M + 效果:较差 + 输入类型: 支持字符串与list, 同时支持输入cache +""" +CTTRANSFORMER = { + "model_type": "ct-transformer", + "model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + "model_revision": "v2.0.4" +} + +class Punctuation: + def __init__(self, + model_type="funasr", # funasr | ct-transformer + model_path="ct-punc", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type=model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == 'funasr': + self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + elif model_type == 'ct-transformer': + self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.") + + def check_text_type(self, + text_data): + # funasr只支持单个str输入,不支持list输入,此处将list转化为字符串 + if self.model_type == 'funasr': + if isinstance(text_data, str): + pass + elif isinstance(text_data, list): + text_data = ''.join(text_data) + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + # ct-transformer支持list输入 + # TODO 验证拆分字符串能否提高效率 + elif self.model_type == 'ct-transformer': + if isinstance(text_data, str): + text_data = [text_data] + elif isinstance(text_data, list): + pass + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + else: + pass + return text_data + + def generate_cache(self, cache): + new_cache = {'pre_text': ""} + for text in cache['text']: + if text != '': + new_cache['pre_text'] = new_cache['pre_text']+text + return new_cache + + def process(self, + text, + append_period=False, + cache={}): + if text == '': + return '' + text = self.check_text_type(text) + if self.model_type == 'funasr': + result = self.punc_model.generate(text) + elif self.model_type == 'ct-transformer': + if cache != {}: + cache = self.generate_cache(cache) + result = self.punc_model(text, cache=cache) + punced_text = '' + for res in result: + punced_text += res['text'] + # 如果最后没有标点符号,手动加上。 + if append_period and not punced_text[-1] in PUNCTUATION_MARK: + punced_text += "。" + return punced_text + +if __name__ == "__main__": + inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串" + """ + 把字符串拆分为list只适用于ct-transformer模型, + 在数据处理部分,已经把list转为单个字符串 + """ + vads = inputs.split("|") + device = "cuda" + CTTRANSFORMER.update({"device": device}) + puct_model = Punctuation(**CTTRANSFORMER) + result = puct_model.process(vads) + print(result) + # FUNASR.update({"device":"cuda"}) + # puct_model = Punctuation(**FUNASR) + # result = puct_model.process(vads) + # print(result) \ No newline at end of file diff --git a/takway/stt/speaker_ver_utils.py b/takway/stt/speaker_ver_utils.py new file mode 100644 index 0000000..838393f --- /dev/null +++ b/takway/stt/speaker_ver_utils.py @@ -0,0 +1,86 @@ +from modelscope.pipelines import pipeline +import numpy as np +import os +import pdb +ERES2NETV2 = { + "task": 'speaker-verification', + "model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common', + "model_revision": 'v1.0.1', + "save_embeddings": False +} + +# 保存 embedding 的路径 +DEFALUT_SAVE_PATH = r"D:\python\irving\takway_base-main\examples" + +class speaker_verfication: + def __init__(self, + task='speaker-verification', + model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common', + model_revision='v1.0.1', + device="cuda", + save_embeddings=False): + self.pipeline = pipeline( + task=task, + model=model_name, + model_revision=model_revision, + device=device) + self.save_embeddings = save_embeddings + + def wav2embeddings(self, speaker_1_wav, save_path=None): + result = self.pipeline([speaker_1_wav], output_emb=True) + speaker_1_emb = result['embs'][0] + if save_path is not None: + np.save(save_path, speaker_1_emb) + return speaker_1_emb + + def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path): + if not self.save_embeddings: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold) + return result["text"] + else: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True) + speaker1_emb = result["embs"][0] + speaker2_emb = result["embs"][1] + np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb) + return result['outputs']["text"] + + def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold): + base_emb = np.load(base_emb) + result = self.pipeline([speaker_2_wav], output_emb=True) + speaker2_emb = result["embs"][0] + similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb)) + if similarity > threshold: + return "yes" + else: + return "no" + + def verfication(self, + base_emb=None, + speaker_1_wav=None, + speaker_2_wav=None, + threshold=0.333, + save_path=None): + if base_emb is not None and speaker_1_wav is not None: + raise ValueError("Only need one of them, base_emb or speaker_1_wav") + if base_emb is not None and speaker_2_wav is not None: + return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold) + elif speaker_1_wav is not None and speaker_2_wav is not None: + return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path) + else: + raise NotImplementedError + +if __name__ == '__main__': + verifier = speaker_verfication(**ERES2NETV2) + + verifier = speaker_verfication(save_embeddings=False) + result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav", + threshold=0.333, + save_path=r"D:\python\irving\takway_base-main\savePath" + ) + print("---") + print(result) + print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav", + threshold=0.333, + )) \ No newline at end of file diff --git a/takway/stt/vosk_utils.py b/takway/stt/vosk_utils.py new file mode 100644 index 0000000..b67cfa5 --- /dev/null +++ b/takway/stt/vosk_utils.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ####################################################### # +# VOSKAutoSpeechRecognizer +# ####################################################### # +import json +import wave +import io +import os +from vosk import Model, KaldiRecognizer, SetLogLevel +from .base_stt import STTBase +from ..common_utils import decode_str2bytes + +class VOSKAutoSpeechRecognizer(STTBase): + def __init__(self, model_path="vosk-model-small-cn-0.22", RATE=16000, cfg_path=None, efficent_mode=True, debug=False): + super().__init__(self, model_path=model_path, RATE=RATE, cfg_path=cfg_path, debug=debug) + self.asr_model = AutoModel(model="paraformer-zh-streaming") + + self.apply_asr_config(self.asr_cfg) + + def recognize_keywords(self, audio_data, partial_size=None, queue=None): + """recognize keywords in audio data""" + audio_data = self.check_audio_type(audio_data) + if partial_size is None: + rec_result = self.recognize(audio_data, queue) + rec_text = self.result_postprecess(rec_result) + else: + rec_result = self.partial_recognize(audio_data, partial_size, queue) + rec_text = self.result_postprecess(rec_result, 'partial') + print(f"rec_text: {rec_text}") + if rec_text != '': + print(f"rec_text: {rec_text}") + if any(keyword in rec_text for keyword in self.keywords): + print("Keyword detected.") + return True, rec_text + else: + return False, None + + def recognize(self, audio_data, queue=None): + """recognize audio data to text""" + audio_data = self.check_audio_type(audio_data) + self.asr.AcceptWaveform(audio_data) + result = json.loads(self.asr.FinalResult()) + # TODO: put result to queue + return result + + def partial_recognize(self, audio_data, partial_size=1024, queue=None): + """recognize partial result""" + audio_data = self.check_audio_type(audio_data) + text_dict = dict( + text=[], + partial=[], + final=[], + is_end=False) + # 逐个分割音频数据进行识别 + for i in range(0, len(audio_data), partial_size): + # print(f"partial data: {i} - {i+partial_size}") + data = audio_data[i:i+partial_size] + if len(data) == 0: + break + if self.asr.AcceptWaveform(data): + result = json.loads(self.asr.Result()) + if result['text'] != '': + text_dict['text'].append(result['text']) + if queue is not None: + queue.put(('stt_info', text_dict)) + # print(f"text result: {result}") + else: + result = json.loads(self.asr.PartialResult()) + if result['partial'] != '': + # text_dict['partial'].append(result['partial']) + text_dict['partial'] = [result['partial']] + if queue is not None: + queue.put(('stt_info', text_dict)) + # print(f"partial result: {result}") + + # final recognize + final_result = json.loads(self.asr.FinalResult()) + if final_result['text'] != '': + text_dict['final'].append(final_result['text']) + text_dict['text'].append(final_result['text']) + + text_dict['is_end'] = True + + print(f"final dict: {text_dict}") + if queue is not None: + queue.put(('stt_info', text_dict)) + return text_dict + + +if __name__ == "__main__": + ''' + wav_file_path = "recording.wav" + + # You can set log level to -1 to disable debug messages + SetLogLevel(0) + + model = Model(model_path="vosk-model-small-cn-0.22") + + # 调用函数进行录音 + # record_audio(wav_file_path) + data = record_audio() + + # 调用函数进行音频转写 + result = audio_to_text(data, model) + + print("-------------") + print(result) + ''' + from takway.audio_utils import Recorder + rec = Recorder() + + return_type = 'bytes' + data = rec.record(return_type) + print(type(data)) + + asr = AutoSpeechRecognizer() + # asr.recognize(data) + asr.add_keyword("你好") + asr.recognize_keywords(data) \ No newline at end of file