diff --git a/requirements.txt b/requirements.txt index 1b78a58f564c273d6597c6c3d6a83b6d35716ec8..ae958e82bec53efd47bf3cd99cd600e7a388d08f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ torch>=1.6 torchaudio>=0.9.0 tqdm>=4.65.0 tensorboard>=2.13.0 -Pillow==9.2.0 \ No newline at end of file +Pillow==9.2.0 +torchvision>=0.15.2 \ No newline at end of file diff --git a/src/data/lagrangian_datatools.py b/src/data/lagrangian_datatools.py index deda09744e4eed35b77b1ebf9c746d861fe1d80a..d60356566d9eeefdf57dd086bfaeee5696defc7f 100644 --- a/src/data/lagrangian_datatools.py +++ b/src/data/lagrangian_datatools.py @@ -14,6 +14,14 @@ class StandardScaler(object): def __call__(self, sample): return (sample - self.mean) / self.std +class ScaleDiffusionRange(object): + """Scales data to be in range [-1,1]""" + def __init__(self): + pass + + def __call__(self, sample): + return (sample * 2) - 1 + class ParticleDataset(Dataset): def __init__(self, path, transform=None): diff --git a/src/diffwave/dataset.py b/src/diffwave/dataset.py index 31c7b148d62e19494b2e9c14fd03a556513840c9..5304b1370590612e1c029f25418e2153c42df7ed 100644 --- a/src/diffwave/dataset.py +++ b/src/diffwave/dataset.py @@ -153,7 +153,7 @@ def from_path(args, params, is_distributed=False): # in an exploration notebook dataset = ParticleDatasetVx(path = data_dirs[0], transform=transforms.Compose([StandardScaler(mean=-0.0003, std=1.7358), - transforms.Lambda(lambda t: (t * 2) - 1), + ScaleDiffusionRange(), ToDiffwaveTensor()])) else: #with condition dataset = ConditionalDataset(data_dirs) @@ -162,7 +162,7 @@ def from_path(args, params, is_distributed=False): batch_size=params.batch_size, collate_fn= LagrangianCollator().collate if 'trajectories' in args.data_type else Collator(params).collate, shuffle=not is_distributed, - num_workers= 2, #os.cpu_count(), + num_workers= os.cpu_count(), #2 for cpu sampler=DistributedSampler(dataset) if is_distributed else None, pin_memory=True, drop_last=True)