Newer
Older
import numpy as np
import torch
from torch.utils.data import Dataset
from glob import glob
class StandardScaler(object):
"""Standardize the data"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, sample):
return (sample - self.mean) / self.std
class ScaleDiffusionRange(object):
"""Scales data to be in range [-1,1]"""
def __init__(self):
pass
def __call__(self, sample):
return (sample * 2) - 1
class ParticleDataset(Dataset):
def __init__(self, path, transform=None):
self.npy_filepath = path
self.transform = transform
self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
part_traj = self.data[idx, :, :]
if self.transform:
part_traj = self.transform(part_traj)
return part_traj
# maybe this class is not needed, and we can keep one
# class for both cases, and receive an argument for the coordinate
# but for now I'll keep it like this
class ParticleDatasetVx(Dataset):
def __init__(self, path, transform):
self.npy_filepath = path
self.transform = transform
self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
part_traj = self.data[idx, :]
if self.transform:
part_traj = self.transform(part_traj)
return part_traj
# We can pass from this class if we modify the train() method in
# learner.py, and for this we may need to use a new parameter in the
# model class, to know if we are training for trajectories or not
# we're currently using a command argument to know this
class CollatorForDiffwave:
def __init__(self):
pass
def collate(self, minibatch):
trajectories = np.stack([record ['audio']for record in minibatch])
return {
'audio': torch.from_numpy(trajectories),
'spectrogram': None,
}
class ToDiffwaveTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
trajectory = sample
return {
'audio': torch.from_numpy(trajectory).float(),
'spectrogram': None
}