58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import logging
|
|
|
|
import kaldiio
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def custom_collate(batch):
|
|
keys, speech, speaker_labels, orders = zip(*batch)
|
|
speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
|
|
speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
|
|
orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
|
|
batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)
|
|
|
|
return keys, batch
|
|
|
|
|
|
class EENDOLADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
data_file,
|
|
):
|
|
self.data_file = data_file
|
|
with open(data_file) as f:
|
|
lines = f.readlines()
|
|
self.samples = [line.strip().split() for line in lines]
|
|
logging.info("total samples: {}".format(len(self.samples)))
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
key, speech_path, speaker_label_path = self.samples[idx]
|
|
speech = kaldiio.load_mat(speech_path)
|
|
speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
|
|
|
|
order = np.arange(speech.shape[0])
|
|
np.random.shuffle(order)
|
|
|
|
return key, speech, speaker_label, order
|
|
|
|
|
|
class EENDOLADataLoader:
|
|
def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
|
|
dataset = EENDOLADataset(data_file)
|
|
self.data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=custom_collate,
|
|
shuffle=shuffle,
|
|
num_workers=num_workers,
|
|
)
|
|
|
|
def build_iter(self, epoch):
|
|
return self.data_loader
|