FunASR/funasr/datasets/audio_datasets/espnet_samplers.py

162 lines
5.9 KiB
Python

import torch
import numpy as np
import logging
import math
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from torch.utils.data import BatchSampler, Sampler
import torch.distributed as dist
import random
from funasr.register import tables
@tables.register("batch_sampler_classes", "EspnetStyleBatchSampler")
def EspnetStyleBatchSampler_fn(dataset, **kwargs):
dataloader_args = {}
batch_sampler = EspnetStyleBatchSampler(dataset, **kwargs)
dataloader_args["batch_sampler"] = batch_sampler
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
return dataloader_args
import torch
from torch.utils.data import Dataset, DistributedSampler
import math
import random
class EspnetStyleBatchSampler(DistributedSampler):
def __init__(
self,
dataset,
batch_size,
batch_type="token",
rank=None,
num_replicas=None,
rank_split=False,
shuffle=True,
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
start_step: int = 0,
**kwargs,
):
try:
rank = dist.get_rank()
num_replicas = dist.get_world_size()
except:
rank = 0
num_replicas = 1
# if rank_split:
# logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
# rank = 0
# num_replicas = 1
self.rank = rank
self.num_replicas = num_replicas
self.dataset = dataset
self.batch_size = batch_size
self.batch_type = batch_type
self.is_training = is_training
self.shuffle = shuffle and is_training
self.drop_last = drop_last
self.total_size = len(self.dataset)
self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
self.epoch = 0
self.sort_size = sort_size * num_replicas
self.max_token_length = kwargs.get("max_token_length", 2048)
self.min_token_length = kwargs.get("min_token_length", 0)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
self.start_step = start_step
if self.start_step > 0:
logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
# super().__init__(dataset, num_replicas=num_replicas, rank=rank,
# shuffle=shuffle, drop_last=drop_last)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
random.seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Sort indices by sample length
sorted_indices = sorted(indices, key=lambda idx: self.dataset.get_source_len(idx))
# Organize batches based on 'length' or 'example'
buffer_batches = []
batch = []
max_len_in_batch = 0 # Tracks the max sample length within the current batch
for idx in sorted_indices:
# original_sample_length = self.dataset.get_source_len(idx)
# if (
# original_sample_length < self.min_token_length
# or original_sample_length > self.max_token_length
# ): # Skip samples that exceed the max length
# continue
# sample_length = 1 if self.batch_type == "example" else original_sample_length
# Set sample_length based on the batch type
if self.batch_type == "example":
sample_length = 1
elif self.batch_type == "token":
sample_length = self.dataset.get_source_len(idx) + int(
self.dataset.get_target_len(idx) * 1.2
)
else:
sample_length = self.dataset.get_source_len(idx)
# Calculate potential batch size with the new sample
potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
# Add index to batch if it doesn't exceed batch size limit
if potential_batch_length <= self.batch_size:
batch.append(idx)
max_len_in_batch = max(max_len_in_batch, sample_length)
else:
# Save the current batch and start a new one
buffer_batches.append(batch)
batch = [idx]
max_len_in_batch = sample_length
# Add the last batch if it shouldn't be dropped
if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
buffer_batches.append(batch)
# Shuffle the list of batches
if self.shuffle:
random.seed(self.epoch)
random.shuffle(buffer_batches)
# Ensure each rank gets the same number of batches
batches_per_rank = int(math.ceil(len(buffer_batches) / self.num_replicas))
total_batches_needed = batches_per_rank * self.num_replicas
extra_batches = total_batches_needed - len(buffer_batches)
# Add extra batches by random selection, if needed
buffer_batches += random.choices(buffer_batches, k=extra_batches)
# Allocate the batches to the current rank
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
logging.info(
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
)
# Return an iterator over the batches for the current rank
return iter(rank_batches)
def __len__(self):
# Calculate the number of batches per epoch for the current rank
return 1
def set_epoch(self, epoch):
# Set the epoch for shuffling
self.epoch = epoch