Skip to content
Snippets Groups Projects
lagrangian_datatools.py 2.24 KiB
Newer Older
import numpy as np
import torch
from torch.utils.data import Dataset
from glob import glob


class ParticleDataset(Dataset):
    def __init__(self, path, transform=None, for_diffwave = False):

        self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
        self.transform = transform
        self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+'))
        self.for_diffwave = for_diffwave

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        part_traj = self.data[idx, :, :]
        if self.transform:
            part_traj = self.transform(part_traj)
        
        # 1D convolutions in Pytorch are in the form (batch, channels, length)
        # so we need to permute the dimensions for all the N samples
        # from (T,3) to (3, T)
        # N: number of trajectories, T: number of timesteps, 3: x,y,z coords
        if (self.for_diffwave):
            return {
                'audio': part_traj.permute(1,0),
                'spectrogram': None
            }
        return part_traj.permute(1,0)

class ParticleDatasetVx(Dataset):
    def __init__(self, path, transform=None, for_diffwave = False):

        self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
        self.transform = transform
        self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]).unsqueeze(1)
        self.for_diffwave = for_diffwave

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        part_traj = self.data[idx, :]
        if self.transform:
            part_traj = self.transform(part_traj)

        # 1D convolutions in Pytorch are in the form (batch, channels, length)
        # so we need to permute the dimensions for all the N samples
        # from (T,3) to (3, T)
        # N: number of trajectories, T: number of timesteps, 3: x,y,z coords
        if (self.for_diffwave):
            return {
                'audio': part_traj.permute(1,0),
                'spectrogram': None
            }
        
        return part_traj.permute(1,0)