119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
from funasr import AutoModel
|
||
from modelscope.pipelines import pipeline
|
||
from modelscope.utils.constant import Tasks
|
||
|
||
PUNCTUATION_MARK = [",", ".", "?", "!", ",", "。", "?", "!"]
|
||
"""
|
||
FUNASR
|
||
模型大小: 1G
|
||
效果: 较好
|
||
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
|
||
"""
|
||
FUNASR = {
|
||
"model_type": "funasr",
|
||
"model_path": "ct-punc",
|
||
"model_revision": "v2.0.4"
|
||
}
|
||
"""
|
||
CTTRANSFORMER
|
||
模型大小: 275M
|
||
效果:较差
|
||
输入类型: 支持字符串与list, 同时支持输入cache
|
||
"""
|
||
CTTRANSFORMER = {
|
||
"model_type": "ct-transformer",
|
||
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
||
"model_revision": "v2.0.4"
|
||
}
|
||
|
||
class Punctuation:
|
||
def __init__(self,
|
||
model_type="funasr", # funasr | ct-transformer
|
||
model_path="ct-punc",
|
||
device="cuda",
|
||
model_revision="v2.0.4",
|
||
**kwargs):
|
||
|
||
self.model_type=model_type
|
||
self.initialize(model_type, model_path, device, model_revision, **kwargs)
|
||
|
||
def initialize(self,
|
||
model_type,
|
||
model_path,
|
||
device,
|
||
model_revision,
|
||
**kwargs):
|
||
if model_type == 'funasr':
|
||
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
|
||
elif model_type == 'ct-transformer':
|
||
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
|
||
else:
|
||
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
|
||
|
||
def check_text_type(self,
|
||
text_data):
|
||
# funasr只支持单个str输入,不支持list输入,此处将list转化为字符串
|
||
if self.model_type == 'funasr':
|
||
if isinstance(text_data, str):
|
||
pass
|
||
elif isinstance(text_data, list):
|
||
text_data = ''.join(text_data)
|
||
else:
|
||
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||
# ct-transformer支持list输入
|
||
# TODO 验证拆分字符串能否提高效率
|
||
elif self.model_type == 'ct-transformer':
|
||
if isinstance(text_data, str):
|
||
text_data = [text_data]
|
||
elif isinstance(text_data, list):
|
||
pass
|
||
else:
|
||
raise TypeError(f"text must be str or list, but got {type(list)}")
|
||
else:
|
||
pass
|
||
return text_data
|
||
|
||
def generate_cache(self, cache):
|
||
new_cache = {'pre_text': ""}
|
||
for text in cache['text']:
|
||
if text != '':
|
||
new_cache['pre_text'] = new_cache['pre_text']+text
|
||
return new_cache
|
||
|
||
def process(self,
|
||
text,
|
||
append_period=False,
|
||
cache={}):
|
||
if text == '':
|
||
return ''
|
||
text = self.check_text_type(text)
|
||
if self.model_type == 'funasr':
|
||
result = self.punc_model.generate(text)
|
||
elif self.model_type == 'ct-transformer':
|
||
if cache != {}:
|
||
cache = self.generate_cache(cache)
|
||
result = self.punc_model(text, cache=cache)
|
||
punced_text = ''
|
||
for res in result:
|
||
punced_text += res['text']
|
||
# 如果最后没有标点符号,手动加上。
|
||
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
|
||
punced_text += "。"
|
||
return punced_text
|
||
|
||
if __name__ == "__main__":
|
||
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
|
||
"""
|
||
把字符串拆分为list只适用于ct-transformer模型,
|
||
在数据处理部分,已经把list转为单个字符串
|
||
"""
|
||
vads = inputs.split("|")
|
||
device = "cuda"
|
||
CTTRANSFORMER.update({"device": device})
|
||
puct_model = Punctuation(**CTTRANSFORMER)
|
||
result = puct_model.process(vads)
|
||
print(result)
|
||
# FUNASR.update({"device":"cuda"})
|
||
# puct_model = Punctuation(**FUNASR)
|
||
# result = puct_model.process(vads)
|
||
# print(result) |