From 54d13fba87433cc64171c93de2af5c688f4f82b8 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Sat, 18 May 2024 17:10:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E6=83=85=E6=84=9F?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E4=B8=8E=E6=A0=87=E7=82=B9=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controllers/chat_controller.py | 36 +++++++- utils/stt/emotion_utils.py | 142 +++++++++++++++++++++++++++++ utils/stt/modified_funasr.py | 29 ++++++ utils/stt/punctuation_utils.py | 119 ++++++++++++++++++++++++ utils/stt/speaker_ver_utils.py | 75 +++++++++++++++ 5 files changed, 398 insertions(+), 3 deletions(-) create mode 100644 utils/stt/emotion_utils.py create mode 100644 utils/stt/modified_funasr.py create mode 100644 utils/stt/punctuation_utils.py create mode 100644 utils/stt/speaker_ver_utils.py diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 2c42bfb..5a5e887 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -16,9 +16,9 @@ import aiohttp logger = get_logger() # --------------------初始化本地ASR----------------------- -from utils.stt.funasr_utils import FunAutoSpeechRecognizer +from utils.stt.modified_funasr import ModifiedRecognizer -asr = FunAutoSpeechRecognizer() +asr = ModifiedRecognizer() logger.info("本地ASR初始化成功") # ------------------------------------------------------- @@ -60,6 +60,10 @@ def parseChunkDelta(chunk): return "end" except KeyError: logger.error(f"error chunk: {chunk}") + return "" + except json.JSONDecodeError: + logger.error(f"error chunk: {chunk}") + return "" #断句函数 def split_string_with_punctuation(current_sentence,text,is_first,is_end): @@ -224,6 +228,7 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") is_signup = False + audio = "" try: current_message = "" while not (user_input_finish_event.is_set() and user_input_q.empty()): @@ -231,10 +236,16 @@ async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_ asr.session_signup(session_id) is_signup = True audio_data = await user_input_q.get() + audio += audio_data asr_result = asr.streaming_recognize(session_id,audio_data) current_message += ''.join(asr_result['text']) asr_result = asr.streaming_recognize(session_id,b'',is_end=True) current_message += ''.join(asr_result['text']) + current_message = asr.punctuation_correction(current_message) + emotion_dict = asr.emtion_recognition(audio) #情感辨识 + if not isinstance(emotion_dict, str): + max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) + current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}" await llm_input_q.put(current_message) asr.session_signout(session_id) except Exception as e: @@ -371,6 +382,7 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve logger.debug("语音识别函数启动") is_signup = False current_message = "" + audio = "" while not (input_finished_event.is_set() and user_input_q.empty()): try: aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3) @@ -380,15 +392,24 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve if aduio_frame['is_end']: asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True) current_message += ''.join(asr_result['text']) + current_message = asr.punctuation_correction(current_message) + audio += aduio_frame['audio'] + emotion_dict =asr.emtion_recognition(audio) #情感辨识 + if not isinstance(emotion_dict, str): + max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) + current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}" + print(current_message) await llm_input_q.put(current_message) + current_message = "" + audio = "" logger.debug(f"接收到用户消息: {current_message}") else: asr_result = asr.streaming_recognize(session_id,aduio_frame['audio']) + audio += aduio_frame['audio'] current_message += ''.join(asr_result['text']) except asyncio.TimeoutError: continue except Exception as e: - asr.session_signout(session_id) logger.error(f"语音识别函数发生错误: {str(e)}") break asr.session_signout(session_id) @@ -523,6 +544,7 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin current_message = "" vad_count = 0 is_signup = False + audio = "" while not (input_finished_event.is_set() and audio_q.empty()): try: if not is_signup: @@ -533,14 +555,22 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin if vad_count > 0: vad_count -= 1 asr_result = asr.streaming_recognize(session_id, audio_data) + audio += audio_data current_message += ''.join(asr_result['text']) else: vad_count += 1 if vad_count >= 25: #连续25帧没有语音,则认为说完了 asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True) if current_message: + current_message = asr.punctuation_correction(current_message) + audio += audio_data + emotion_dict =asr.emtion_recognition(audio) #情感辨识 + if not isinstance(emotion_dict, str): + max_index = emotion_dict['scores'].index(max(emotion_dict['scores'])) + current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}" logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) + audio = "" text_response = {"type": "user_text", "code": 200, "msg": current_message} await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 current_message = "" diff --git a/utils/stt/emotion_utils.py b/utils/stt/emotion_utils.py new file mode 100644 index 0000000..103e33b --- /dev/null +++ b/utils/stt/emotion_utils.py @@ -0,0 +1,142 @@ +import io +import numpy as np +import base64 +import wave +from funasr import AutoModel +import time +""" + Base模型 + 不能进行情绪分类,只能用作特征提取 +""" +FUNASRBASE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base", + "model_revision": "v2.0.4" +} + + +""" + Finetune模型 + 输出分类结果 +""" +FUNASRFINETUNE = { + "model_type": "funasr", + "model_path": "iic/emotion2vec_base_finetuned" +} + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class Emotion: + def __init__(self, + model_type="funasr", + model_path="iic/emotion2vec_base", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type = model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + # 初始化模型 + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == "funasr": + self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.") + + # 检查输入类型 + def check_audio_type(self, + audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + pass + else: + raise TypeError(f"audio_data must be bytes, list, str, \ + io.BytesIO or numpy array, but got {type(audio_data)}") + + if isinstance(audio_data, bytes): + audio_data = np.frombuffer(audio_data, dtype=np.int16) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + else: + raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") + + # 输入类型必须是float32 + if isinstance(audio_data, np.ndarray): + audio_data = audio_data.astype(np.float32) + else: + raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}") + return audio_data + + def process(self, + audio_data, + granularity="utterance", + extract_embedding=False, + output_dir=None, + only_score=True): + """ + audio_data: only float32 expected beacause layernorm + extract_embedding: save embedding if true + output_dir: save path for embedding + only_Score: only return lables & scores if true + """ + audio_data = self.check_audio_type(audio_data) + if self.model_type == 'funasr': + result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding) + else: + pass + + # 只保留 lables 和 scores + if only_score: + maintain_key = ["labels", "scores"] + for res in result: + keys_to_remove = [k for k in res.keys() if k not in maintain_key] + for k in keys_to_remove: + res.pop(k) + return result[0] + +# only for test +def load_audio_file(wav_file): + with wave.open(wav_file, 'rb') as wf: + params = wf.getparams() + frames = wf.readframes(params.nframes) + print("Audio file loaded.") + # Audio Parameters + # print("Channels:", params.nchannels) + # print("Sample width:", params.sampwidth) + # print("Frame rate:", params.framerate) + # print("Number of frames:", params.nframes) + # print("Compression type:", params.comptype) + return frames + +if __name__ == "__main__": + inputs = r".\example\test.wav" + inputs = load_audio_file(inputs) + device = "cuda" + # FUNASRBASE.update({"device": device}) + FUNASRFINETUNE.update({"deivce": device}) + emotion_model = Emotion(**FUNASRFINETUNE) + s = time.time() + result = emotion_model.process(inputs) + t = time.time() + print(t - s) + print(result) \ No newline at end of file diff --git a/utils/stt/modified_funasr.py b/utils/stt/modified_funasr.py new file mode 100644 index 0000000..ea8f0dc --- /dev/null +++ b/utils/stt/modified_funasr.py @@ -0,0 +1,29 @@ +from .funasr_utils import FunAutoSpeechRecognizer +from .punctuation_utils import FUNASR, Punctuation +from .emotion_utils import FUNASRFINETUNE, Emotion + +class ModifiedRecognizer(): + def __init__(self): + #增加语音识别模型 + self.asr_model = FunAutoSpeechRecognizer() + + #增加标点模型 + self.puctuation_model = Punctuation(**FUNASR) + + # 情绪识别模型 + self.emotion_model = Emotion(**FUNASRFINETUNE) + + def session_signup(self, session_id): + self.asr_model.session_signup(session_id) + + def session_signout(self, session_id): + self.asr_model.session_signout(session_id) + + def streaming_recognize(self, session_id, audio_data,is_end=False): + return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end) + + def punctuation_correction(self, sentence): + return self.puctuation_model.process(sentence) + + def emtion_recognition(self, audio): + return self.emotion_model.process(audio) \ No newline at end of file diff --git a/utils/stt/punctuation_utils.py b/utils/stt/punctuation_utils.py new file mode 100644 index 0000000..01d2b6d --- /dev/null +++ b/utils/stt/punctuation_utils.py @@ -0,0 +1,119 @@ +from funasr import AutoModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"] +""" +FUNASR + 模型大小: 1G + 效果: 较好 + 输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理 +""" +FUNASR = { + "model_type": "funasr", + "model_path": "ct-punc", + "model_revision": "v2.0.4" +} +""" +CTTRANSFORMER + 模型大小: 275M + 效果:较差 + 输入类型: 支持字符串与list, 同时支持输入cache +""" +CTTRANSFORMER = { + "model_type": "ct-transformer", + "model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + "model_revision": "v2.0.4" +} + +class Punctuation: + def __init__(self, + model_type="funasr", # funasr | ct-transformer + model_path="ct-punc", + device="cuda", + model_revision="v2.0.4", + **kwargs): + + self.model_type=model_type + self.initialize(model_type, model_path, device, model_revision, **kwargs) + + def initialize(self, + model_type, + model_path, + device, + model_revision, + **kwargs): + if model_type == 'funasr': + self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) + elif model_type == 'ct-transformer': + self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs) + else: + raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.") + + def check_text_type(self, + text_data): + # funasr只支持单个str输入,不支持list输入,此处将list转化为字符串 + if self.model_type == 'funasr': + if isinstance(text_data, str): + pass + elif isinstance(text_data, list): + text_data = ''.join(text_data) + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + # ct-transformer支持list输入 + # TODO 验证拆分字符串能否提高效率 + elif self.model_type == 'ct-transformer': + if isinstance(text_data, str): + text_data = [text_data] + elif isinstance(text_data, list): + pass + else: + raise TypeError(f"text must be str or list, but got {type(list)}") + else: + pass + return text_data + + def generate_cache(self, cache): + new_cache = {'pre_text': ""} + for text in cache['text']: + if text != '': + new_cache['pre_text'] = new_cache['pre_text']+text + return new_cache + + def process(self, + text, + append_period=False, + cache={}): + if text == '': + return '' + text = self.check_text_type(text) + if self.model_type == 'funasr': + result = self.punc_model.generate(text) + elif self.model_type == 'ct-transformer': + if cache != {}: + cache = self.generate_cache(cache) + result = self.punc_model(text, cache=cache) + punced_text = '' + for res in result: + punced_text += res['text'] + # 如果最后没有标点符号,手动加上。 + if append_period and not punced_text[-1] in PUNCTUATION_MARK: + punced_text += "。" + return punced_text + +if __name__ == "__main__": + inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串" + """ + 把字符串拆分为list只适用于ct-transformer模型, + 在数据处理部分,已经把list转为单个字符串 + """ + vads = inputs.split("|") + device = "cuda" + CTTRANSFORMER.update({"device": device}) + puct_model = Punctuation(**CTTRANSFORMER) + result = puct_model.process(vads) + print(result) + # FUNASR.update({"device":"cuda"}) + # puct_model = Punctuation(**FUNASR) + # result = puct_model.process(vads) + # print(result) \ No newline at end of file diff --git a/utils/stt/speaker_ver_utils.py b/utils/stt/speaker_ver_utils.py new file mode 100644 index 0000000..684692a --- /dev/null +++ b/utils/stt/speaker_ver_utils.py @@ -0,0 +1,75 @@ +from modelscope.pipelines import pipeline +import numpy as np +import os +import pdb +ERES2NETV2 = { + "task": 'speaker-verification', + "model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common', + "model_revision": 'v1.0.1', + "save_embeddings": False +} + +# 保存 embedding 的路径 +DEFALUT_SAVE_PATH = r".\takway\savePath" + +class speaker_verfication: + def __init__(self, + task='speaker-verification', + model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common', + model_revision='v1.0.1', + device="cuda", + save_embeddings=False): + self.pipeline = pipeline( + task=task, + model=model_name, + model_revision=model_revision, + device=device) + self.save_embeddings = save_embeddings + + def wav2embeddings(self, speaker_1_wav): + result = self.pipeline([speaker_1_wav], output_emb=True) + speaker_1_emb = result['embs'][0] + return speaker_1_emb + + def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path): + if not self.save_embeddings: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold) + return result["text"] + else: + result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True) + speaker1_emb = result["embs"][0] + speaker2_emb = result["embs"][1] + np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb) + return result['outputs']["text"] + + def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold): + base_emb = np.load(base_emb) + result = self.pipeline([speaker_2_wav], output_emb=True) + speaker2_emb = result["embs"][0] + similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb)) + if similarity > threshold: + return "yes" + else: + return "no" + + def verfication(self, + base_emb, + speaker_emb, + threshold=0.333, ): + return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold + +if __name__ == '__main__': + verifier = speaker_verfication(**ERES2NETV2) + + verifier = speaker_verfication(save_embeddings=False) + result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav", + threshold=0.333, + save_path=r"D:\python\irving\takway_base-main\savePath" + ) + print("---") + print(result) + print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy", + speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav", + threshold=0.333, + )) \ No newline at end of file