from funasr import AutoModel from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"] """ FUNASR 模型大小: 1G 效果: 较好 输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理 """ FUNASR = { "model_type": "funasr", "model_path": "ct-punc", "model_revision": "v2.0.4" } """ CTTRANSFORMER 模型大小: 275M 效果:较差 输入类型: 支持字符串与list, 同时支持输入cache """ CTTRANSFORMER = { "model_type": "ct-transformer", "model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", "model_revision": "v2.0.4" } class Punctuation: def __init__(self, model_type="funasr", # funasr | ct-transformer model_path="ct-punc", device="cuda", model_revision="v2.0.4", **kwargs): self.model_type=model_type self.initialize(model_type, model_path, device, model_revision, **kwargs) def initialize(self, model_type, model_path, device, model_revision, **kwargs): if model_type == 'funasr': self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) elif model_type == 'ct-transformer': self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs) else: raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.") def check_text_type(self, text_data): # funasr只支持单个str输入,不支持list输入,此处将list转化为字符串 if self.model_type == 'funasr': if isinstance(text_data, str): pass elif isinstance(text_data, list): text_data = ''.join(text_data) else: raise TypeError(f"text must be str or list, but got {type(list)}") # ct-transformer支持list输入 # TODO 验证拆分字符串能否提高效率 elif self.model_type == 'ct-transformer': if isinstance(text_data, str): text_data = [text_data] elif isinstance(text_data, list): pass else: raise TypeError(f"text must be str or list, but got {type(list)}") else: pass return text_data def generate_cache(self, cache): new_cache = {'pre_text': ""} for text in cache['text']: if text != '': new_cache['pre_text'] = new_cache['pre_text']+text return new_cache def process(self, text, append_period=False, cache={}): if text == '': return '' text = self.check_text_type(text) if self.model_type == 'funasr': result = self.punc_model.generate(text) elif self.model_type == 'ct-transformer': if cache != {}: cache = self.generate_cache(cache) result = self.punc_model(text, cache=cache) punced_text = '' for res in result: punced_text += res['text'] # 如果最后没有标点符号,手动加上。 if append_period and not punced_text[-1] in PUNCTUATION_MARK: punced_text += "。" return punced_text if __name__ == "__main__": inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串" """ 把字符串拆分为list只适用于ct-transformer模型, 在数据处理部分,已经把list转为单个字符串 """ vads = inputs.split("|") device = "cuda" CTTRANSFORMER.update({"device": device}) puct_model = Punctuation(**CTTRANSFORMER) result = puct_model.process(vads) print(result) # FUNASR.update({"device":"cuda"}) # puct_model = Punctuation(**FUNASR) # result = puct_model.process(vads) # print(result)