FunASR/funasr/datasets/large_datasets/collate_fn.py

195 lines
6.0 KiB
Python

from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
class CommonCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None,
):
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return common_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def common_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output
class DiarCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None,
):
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return diar_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def diar_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list_all_dim(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output
def crop_to_max_size(feature, target_size):
size = len(feature)
diff = size - target_size
if diff <= 0:
return feature
start = np.random.randint(0, diff + 1)
end = size - diff + start
return feature[start:end]
def clipping_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
max_sample_size=None,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
# mainly for pre-training
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
sizes = [len(s) for s in tensor_list]
if max_sample_size is None:
target_size = min(sizes)
else:
target_size = min(min(sizes), max_sample_size)
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
diff = size - target_size
if diff == 0:
tensor[i] = source
else:
tensor[i] = crop_to_max_size(source, target_size)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output