195 lines
6.0 KiB
Python
195 lines
6.0 KiB
Python
from typing import Collection
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
|
|
|
|
|
|
class CommonCollateFn:
|
|
"""Functor class of common_collate_fn()"""
|
|
|
|
def __init__(
|
|
self,
|
|
float_pad_value: Union[float, int] = 0.0,
|
|
int_pad_value: int = -32768,
|
|
not_sequence: Collection[str] = (),
|
|
max_sample_size=None,
|
|
):
|
|
self.float_pad_value = float_pad_value
|
|
self.int_pad_value = int_pad_value
|
|
self.not_sequence = set(not_sequence)
|
|
self.max_sample_size = max_sample_size
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
|
f"int_pad_value={self.float_pad_value})"
|
|
)
|
|
|
|
def __call__(
|
|
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
|
return common_collate_fn(
|
|
data,
|
|
float_pad_value=self.float_pad_value,
|
|
int_pad_value=self.int_pad_value,
|
|
not_sequence=self.not_sequence,
|
|
)
|
|
|
|
|
|
def common_collate_fn(
|
|
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
|
float_pad_value: Union[float, int] = 0.0,
|
|
int_pad_value: int = -32768,
|
|
not_sequence: Collection[str] = (),
|
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
|
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
|
|
uttids = [u for u, _ in data]
|
|
data = [d for _, d in data]
|
|
|
|
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
|
assert all(
|
|
not k.endswith("_lengths") for k in data[0]
|
|
), f"*_lengths is reserved: {list(data[0])}"
|
|
|
|
output = {}
|
|
for key in data[0]:
|
|
if data[0][key].dtype.kind == "i":
|
|
pad_value = int_pad_value
|
|
else:
|
|
pad_value = float_pad_value
|
|
|
|
array_list = [d[key] for d in data]
|
|
tensor_list = [torch.from_numpy(a) for a in array_list]
|
|
tensor = pad_list(tensor_list, pad_value)
|
|
output[key] = tensor
|
|
|
|
if key not in not_sequence:
|
|
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
|
output[key + "_lengths"] = lens
|
|
|
|
output = (uttids, output)
|
|
return output
|
|
|
|
|
|
class DiarCollateFn:
|
|
"""Functor class of common_collate_fn()"""
|
|
|
|
def __init__(
|
|
self,
|
|
float_pad_value: Union[float, int] = 0.0,
|
|
int_pad_value: int = -32768,
|
|
not_sequence: Collection[str] = (),
|
|
max_sample_size=None,
|
|
):
|
|
self.float_pad_value = float_pad_value
|
|
self.int_pad_value = int_pad_value
|
|
self.not_sequence = set(not_sequence)
|
|
self.max_sample_size = max_sample_size
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
|
f"int_pad_value={self.float_pad_value})"
|
|
)
|
|
|
|
def __call__(
|
|
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
|
return diar_collate_fn(
|
|
data,
|
|
float_pad_value=self.float_pad_value,
|
|
int_pad_value=self.int_pad_value,
|
|
not_sequence=self.not_sequence,
|
|
)
|
|
|
|
|
|
def diar_collate_fn(
|
|
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
|
float_pad_value: Union[float, int] = 0.0,
|
|
int_pad_value: int = -32768,
|
|
not_sequence: Collection[str] = (),
|
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
|
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
|
|
uttids = [u for u, _ in data]
|
|
data = [d for _, d in data]
|
|
|
|
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
|
assert all(
|
|
not k.endswith("_lengths") for k in data[0]
|
|
), f"*_lengths is reserved: {list(data[0])}"
|
|
|
|
output = {}
|
|
for key in data[0]:
|
|
if data[0][key].dtype.kind == "i":
|
|
pad_value = int_pad_value
|
|
else:
|
|
pad_value = float_pad_value
|
|
|
|
array_list = [d[key] for d in data]
|
|
tensor_list = [torch.from_numpy(a) for a in array_list]
|
|
tensor = pad_list_all_dim(tensor_list, pad_value)
|
|
output[key] = tensor
|
|
|
|
if key not in not_sequence:
|
|
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
|
output[key + "_lengths"] = lens
|
|
|
|
output = (uttids, output)
|
|
return output
|
|
|
|
|
|
def crop_to_max_size(feature, target_size):
|
|
size = len(feature)
|
|
diff = size - target_size
|
|
if diff <= 0:
|
|
return feature
|
|
|
|
start = np.random.randint(0, diff + 1)
|
|
end = size - diff + start
|
|
return feature[start:end]
|
|
|
|
|
|
def clipping_collate_fn(
|
|
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
|
max_sample_size=None,
|
|
not_sequence: Collection[str] = (),
|
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
|
# mainly for pre-training
|
|
uttids = [u for u, _ in data]
|
|
data = [d for _, d in data]
|
|
|
|
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
|
assert all(
|
|
not k.endswith("_lengths") for k in data[0]
|
|
), f"*_lengths is reserved: {list(data[0])}"
|
|
|
|
output = {}
|
|
for key in data[0]:
|
|
array_list = [d[key] for d in data]
|
|
tensor_list = [torch.from_numpy(a) for a in array_list]
|
|
sizes = [len(s) for s in tensor_list]
|
|
if max_sample_size is None:
|
|
target_size = min(sizes)
|
|
else:
|
|
target_size = min(min(sizes), max_sample_size)
|
|
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
|
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
|
diff = size - target_size
|
|
if diff == 0:
|
|
tensor[i] = source
|
|
else:
|
|
tensor[i] = crop_to_max_size(source, target_size)
|
|
output[key] = tensor
|
|
|
|
if key not in not_sequence:
|
|
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
|
output[key + "_lengths"] = lens
|
|
|
|
output = (uttids, output)
|
|
return output
|