pun_emo_speaker_utils/takway/stt/punctuation_utils.py

119 lines
4.2 KiB
Python
Raw Normal View History

from funasr import AutoModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
PUNCTUATION_MARK = [",", ".", "?", "!", "", "", "", ""]
"""
FUNASR
模型大小: 1G
效果: 较好
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
"""
FUNASR = {
"model_type": "funasr",
"model_path": "ct-punc",
"model_revision": "v2.0.4"
}
"""
CTTRANSFORMER
模型大小: 275M
效果较差
输入类型: 支持字符串与list, 同时支持输入cache
"""
CTTRANSFORMER = {
"model_type": "ct-transformer",
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
"model_revision": "v2.0.4"
}
class Punctuation:
def __init__(self,
model_type="funasr", # funasr | ct-transformer
model_path="ct-punc",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type=model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == 'funasr':
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
elif model_type == 'ct-transformer':
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
def check_text_type(self,
text_data):
# funasr只支持单个str输入不支持list输入此处将list转化为字符串
if self.model_type == 'funasr':
if isinstance(text_data, str):
pass
elif isinstance(text_data, list):
text_data = ''.join(text_data)
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
# ct-transformer支持list输入
# TODO 验证拆分字符串能否提高效率
elif self.model_type == 'ct-transformer':
if isinstance(text_data, str):
text_data = [text_data]
elif isinstance(text_data, list):
pass
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
else:
pass
return text_data
def generate_cache(self, cache):
new_cache = {'pre_text': ""}
for text in cache['text']:
if text != '':
new_cache['pre_text'] = new_cache['pre_text']+text
return new_cache
def process(self,
text,
append_period=False,
cache={}):
if text == '':
return ''
text = self.check_text_type(text)
if self.model_type == 'funasr':
result = self.punc_model.generate(text)
elif self.model_type == 'ct-transformer':
if cache != {}:
cache = self.generate_cache(cache)
result = self.punc_model(text, cache=cache)
punced_text = ''
for res in result:
punced_text += res['text']
# 如果最后没有标点符号,手动加上。
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
punced_text += ""
return punced_text
if __name__ == "__main__":
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
"""
把字符串拆分为list只适用于ct-transformer模型,
在数据处理部分,已经把list转为单个字符串
"""
vads = inputs.split("|")
device = "cuda"
CTTRANSFORMER.update({"device": device})
puct_model = Punctuation(**CTTRANSFORMER)
result = puct_model.process(vads)
print(result)
# FUNASR.update({"device":"cuda"})
# puct_model = Punctuation(**FUNASR)
# result = puct_model.process(vads)
# print(result)