119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
|
from funasr import AutoModel
|
|||
|
from modelscope.pipelines import pipeline
|
|||
|
from modelscope.utils.constant import Tasks
|
|||
|
|
|||
|
PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"]
|
|||
|
"""
|
|||
|
FUNASR
|
|||
|
模型大小: 1G
|
|||
|
效果: 较好
|
|||
|
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
|
|||
|
"""
|
|||
|
FUNASR = {
|
|||
|
"model_type": "funasr",
|
|||
|
"model_path": "ct-punc",
|
|||
|
"model_revision": "v2.0.4"
|
|||
|
}
|
|||
|
"""
|
|||
|
CTTRANSFORMER
|
|||
|
模型大小: 275M
|
|||
|
效果:较差
|
|||
|
输入类型: 支持字符串与list, 同时支持输入cache
|
|||
|
"""
|
|||
|
CTTRANSFORMER = {
|
|||
|
"model_type": "ct-transformer",
|
|||
|
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
|||
|
"model_revision": "v2.0.4"
|
|||
|
}
|
|||
|
|
|||
|
class Punctuation:
|
|||
|
def __init__(self,
|
|||
|
model_type="funasr", # funasr | ct-transformer
|
|||
|
model_path="ct-punc",
|
|||
|
device="cuda",
|
|||
|
model_revision="v2.0.4",
|
|||
|
**kwargs):
|
|||
|
|
|||
|
self.model_type=model_type
|
|||
|
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
|||
|
|
|||
|
def initialize(self,
|
|||
|
model_type,
|
|||
|
model_path,
|
|||
|
device,
|
|||
|
model_revision,
|
|||
|
**kwargs):
|
|||
|
if model_type == 'funasr':
|
|||
|
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
|||
|
elif model_type == 'ct-transformer':
|
|||
|
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
|
|||
|
else:
|
|||
|
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
|
|||
|
|
|||
|
def check_text_type(self,
|
|||
|
text_data):
|
|||
|
# funasr只支持单个str输入,不支持list输入,此处将list转化为字符串
|
|||
|
if self.model_type == 'funasr':
|
|||
|
if isinstance(text_data, str):
|
|||
|
pass
|
|||
|
elif isinstance(text_data, list):
|
|||
|
text_data = ''.join(text_data)
|
|||
|
else:
|
|||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
|||
|
# ct-transformer支持list输入
|
|||
|
# TODO 验证拆分字符串能否提高效率
|
|||
|
elif self.model_type == 'ct-transformer':
|
|||
|
if isinstance(text_data, str):
|
|||
|
text_data = [text_data]
|
|||
|
elif isinstance(text_data, list):
|
|||
|
pass
|
|||
|
else:
|
|||
|
raise TypeError(f"text must be str or list, but got {type(list)}")
|
|||
|
else:
|
|||
|
pass
|
|||
|
return text_data
|
|||
|
|
|||
|
def generate_cache(self, cache):
|
|||
|
new_cache = {'pre_text': ""}
|
|||
|
for text in cache['text']:
|
|||
|
if text != '':
|
|||
|
new_cache['pre_text'] = new_cache['pre_text']+text
|
|||
|
return new_cache
|
|||
|
|
|||
|
def process(self,
|
|||
|
text,
|
|||
|
append_period=False,
|
|||
|
cache={}):
|
|||
|
if text == '':
|
|||
|
return ''
|
|||
|
text = self.check_text_type(text)
|
|||
|
if self.model_type == 'funasr':
|
|||
|
result = self.punc_model.generate(text)
|
|||
|
elif self.model_type == 'ct-transformer':
|
|||
|
if cache != {}:
|
|||
|
cache = self.generate_cache(cache)
|
|||
|
result = self.punc_model(text, cache=cache)
|
|||
|
punced_text = ''
|
|||
|
for res in result:
|
|||
|
punced_text += res['text']
|
|||
|
# 如果最后没有标点符号,手动加上。
|
|||
|
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
|
|||
|
punced_text += "。"
|
|||
|
return punced_text
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
|
|||
|
"""
|
|||
|
把字符串拆分为list只适用于ct-transformer模型,
|
|||
|
在数据处理部分,已经把list转为单个字符串
|
|||
|
"""
|
|||
|
vads = inputs.split("|")
|
|||
|
device = "cuda"
|
|||
|
CTTRANSFORMER.update({"device": device})
|
|||
|
puct_model = Punctuation(**CTTRANSFORMER)
|
|||
|
result = puct_model.process(vads)
|
|||
|
print(result)
|
|||
|
# FUNASR.update({"device":"cuda"})
|
|||
|
# puct_model = Punctuation(**FUNASR)
|
|||
|
# result = puct_model.process(vads)
|
|||
|
# print(result)
|