import numpy as np import torch from torch.utils.data import Dataset from glob import glob class ParticleDataset(Dataset): def __init__(self, path, transform=None, for_diffwave = False): self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True) self.transform = transform self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')) self.for_diffwave = for_diffwave def __len__(self): return self.data.shape[0] def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() part_traj = self.data[idx, :, :] if self.transform: part_traj = self.transform(part_traj) # 1D convolutions in Pytorch are in the form (batch, channels, length) # so we need to permute the dimensions for all the N samples # from (T,3) to (3, T) # N: number of trajectories, T: number of timesteps, 3: x,y,z coords if (self.for_diffwave): return { 'audio': part_traj.permute(1,0), 'spectrogram': None } return part_traj.permute(1,0) class ParticleDatasetVx(Dataset): def __init__(self, path, transform=None, for_diffwave = False): self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True) self.transform = transform self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]).unsqueeze(1) self.for_diffwave = for_diffwave def __len__(self): return self.data.shape[0] def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() part_traj = self.data[idx, :] if self.transform: part_traj = self.transform(part_traj) # 1D convolutions in Pytorch are in the form (batch, channels, length) # so we need to permute the dimensions for all the N samples # from (T,3) to (3, T) # N: number of trajectories, T: number of timesteps, 3: x,y,z coords if (self.for_diffwave): return { 'audio': part_traj.permute(1,0), 'spectrogram': None } return part_traj.permute(1,0)