1
0
Fork 0
TakwayPlatform/utils/stt/punctuation_utils.py

119 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from funasr import AutoModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
PUNCTUATION_MARK = [",", ".", "?", "!", "", "", "", ""]
"""
FUNASR
模型大小: 1G
效果: 较好
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
"""
FUNASR = {
"model_type": "funasr",
"model_path": "ct-punc",
"model_revision": "v2.0.4"
}
"""
CTTRANSFORMER
模型大小: 275M
效果:较差
输入类型: 支持字符串与list, 同时支持输入cache
"""
CTTRANSFORMER = {
"model_type": "ct-transformer",
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
"model_revision": "v2.0.4"
}
class Punctuation:
def __init__(self,
model_type="funasr", # funasr | ct-transformer
model_path="ct-punc",
device="cuda",
model_revision="v2.0.4",
**kwargs):
self.model_type=model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs)
def initialize(self,
model_type,
model_path,
device,
model_revision,
**kwargs):
if model_type == 'funasr':
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
elif model_type == 'ct-transformer':
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
def check_text_type(self,
text_data):
# funasr只支持单个str输入不支持list输入此处将list转化为字符串
if self.model_type == 'funasr':
if isinstance(text_data, str):
pass
elif isinstance(text_data, list):
text_data = ''.join(text_data)
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
# ct-transformer支持list输入
# TODO 验证拆分字符串能否提高效率
elif self.model_type == 'ct-transformer':
if isinstance(text_data, str):
text_data = [text_data]
elif isinstance(text_data, list):
pass
else:
raise TypeError(f"text must be str or list, but got {type(list)}")
else:
pass
return text_data
def generate_cache(self, cache):
new_cache = {'pre_text': ""}
for text in cache['text']:
if text != '':
new_cache['pre_text'] = new_cache['pre_text']+text
return new_cache
def process(self,
text,
append_period=False,
cache={}):
if text == '':
return ''
text = self.check_text_type(text)
if self.model_type == 'funasr':
result = self.punc_model.generate(text)
elif self.model_type == 'ct-transformer':
if cache != {}:
cache = self.generate_cache(cache)
result = self.punc_model(text, cache=cache)
punced_text = ''
for res in result:
punced_text += res['text']
# 如果最后没有标点符号,手动加上。
if append_period and not punced_text[-1] in PUNCTUATION_MARK:
punced_text += ""
return punced_text
if __name__ == "__main__":
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
"""
把字符串拆分为list只适用于ct-transformer模型,
在数据处理部分,已经把list转为单个字符串
"""
vads = inputs.split("|")
device = "cuda"
CTTRANSFORMER.update({"device": device})
puct_model = Punctuation(**CTTRANSFORMER)
result = puct_model.process(vads)
print(result)
# FUNASR.update({"device":"cuda"})
# puct_model = Punctuation(**FUNASR)
# result = puct_model.process(vads)
# print(result)