Skip to content
Snippets Groups Projects
lagrangian_datatools.py 2.52 KiB
Newer Older
import numpy as np
import torch
from torch.utils.data import Dataset
from glob import glob


class StandardScaler(object):
    """Standardize the data"""

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        return (sample - self.mean) / self.std
    
class ScaleDiffusionRange(object):
    """Scales data to be in range [-1,1]"""
    def __init__(self):
        pass

    def __call__(self, sample):
        return (sample * 2) - 1

    def __init__(self, path, transform=None):
        self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        part_traj = self.data[idx, :, :]
        if self.transform:
            part_traj = self.transform(part_traj)
        
# maybe this class is not needed, and we can keep one 
# class for both cases, and receive an argument for the coordinate
# but for now I'll keep it like this
    def __init__(self, path, transform):
        self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        part_traj = self.data[idx, :]
        if self.transform:
            part_traj = self.transform(part_traj)

        return part_traj

# We can pass from this class if we modify the train() method in 
# learner.py, and for this we may need to use a new parameter in the
# model class, to know if we are training for trajectories or not
# we're currently using a command argument to know this

class CollatorForDiffwave:
    def __init__(self):
        pass

    def collate(self, minibatch):
        trajectories = np.stack([record ['audio']for record in minibatch])
        return {
            'audio': torch.from_numpy(trajectories),
            'spectrogram': None,
        }
    

class ToDiffwaveTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        trajectory = sample

        return {
            'audio': torch.from_numpy(trajectory).float(),
            'spectrogram': None
        }