Skip to content
Snippets Groups Projects
lagrangian_datatools.py 2.24 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    import torch
    from torch.utils.data import Dataset
    from glob import glob
    
    
    class ParticleDataset(Dataset):
        def __init__(self, path, transform=None, for_diffwave = False):
    
            self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
            self.transform = transform
            self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+'))
            self.for_diffwave = for_diffwave
    
        def __len__(self):
            return self.data.shape[0]
    
        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
    
            part_traj = self.data[idx, :, :]
            if self.transform:
                part_traj = self.transform(part_traj)
            
            # 1D convolutions in Pytorch are in the form (batch, channels, length)
            # so we need to permute the dimensions for all the N samples
            # from (T,3) to (3, T)
            # N: number of trajectories, T: number of timesteps, 3: x,y,z coords
            if (self.for_diffwave):
                return {
                    'audio': part_traj.permute(1,0),
                    'spectrogram': None
                }
            return part_traj.permute(1,0)
    
    class ParticleDatasetVx(Dataset):
        def __init__(self, path, transform=None, for_diffwave = False):
    
            self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
            self.transform = transform
            self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]).unsqueeze(1)
            self.for_diffwave = for_diffwave
    
        def __len__(self):
            return self.data.shape[0]
    
        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
    
            part_traj = self.data[idx, :]
            if self.transform:
                part_traj = self.transform(part_traj)
    
            # 1D convolutions in Pytorch are in the form (batch, channels, length)
            # so we need to permute the dimensions for all the N samples
            # from (T,3) to (3, T)
            # N: number of trajectories, T: number of timesteps, 3: x,y,z coords
            if (self.for_diffwave):
                return {
                    'audio': part_traj.permute(1,0),
                    'spectrogram': None
                }
            
            return part_traj.permute(1,0)