Skip to content
Snippets Groups Projects
lagrangian_datatools.py 2.34 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    import torch
    from torch.utils.data import Dataset
    from glob import glob
    
    
    
    class StandardScaler(object):
        """Standardize the data"""
    
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std
    
        def __call__(self, sample):
            return (sample - self.mean) / self.std
        
    
        def __init__(self, path, transform=None):
    
            self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')
    
    
        def __len__(self):
            return self.data.shape[0]
    
        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
    
            part_traj = self.data[idx, :, :]
            if self.transform:
                part_traj = self.transform(part_traj)
            
    
    # maybe this class is not needed, and we can keep one 
    # class for both cases, and receive an argument for the coordinate
    # but for now I'll keep it like this
    
        def __init__(self, path, transform):
    
            self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]
    
    
        def __len__(self):
            return self.data.shape[0]
    
        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
    
            part_traj = self.data[idx, :]
            if self.transform:
                part_traj = self.transform(part_traj)
    
    
            return part_traj
    
    # We can pass from this class if we modify the train() method in 
    # learner.py, and for this we may need to use a new parameter in the
    # model class, to know if we are training for trajectories or not
    # we're currently using a command argument to know this
    
    class CollatorForDiffwave:
        def __init__(self):
            pass
    
        def collate(self, minibatch):
            trajectories = np.stack([record ['audio']for record in minibatch])
            return {
                'audio': torch.from_numpy(trajectories),
                'spectrogram': None,
            }
        
    
    class ToDiffwaveTensor(object):
        """Convert ndarrays in sample to Tensors."""
    
        def __call__(self, sample):
            trajectory = sample
    
            return {
                'audio': torch.from_numpy(trajectory).float(),
                'spectrogram': None
            }