35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
|
import os
|
|||
|
import json
|
|||
|
import torch
|
|||
|
import logging
|
|||
|
import concurrent.futures
|
|||
|
import librosa
|
|||
|
import torch.distributed as dist
|
|||
|
from typing import Collection
|
|||
|
import torch
|
|||
|
import torchaudio
|
|||
|
from torch import nn
|
|||
|
import random
|
|||
|
import re
|
|||
|
import string
|
|||
|
from funasr.tokenizer.cleaner import TextCleaner
|
|||
|
from funasr.register import tables
|
|||
|
|
|||
|
|
|||
|
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
|
|||
|
class TextPreprocessRemovePunctuation(nn.Module):
|
|||
|
def __init__(self, **kwargs):
|
|||
|
super().__init__()
|
|||
|
|
|||
|
def forward(self, text, **kwargs):
|
|||
|
# 定义英文标点符号
|
|||
|
en_punct = string.punctuation
|
|||
|
# 定义中文标点符号(部分常用的)
|
|||
|
cn_punct = "。?!,、;:“”‘’()《》【】…—~·"
|
|||
|
# 合并英文和中文标点符号
|
|||
|
all_punct = en_punct + cn_punct
|
|||
|
# 创建正则表达式模式,匹配任何在all_punct中的字符
|
|||
|
punct_pattern = re.compile("[{}]".format(re.escape(all_punct)))
|
|||
|
# 使用正则表达式的sub方法替换掉这些字符
|
|||
|
return punct_pattern.sub("", text)
|