351 lines
11 KiB
Python
351 lines
11 KiB
Python
import json
|
||
import re
|
||
import string
|
||
from collections import defaultdict, namedtuple
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
from unicodedata import category
|
||
|
||
import logging
|
||
|
||
|
||
EOS_TYPE = "EOS"
|
||
PUNCT_TYPE = "PUNCT"
|
||
PLAIN_TYPE = "PLAIN"
|
||
Instance = namedtuple("Instance", "token_type un_normalized normalized")
|
||
known_types = [
|
||
"PLAIN",
|
||
"DATE",
|
||
"CARDINAL",
|
||
"LETTERS",
|
||
"VERBATIM",
|
||
"MEASURE",
|
||
"DECIMAL",
|
||
"ORDINAL",
|
||
"DIGIT",
|
||
"MONEY",
|
||
"TELEPHONE",
|
||
"ELECTRONIC",
|
||
"FRACTION",
|
||
"TIME",
|
||
"ADDRESS",
|
||
]
|
||
|
||
|
||
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
|
||
"""
|
||
https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
|
||
Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
|
||
E.g.
|
||
PLAIN Brillantaisia <self>
|
||
PLAIN is <self>
|
||
PLAIN a <self>
|
||
PLAIN genus <self>
|
||
PLAIN of <self>
|
||
PLAIN plant <self>
|
||
PLAIN in <self>
|
||
PLAIN family <self>
|
||
PLAIN Acanthaceae <self>
|
||
PUNCT . sil
|
||
<eos> <eos>
|
||
|
||
Args:
|
||
file_path: file path to text file
|
||
|
||
Returns: flat list of instances
|
||
"""
|
||
res = []
|
||
with open(file_path, "r") as fp:
|
||
for line in fp:
|
||
parts = line.strip().split("\t")
|
||
if parts[0] == "<eos>":
|
||
res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
|
||
else:
|
||
l_type, l_token, l_normalized = parts
|
||
l_token = l_token.lower()
|
||
l_normalized = l_normalized.lower()
|
||
|
||
if l_type == PLAIN_TYPE:
|
||
res.append(
|
||
Instance(token_type=l_type, un_normalized=l_token, normalized=l_token)
|
||
)
|
||
elif l_type != PUNCT_TYPE:
|
||
res.append(
|
||
Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized)
|
||
)
|
||
return res
|
||
|
||
|
||
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
|
||
"""
|
||
Load given list of text files using the `load_func` function.
|
||
|
||
Args:
|
||
file_paths: list of file paths
|
||
load_func: loading function
|
||
|
||
Returns: flat list of instances
|
||
"""
|
||
res = []
|
||
for file_path in file_paths:
|
||
res.extend(load_func(file_path=file_path))
|
||
return res
|
||
|
||
|
||
def clean_generic(text: str) -> str:
|
||
"""
|
||
Cleans text without affecting semiotic classes.
|
||
|
||
Args:
|
||
text: string
|
||
|
||
Returns: cleaned string
|
||
"""
|
||
text = text.strip()
|
||
text = text.lower()
|
||
return text
|
||
|
||
|
||
def evaluate(
|
||
preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True
|
||
) -> float:
|
||
"""
|
||
Evaluates accuracy given predictions and labels.
|
||
|
||
Args:
|
||
preds: predictions
|
||
labels: labels
|
||
input: optional, only needed for verbosity
|
||
verbose: if true prints [input], golden labels and predictions
|
||
|
||
Returns accuracy
|
||
"""
|
||
acc = 0
|
||
nums = len(preds)
|
||
for i in range(nums):
|
||
pred_norm = clean_generic(preds[i])
|
||
label_norm = clean_generic(labels[i])
|
||
if pred_norm == label_norm:
|
||
acc = acc + 1
|
||
else:
|
||
if input:
|
||
print(f"inpu: {json.dumps(input[i])}")
|
||
print(f"gold: {json.dumps(label_norm)}")
|
||
print(f"pred: {json.dumps(pred_norm)}")
|
||
return acc / nums
|
||
|
||
|
||
def training_data_to_tokens(
|
||
data: List[Instance], category: Optional[str] = None
|
||
) -> Dict[str, Tuple[List[str], List[str]]]:
|
||
"""
|
||
Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
|
||
|
||
Args:
|
||
data: list of instances
|
||
category: optional semiotic class category name
|
||
|
||
Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
|
||
"""
|
||
result = defaultdict(lambda: ([], []))
|
||
for instance in data:
|
||
if instance.token_type != EOS_TYPE:
|
||
if category is None or instance.token_type == category:
|
||
result[instance.token_type][0].append(instance.un_normalized)
|
||
result[instance.token_type][1].append(instance.normalized)
|
||
return result
|
||
|
||
|
||
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
|
||
"""
|
||
Takes instance list, creates list of sentences split by EOS_Token
|
||
Args:
|
||
data: list of instances
|
||
Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
|
||
"""
|
||
# split data at EOS boundaries
|
||
sentences = []
|
||
sentence = []
|
||
categories = []
|
||
sentence_categories = set()
|
||
|
||
for instance in data:
|
||
if instance.token_type == EOS_TYPE:
|
||
sentences.append(sentence)
|
||
sentence = []
|
||
categories.append(sentence_categories)
|
||
sentence_categories = set()
|
||
else:
|
||
sentence.append(instance)
|
||
sentence_categories.update([instance.token_type])
|
||
un_normalized = [
|
||
" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences
|
||
]
|
||
normalized = [
|
||
" ".join([instance.normalized for instance in sentence]) for sentence in sentences
|
||
]
|
||
return un_normalized, normalized, categories
|
||
|
||
|
||
def post_process_punctuation(text: str) -> str:
|
||
"""
|
||
Normalized quotes and spaces
|
||
|
||
Args:
|
||
text: text
|
||
|
||
Returns: text with normalized spaces and quotes
|
||
"""
|
||
text = (
|
||
text.replace("( ", "(")
|
||
.replace(" )", ")")
|
||
.replace("{ ", "{")
|
||
.replace(" }", "}")
|
||
.replace("[ ", "[")
|
||
.replace(" ]", "]")
|
||
.replace(" ", " ")
|
||
.replace("”", '"')
|
||
.replace("’", "'")
|
||
.replace("»", '"')
|
||
.replace("«", '"')
|
||
.replace("\\", "")
|
||
.replace("„", '"')
|
||
.replace("´", "'")
|
||
.replace("’", "'")
|
||
.replace("“", '"')
|
||
.replace("‘", "'")
|
||
.replace("`", "'")
|
||
.replace("- -", "--")
|
||
)
|
||
|
||
for punct in "!,.:;?":
|
||
text = text.replace(f" {punct}", punct)
|
||
return text.strip()
|
||
|
||
|
||
def pre_process(text: str) -> str:
|
||
"""
|
||
Optional text preprocessing before normalization (part of TTS TN pipeline)
|
||
|
||
Args:
|
||
text: string that may include semiotic classes
|
||
|
||
Returns: text with spaces around punctuation marks
|
||
"""
|
||
space_both = "[]"
|
||
for punct in space_both:
|
||
text = text.replace(punct, " " + punct + " ")
|
||
|
||
# remove extra space
|
||
text = re.sub(r" +", " ", text)
|
||
return text
|
||
|
||
|
||
def load_file(file_path: str) -> List[str]:
|
||
"""
|
||
Loads given text file with separate lines into list of string.
|
||
|
||
Args:
|
||
file_path: file path
|
||
|
||
Returns: flat list of string
|
||
"""
|
||
res = []
|
||
with open(file_path, "r") as fp:
|
||
for line in fp:
|
||
res.append(line)
|
||
return res
|
||
|
||
|
||
def write_file(file_path: str, data: List[str]):
|
||
"""
|
||
Writes out list of string to file.
|
||
|
||
Args:
|
||
file_path: file path
|
||
data: list of string
|
||
|
||
"""
|
||
with open(file_path, "w") as fp:
|
||
for line in data:
|
||
fp.write(line + "\n")
|
||
|
||
|
||
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
|
||
"""
|
||
Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
|
||
After NN normalization, Moses detokenization puts a space after
|
||
punctuation marks, and attaches an opening quote "'" to the word to the right.
|
||
E.g., input to the TN NN model is "12 test' example",
|
||
after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
|
||
but it doesn't match the input and can cause issues during TTS voice generation.)
|
||
The current function will match the punctuation and spaces of the normalized text with the input sequence.
|
||
"12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
|
||
|
||
Args:
|
||
input: input text (original input to the NN, before normalization or tokenization)
|
||
normalized_text: output text (output of the TN NN model)
|
||
add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
|
||
"""
|
||
# in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
|
||
# this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
|
||
# to make sure these new double quotes work with this function
|
||
if "``" in input and "``" not in normalized_text:
|
||
input = input.replace("``", '"')
|
||
input = [x for x in input]
|
||
normalized_text = [x for x in normalized_text]
|
||
punct_marks = [x for x in string.punctuation if x in input]
|
||
|
||
if add_unicode_punct:
|
||
punct_unicode = [
|
||
chr(i)
|
||
for i in range(sys.maxunicode)
|
||
if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
|
||
]
|
||
punct_marks = punct_marks.extend(punct_unicode)
|
||
|
||
for punct in punct_marks:
|
||
try:
|
||
equal = True
|
||
if input.count(punct) != normalized_text.count(punct):
|
||
equal = False
|
||
idx_in, idx_out = 0, 0
|
||
while punct in input[idx_in:]:
|
||
idx_out = normalized_text.index(punct, idx_out)
|
||
idx_in = input.index(punct, idx_in)
|
||
|
||
def _is_valid(idx_out, idx_in, normalized_text, input):
|
||
"""Check if previous or next word match (for cases when punctuation marks are part of
|
||
semiotic token, i.e. some punctuation can be missing in the normalized text)"""
|
||
return (
|
||
idx_out > 0
|
||
and idx_in > 0
|
||
and normalized_text[idx_out - 1] == input[idx_in - 1]
|
||
) or (
|
||
idx_out < len(normalized_text) - 1
|
||
and idx_in < len(input) - 1
|
||
and normalized_text[idx_out + 1] == input[idx_in + 1]
|
||
)
|
||
|
||
if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
|
||
idx_in += 1
|
||
continue
|
||
if idx_in > 0 and idx_out > 0:
|
||
if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
|
||
normalized_text[idx_out - 1] = ""
|
||
|
||
elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
|
||
normalized_text[idx_out - 1] += " "
|
||
|
||
if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
|
||
if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
|
||
normalized_text[idx_out + 1] = ""
|
||
elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
|
||
normalized_text[idx_out] = normalized_text[idx_out] + " "
|
||
idx_out += 1
|
||
idx_in += 1
|
||
except:
|
||
logging.debug(f"Skipping post-processing of {''.join(normalized_text)} for '{punct}'")
|
||
|
||
normalized_text = "".join(normalized_text)
|
||
return re.sub(r" +", " ", normalized_text)
|