diff --git a/examples/assets/example_recording.wav b/examples/assets/example_recording.wav new file mode 100644 index 0000000..2f16668 Binary files /dev/null and b/examples/assets/example_recording.wav differ diff --git a/examples/audio_utils.py b/examples/audio_utils.py new file mode 100644 index 0000000..37fac5e --- /dev/null +++ b/examples/audio_utils.py @@ -0,0 +1,284 @@ +import os +import io +import numpy as np +import pyaudio +import wave +import base64 +""" + audio utils for modified_funasr_demo.py +""" + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class BaseAudio: + def __init__(self, + filename=None, + input=False, + output=False, + CHUNK=1024, + FORMAT=pyaudio.paInt16, + CHANNELS=1, + RATE=16000, + input_device_index=None, + output_device_index=None, + **kwargs): + self.CHUNK = CHUNK + self.FORMAT = FORMAT + self.CHANNELS = CHANNELS + self.RATE = RATE + self.filename = filename + assert input!= output, "input and output cannot be the same, \ + but got input={} and output={}.".format(input, output) + print("------------------------------------------") + print(f"{'Input' if input else 'Output'} Audio Initialization: ") + print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}") + print("------------------------------------------") + self.p = pyaudio.PyAudio() + self.stream = self.p.open(format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=input, + output=output, + input_device_index=input_device_index, + output_device_index=output_device_index, + **kwargs) + + def load_audio_file(self, wav_file): + with wave.open(wav_file, 'rb') as wf: + params = wf.getparams() + frames = wf.readframes(params.nframes) + print("Audio file loaded.") + # Audio Parameters + # print("Channels:", params.nchannels) + # print("Sample width:", params.sampwidth) + # print("Frame rate:", params.framerate) + # print("Number of frames:", params.nframes) + # print("Compression type:", params.comptype) + return frames + + def check_audio_type(self, audio_data, return_type=None): + assert return_type in ['bytes', 'io', None], \ + "return_type should be 'bytes', 'io' or None." + if isinstance(audio_data, str): + if len(audio_data) > 50: + audio_data = decode_str2bytes(audio_data) + else: + assert os.path.isfile(audio_data), \ + "audio_data should be a file path or a bytes object." + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype == np.dtype('float32'): + audio_data = np.int16(audio_data * np.iinfo(np.int16).max) + audio_data = audio_data.tobytes() + elif isinstance(audio_data, bytes): + pass + else: + raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \ + but got {type(audio_data)}") + + if return_type == None: + return audio_data + return self.write_wave(None, [audio_data], return_type) + + def write_wave(self, filename, frames, return_type='io'): + """Write audio data to a file.""" + if isinstance(frames, bytes): + frames = [frames] + if not isinstance(frames, list): + raise TypeError("frames should be \ + a list of bytes or a bytes object, \ + but got {}.".format(type(frames))) + + if return_type == 'io': + if filename is None: + filename = io.BytesIO() + if self.filename: + filename = self.filename + return self.write_wave_io(filename, frames) + elif return_type == 'bytes': + return self.write_wave_bytes(frames) + + + def write_wave_io(self, filename, frames): + """ + Write audio data to a file-like object. + + Args: + filename: [string or file-like object], file path or file-like object to write + frames: list of bytes, audio data to write + """ + wf = wave.open(filename, 'wb') + + # 设置WAV文件的参数 + wf.setnchannels(self.CHANNELS) + wf.setsampwidth(self.p.get_sample_size(self.FORMAT)) + wf.setframerate(self.RATE) + wf.writeframes(b''.join(frames)) + wf.close() + if isinstance(filename, io.BytesIO): + filename.seek(0) # reset file pointer to beginning + return filename + + def write_wave_bytes(self, frames): + """Write audio data to a bytes object.""" + return b''.join(frames) +class BaseAudio: + def __init__(self, + filename=None, + input=False, + output=False, + CHUNK=1024, + FORMAT=pyaudio.paInt16, + CHANNELS=1, + RATE=16000, + input_device_index=None, + output_device_index=None, + **kwargs): + self.CHUNK = CHUNK + self.FORMAT = FORMAT + self.CHANNELS = CHANNELS + self.RATE = RATE + self.filename = filename + assert input!= output, "input and output cannot be the same, \ + but got input={} and output={}.".format(input, output) + print("------------------------------------------") + print(f"{'Input' if input else 'Output'} Audio Initialization: ") + print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}") + print("------------------------------------------") + self.p = pyaudio.PyAudio() + self.stream = self.p.open(format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=input, + output=output, + input_device_index=input_device_index, + output_device_index=output_device_index, + **kwargs) + + def load_audio_file(self, wav_file): + with wave.open(wav_file, 'rb') as wf: + params = wf.getparams() + frames = wf.readframes(params.nframes) + print("Audio file loaded.") + # Audio Parameters + # print("Channels:", params.nchannels) + # print("Sample width:", params.sampwidth) + # print("Frame rate:", params.framerate) + # print("Number of frames:", params.nframes) + # print("Compression type:", params.comptype) + return frames + + def check_audio_type(self, audio_data, return_type=None): + assert return_type in ['bytes', 'io', None], \ + "return_type should be 'bytes', 'io' or None." + if isinstance(audio_data, str): + if len(audio_data) > 50: + audio_data = decode_str2bytes(audio_data) + else: + assert os.path.isfile(audio_data), \ + "audio_data should be a file path or a bytes object." + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype == np.dtype('float32'): + audio_data = np.int16(audio_data * np.iinfo(np.int16).max) + audio_data = audio_data.tobytes() + elif isinstance(audio_data, bytes): + pass + else: + raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \ + but got {type(audio_data)}") + + if return_type == None: + return audio_data + return self.write_wave(None, [audio_data], return_type) + + def write_wave(self, filename, frames, return_type='io'): + """Write audio data to a file.""" + if isinstance(frames, bytes): + frames = [frames] + if not isinstance(frames, list): + raise TypeError("frames should be \ + a list of bytes or a bytes object, \ + but got {}.".format(type(frames))) + + if return_type == 'io': + if filename is None: + filename = io.BytesIO() + if self.filename: + filename = self.filename + return self.write_wave_io(filename, frames) + elif return_type == 'bytes': + return self.write_wave_bytes(frames) + + + def write_wave_io(self, filename, frames): + """ + Write audio data to a file-like object. + + Args: + filename: [string or file-like object], file path or file-like object to write + frames: list of bytes, audio data to write + """ + wf = wave.open(filename, 'wb') + + # 设置WAV文件的参数 + wf.setnchannels(self.CHANNELS) + wf.setsampwidth(self.p.get_sample_size(self.FORMAT)) + wf.setframerate(self.RATE) + wf.writeframes(b''.join(frames)) + wf.close() + if isinstance(filename, io.BytesIO): + filename.seek(0) # reset file pointer to beginning + return filename + + def write_wave_bytes(self, frames): + """Write audio data to a bytes object.""" + return b''.join(frames) + + +class BaseRecorder(BaseAudio): + def __init__(self, + input=True, + base_chunk_size=None, + RATE=16000, + **kwargs): + super().__init__(input=input, RATE=RATE, **kwargs) + self.base_chunk_size = base_chunk_size + if base_chunk_size is None: + self.base_chunk_size = self.CHUNK + + def record(self, + filename, + duration=5, + return_type='io', + logger=None): + if logger is not None: + logger.info("Recording started.") + else: + print("Recording started.") + frames = [] + for i in range(0, int(self.RATE / self.CHUNK * duration)): + data = self.stream.read(self.CHUNK, exception_on_overflow=False) + frames.append(data) + if logger is not None: + logger.info("Recording stopped.") + else: + print("Recording stopped.") + return self.write_wave(filename, frames, return_type) + + def record_chunk_voice(self, + return_type='bytes', + CHUNK=None, + exception_on_overflow=True, + queue=None): + data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK, + exception_on_overflow=exception_on_overflow) + if return_type is not None: + return self.write_wave(None, [data], return_type) + return data diff --git a/examples/modified_funasr_demo.py b/examples/modified_funasr_demo.py new file mode 100644 index 0000000..6ce79c4 --- /dev/null +++ b/examples/modified_funasr_demo.py @@ -0,0 +1,39 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +from audio_utils import BaseRecorder +from utils.stt.modified_funasr import ModifiedRecognizer + + + +def asr_file_stream(file_path=r'.\assets\example_recording.wav'): + # 读入音频文件 + rec = BaseRecorder() + data = rec.load_audio_file(file_path) + + # 创建模型 + asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True) + asr.session_signup("test") + + # 记录目标说话人 + asr.initialize_speaker(r".\assets\example_recording.wav") + + # 语音识别 + print("===============================================") + text_dict = asr.streaming_recognize("test", data, auto_det_end=True) + print(f"text_dict: {text_dict}") + + if not isinstance(text_dict, str): + print("".join(text_dict['text'])) + + # 情感识别 + print("===============================================") + emotion_dict = asr.recognize_emotion(data) + print(f"emotion_dict: {emotion_dict}") + if not isinstance(emotion_dict, str): + max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) + print("emotion: " +emotion_dict['labels'][max_index]) + + +asr_file_stream() \ No newline at end of file diff --git a/examples/speaker_embedding/README.md b/examples/speaker_embedding/README.md new file mode 100644 index 0000000..cf6570b --- /dev/null +++ b/examples/speaker_embedding/README.md @@ -0,0 +1 @@ +存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH \ No newline at end of file diff --git a/examples/speaker_embedding/example_recording.npy b/examples/speaker_embedding/example_recording.npy new file mode 100644 index 0000000..5bb6c85 Binary files /dev/null and b/examples/speaker_embedding/example_recording.npy differ diff --git a/utils/stt/emotion_utils.py b/utils/stt/emotion_utils.py new file mode 100644 index 0000000..8d38423 --- /dev/null +++ b/utils/stt/emotion_utils.py @@ -0,0 +1,142 @@ +import io +import numpy as np +import base64 +import wave +from funasr import AutoModel +import time +""" + Base模型 + 不能进行情绪分类,只能用作特征提取 +""" +FUNASRBASE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base", + "model_revision": "v2.0.4" +} + + +""" + Finetune模型 + 输出分类结果 +""" +FUNASRFINETUNE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base_finetuned" +} + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class Emotion: + def __init__(self, + model_type="funasr", + model_path="iic/emotion2vec_base", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type = model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + # 初始化模型 + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == "funasr": + self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.") + + # 检查输入类型 + def check_audio_type(self, + audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + pass + else: + raise TypeError(f"audio_data must be bytes, list, str, \ + io.BytesIO or numpy array, but got {type(audio_data)}") + + if isinstance(audio_data, bytes): + audio_data = np.frombuffer(audio_data, dtype=np.int16) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + else: + raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") + + # 输入类型必须是float32 + if isinstance(audio_data, np.ndarray): + audio_data = audio_data.astype(np.float32) + else: + raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}") + return audio_data + + def process(self, + audio_data, + granularity="utterance", + extract_embedding=False, + output_dir=None, + only_score=True): + """ + audio_data: only float32 expected beacause layernorm + extract_embedding: save embedding if true + output_dir: save path for embedding + only_Score: only return lables & scores if true + """ + audio_data = self.check_audio_type(audio_data) + if self.model_type == 'funasr': + result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding) + else: + pass + + # 只保留 lables 和 scores + if only_score: + maintain_key = ["labels", "scores"] + for res in result: + keys_to_remove = [k for k in res.keys() if k not in maintain_key] + for k in keys_to_remove: + res.pop(k) + return result[0] + +# only for test +def load_audio_file(wav_file): + with wave.open(wav_file, 'rb') as wf: + params = wf.getparams() + frames = wf.readframes(params.nframes) + print("Audio file loaded.") + # Audio Parameters + # print("Channels:", params.nchannels) + # print("Sample width:", params.sampwidth) + # print("Frame rate:", params.framerate) + # print("Number of frames:", params.nframes) + # print("Compression type:", params.comptype) + return frames + +if __name__ == "__main__": + inputs = r".\example\test.wav" + inputs = load_audio_file(inputs) + device = "cuda" + # FUNASRBASE.update({"device": device}) + FUNASRFINETUNE.update({"deivce": device}) + emotion_model = Emotion(**FUNASRFINETUNE) + s = time.time() + result = emotion_model.process(inputs) + t = time.time() + print(t - s) + print(result) \ No newline at end of file diff --git a/utils/stt/modified_funasr.py b/utils/stt/modified_funasr.py new file mode 100644 index 0000000..3fd6ca8 --- /dev/null +++ b/utils/stt/modified_funasr.py @@ -0,0 +1,209 @@ +from .funasr_utils import FunAutoSpeechRecognizer +from .punctuation_utils import CTTRANSFORMER, Punctuation +from .emotion_utils import FUNASRFINETUNE, Emotion +from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication +import os + +import numpy as np +class ModifiedRecognizer(FunAutoSpeechRecognizer): + def __init__(self, + use_punct=True, + use_emotion=False, + use_speaker_ver=True): + + # 创建基础的 funasr模型,用于语音识别,识别出不带标点的句子 + super().__init__( + model_path="paraformer-zh-streaming", + device="cuda", + RATE=16000, + cfg_path=None, + debug=False, + chunk_ms=480, + encoder_chunk_look_back=4, + decoder_chunk_look_back=1) + + # 记录是否具备附加功能 + self.use_punct = use_punct + self.use_emotion = use_emotion + self.use_speaker_ver = use_speaker_ver + + # 增加标点模型 + if use_punct: + self.puctuation_model = Punctuation(**CTTRANSFORMER) + + # 情绪识别模型 + if use_emotion: + self.emotion_model = Emotion(**FUNASRFINETUNE) + + # 说话人识别模型 + if use_speaker_ver: + self.speaker_ver_model = speaker_verfication(**ERES2NETV2) + + def initialize_speaker(self, speaker_1_wav): + """ + 用于说话人识别,将输入的音频(speaker_1_wav)设立为目标说话人,并将其特征保存本地 + """ + if not self.use_speaker_ver: + raise NotImplementedError("no access") + if speaker_1_wav.endswith(".npy"): + self.save_speaker_path = speaker_1_wav + elif speaker_1_wav.endswith('.wav'): + self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH, + os.path.basename(speaker_1_wav).replace(".wav", ".npy")) + # self.save_speaker_path = DEFALUT_SAVE_PATH + self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path) + else: + raise TypeError("only support [.npy] or [.wav].") + + + def speaker_ver(self, speaker_2_wav): + """ + 用于说话人识别,判断输入音频是否为目标说话人, + 是返回True,不是返回False + """ + if not self.use_speaker_ver: + raise NotImplementedError("no access") + if not hasattr(self, "save_speaker_path"): + raise NotImplementedError("please initialize speaker first") + + # self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no' + return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path, + speaker_2_wav=speaker_2_wav) == 'yes' + + + def recognize(self, audio_data): + """ + 非流式语音识别,返回识别出的文本,返回值类型 str + """ + audio_data = self.check_audio_type(audio_data) + + # 说话人识别 + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + # 语音识别 + result = self.asr_model.generate(input=audio_data, + batch_size_s=300, + hotword=self.hotwords) + text = '' + for res in result: + text += res['text'] + + # 添加标点 + if self.use_punct: + text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '') + + return text + + def recognize_emotion(self, audio_data): + """ + 情感识别,返回值为: + 1. 如果说话人非目标说话人,返回字符串 "Other People" + 2. 如果说话人为目标说话人,返回字典{"Labels": List[str], "scores": List[int]} + """ + audio_data = self.check_audio_type(audio_data) + + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + if self.use_emotion: + return self.emotion_model.process(audio_data) + else: + raise NotImplementedError("no access") + + def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False): + """recognize partial result + + Args: + audio_data: bytes or numpy array, partial audio data + is_end: bool, whether the audio data is the end of a sentence + auto_det_end: bool, whether to automatically detect the end of a audio data + + 流式语音识别,返回值为: + 1. 如果说话人非目标说话人,返回字符串 "Other People" + 2. 如果说话人为目标说话人,返回字典{"test": List[str], "is_end": boolean} + """ + audio_cache = self.audio_cache[session_id] + asr_cache = self.asr_cache[session_id] + text_dict = dict(text=[], is_end=is_end) + audio_data = self.check_audio_type(audio_data) + + # 说话人识别 + if self.use_speaker_ver: + if self.speaker_ver_model.verfication(self.save_speaker_path, + speaker_2_wav=audio_data) == 'no': + return "Other People" + + # 语音识别 + if audio_cache is None: + audio_cache = audio_data + else: + # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") + if audio_cache.shape[0] > 0: + audio_cache = np.concatenate([audio_cache, audio_data], axis=0) + + if not is_end and audio_cache.shape[0] < self.chunk_partial_size: + self.audio_cache[session_id] = audio_cache + return text_dict + + total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + + if is_end: + # if the audio data is the end of a sentence, \ + # we need to add one more chunk to the end to \ + # ensure the end of the sentence is recognized correctly. + auto_det_end = True + + if auto_det_end: + total_chunk_num += 1 + + # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") + end_idx = None + for i in range(total_chunk_num): + if auto_det_end: + is_end = i == total_chunk_num - 1 + start_idx = i*self.chunk_partial_size + if auto_det_end: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 + else: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 + # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") + # t_stamp = time.time() + + speech_chunk = audio_cache[start_idx:end_idx] + + # TODO: exceptions processes + try: + res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + except ValueError as e: + print(f"ValueError: {e}") + continue + + # 增添标点 + if self.use_punct: + text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict)) + else: + text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) + + + # print(f"each chunk time: {time.time()-t_stamp}") + + if is_end: + audio_cache = None + asr_cache = {} + else: + if end_idx: + audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache + text_dict['is_end'] = is_end + + if self.use_punct and is_end: + text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', '')) + + self.audio_cache[session_id] = audio_cache + self.asr_cache[session_id] = asr_cache + # print(f"text_dict: {text_dict}") + return text_dict \ No newline at end of file diff --git a/utils/stt/punctuation_utils.py b/utils/stt/punctuation_utils.py new file mode 100644 index 0000000..9e038e0 --- /dev/null +++ b/utils/stt/punctuation_utils.py @@ -0,0 +1,119 @@ +from funasr import AutoModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"] +""" +FUNASR + 模型大小: 1G + 效果: 较好 + 输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理 +""" +FUNASR = { + "model_type": "funasr", + "model_path": "ct-punc", + "model_revision": "v2.0.4" +} +""" +CTTRANSFORMER + 模型大小: 275M + 效果:较差 + 输入类型: 支持字符串与list, 同时支持输入cache +""" +CTTRANSFORMER = { + "model_type": "ct-transformer", + "model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + "model_revision": "v2.0.4" +} + +class Punctuation: + def __init__(self, + model_type="funasr", # funasr | ct-transformer + model_path="ct-punc", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type=model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == 'funasr': + self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + elif model_type == 'ct-transformer': + self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.") + + def check_text_type(self, + text_data): + # funasr只支持单个str输入,不支持list输入,此处将list转化为字符串 + if self.model_type == 'funasr': + if isinstance(text_data, str): + pass + elif isinstance(text_data, list): + text_data = ''.join(text_data) + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + # ct-transformer支持list输入 + # TODO 验证拆分字符串能否提高效率 + elif self.model_type == 'ct-transformer': + if isinstance(text_data, str): + text_data = [text_data] + elif isinstance(text_data, list): + pass + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + else: + pass + return text_data + + def generate_cache(self, cache): + new_cache = {'pre_text': ""} + for text in cache['text']: + if text != '': + new_cache['pre_text'] = new_cache['pre_text']+text + return new_cache + + def process(self, + text, + append_period=False, + cache={}): + if text == '': + return '' + text = self.check_text_type(text) + if self.model_type == 'funasr': + result = self.punc_model.generate(text) + elif self.model_type == 'ct-transformer': + if cache != {}: + cache = self.generate_cache(cache) + result = self.punc_model(text, cache=cache) + punced_text = '' + for res in result: + punced_text += res['text'] + # 如果最后没有标点符号,手动加上。 + if append_period and not punced_text[-1] in PUNCTUATION_MARK: + punced_text += "。" + return punced_text + +if __name__ == "__main__": + inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串" + """ + 把字符串拆分为list只适用于ct-transformer模型, + 在数据处理部分,已经把list转为单个字符串 + """ + vads = inputs.split("|") + device = "cuda" + CTTRANSFORMER.update({"device": device}) + puct_model = Punctuation(**CTTRANSFORMER) + result = puct_model.process(vads) + print(result) + # FUNASR.update({"device":"cuda"}) + # puct_model = Punctuation(**FUNASR) + # result = puct_model.process(vads) + # print(result) \ No newline at end of file diff --git a/utils/stt/speaker_ver_utils.py b/utils/stt/speaker_ver_utils.py new file mode 100644 index 0000000..4372bd0 --- /dev/null +++ b/utils/stt/speaker_ver_utils.py @@ -0,0 +1,86 @@ +from modelscope.pipelines import pipeline +import numpy as np +import os + +ERES2NETV2 = { + "task": 'speaker-verification', + "model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common', + "model_revision": 'v1.0.1', + "save_embeddings": False +} + +# 保存 embedding 的路径 +DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding") + +class speaker_verfication: + def __init__(self, + task='speaker-verification', + model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common', + model_revision='v1.0.1', + device="cuda", + save_embeddings=False): + self.pipeline = pipeline( + task=task, + model=model_name, + model_revision=model_revision, + device=device) + self.save_embeddings = save_embeddings + + def wav2embeddings(self, speaker_1_wav, save_path=None): + result = self.pipeline([speaker_1_wav], output_emb=True) + speaker_1_emb = result['embs'][0] + if save_path is not None: + np.save(save_path, speaker_1_emb) + return speaker_1_emb + + def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path): + if not self.save_embeddings: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold) + return result["text"] + else: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True) + speaker1_emb = result["embs"][0] + speaker2_emb = result["embs"][1] + np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb) + return result['outputs']["text"] + + def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold): + base_emb = np.load(base_emb) + result = self.pipeline([speaker_2_wav], output_emb=True) + speaker2_emb = result["embs"][0] + similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb)) + if similarity > threshold: + return "yes" + else: + return "no" + + def verfication(self, + base_emb=None, + speaker_1_wav=None, + speaker_2_wav=None, + threshold=0.333, + save_path=None): + if base_emb is not None and speaker_1_wav is not None: + raise ValueError("Only need one of them, base_emb or speaker_1_wav") + if base_emb is not None and speaker_2_wav is not None: + return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold) + elif speaker_1_wav is not None and speaker_2_wav is not None: + return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path) + else: + raise NotImplementedError + +if __name__ == '__main__': + verifier = speaker_verfication(**ERES2NETV2) + + verifier = speaker_verfication(save_embeddings=False) + result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav", + threshold=0.333, + save_path=r"D:\python\irving\takway_base-main\savePath" + ) + print("---") + print(result) + print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav", + threshold=0.333, + )) \ No newline at end of file