diff --git a/requirements.txt b/requirements.txt
index 1b78a58f564c273d6597c6c3d6a83b6d35716ec8..ae958e82bec53efd47bf3cd99cd600e7a388d08f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,4 +3,5 @@ torch>=1.6
 torchaudio>=0.9.0
 tqdm>=4.65.0
 tensorboard>=2.13.0
-Pillow==9.2.0
\ No newline at end of file
+Pillow==9.2.0
+torchvision>=0.15.2
\ No newline at end of file
diff --git a/src/data/lagrangian_datatools.py b/src/data/lagrangian_datatools.py
index deda09744e4eed35b77b1ebf9c746d861fe1d80a..d60356566d9eeefdf57dd086bfaeee5696defc7f 100644
--- a/src/data/lagrangian_datatools.py
+++ b/src/data/lagrangian_datatools.py
@@ -14,6 +14,14 @@ class StandardScaler(object):
     def __call__(self, sample):
         return (sample - self.mean) / self.std
     
+class ScaleDiffusionRange(object):
+    """Scales data to be in range [-1,1]"""
+    def __init__(self):
+        pass
+
+    def __call__(self, sample):
+        return (sample * 2) - 1
+
 class ParticleDataset(Dataset):
     def __init__(self, path, transform=None):
 
diff --git a/src/diffwave/dataset.py b/src/diffwave/dataset.py
index 31c7b148d62e19494b2e9c14fd03a556513840c9..5304b1370590612e1c029f25418e2153c42df7ed 100644
--- a/src/diffwave/dataset.py
+++ b/src/diffwave/dataset.py
@@ -153,7 +153,7 @@ def from_path(args, params, is_distributed=False):
       # in an exploration notebook
       dataset = ParticleDatasetVx(path = data_dirs[0], 
                                   transform=transforms.Compose([StandardScaler(mean=-0.0003, std=1.7358), 
-                                                                transforms.Lambda(lambda t: (t * 2) - 1),
+                                                                ScaleDiffusionRange(),
                                                                 ToDiffwaveTensor()]))
     else: #with condition
       dataset = ConditionalDataset(data_dirs)
@@ -162,7 +162,7 @@ def from_path(args, params, is_distributed=False):
       batch_size=params.batch_size,
       collate_fn= LagrangianCollator().collate if 'trajectories' in args.data_type else Collator(params).collate,
       shuffle=not is_distributed,
-      num_workers= 2, #os.cpu_count(),
+      num_workers= os.cpu_count(), #2 for cpu
       sampler=DistributedSampler(dataset) if is_distributed else None,
       pin_memory=True,
       drop_last=True)