Skip to content
Snippets Groups Projects
Commit 46d5fedd authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

Merge branch 'uses_trajectories' into 'main'

Uses trajectories

See merge request !2
parents e8d715df eaba14a0
No related branches found
No related tags found
1 merge request!2Uses trajectories
......@@ -3,4 +3,5 @@ torch>=1.6
torchaudio>=0.9.0
tqdm>=4.65.0
tensorboard>=2.13.0
Pillow==9.2.0
\ No newline at end of file
Pillow==9.2.0
torchvision>=0.15.2
\ No newline at end of file
......@@ -14,6 +14,14 @@ class StandardScaler(object):
def __call__(self, sample):
return (sample - self.mean) / self.std
class ScaleDiffusionRange(object):
"""Scales data to be in range [-1,1]"""
def __init__(self):
pass
def __call__(self, sample):
return (sample * 2) - 1
class ParticleDataset(Dataset):
def __init__(self, path, transform=None):
......
......@@ -153,7 +153,7 @@ def from_path(args, params, is_distributed=False):
# in an exploration notebook
dataset = ParticleDatasetVx(path = data_dirs[0],
transform=transforms.Compose([StandardScaler(mean=-0.0003, std=1.7358),
transforms.Lambda(lambda t: (t * 2) - 1),
ScaleDiffusionRange(),
ToDiffwaveTensor()]))
else: #with condition
dataset = ConditionalDataset(data_dirs)
......@@ -162,7 +162,7 @@ def from_path(args, params, is_distributed=False):
batch_size=params.batch_size,
collate_fn= LagrangianCollator().collate if 'trajectories' in args.data_type else Collator(params).collate,
shuffle=not is_distributed,
num_workers= 2, #os.cpu_count(),
num_workers= os.cpu_count(), #2 for cpu
sampler=DistributedSampler(dataset) if is_distributed else None,
pin_memory=True,
drop_last=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment