Skip to content
Snippets Groups Projects
Commit 82b61dfb authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

loads input data correctly; still have to address normalization

parent 21ae7efa
No related branches found
No related tags found
1 merge request!1Uses trajectories
......@@ -4,13 +4,22 @@ from torch.utils.data import Dataset
from glob import glob
class StandardScaler(object):
"""Standardize the data"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, sample):
return (sample - self.mean) / self.std
class ParticleDataset(Dataset):
def __init__(self, path, transform=None, for_diffwave = False):
def __init__(self, path, transform=None):
self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
self.npy_filepath = path
self.transform = transform
self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+'))
self.for_diffwave = for_diffwave
self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')
def __len__(self):
return self.data.shape[0]
......@@ -23,24 +32,17 @@ class ParticleDataset(Dataset):
if self.transform:
part_traj = self.transform(part_traj)
# 1D convolutions in Pytorch are in the form (batch, channels, length)
# so we need to permute the dimensions for all the N samples
# from (T,3) to (3, T)
# N: number of trajectories, T: number of timesteps, 3: x,y,z coords
if (self.for_diffwave):
return {
'audio': part_traj.permute(1,0),
'spectrogram': None
}
return part_traj.permute(1,0)
return part_traj
# maybe this class is not needed, and we can keep one
# class for both cases, and receive an argument for the coordinate
# but for now I'll keep it like this
class ParticleDatasetVx(Dataset):
def __init__(self, path, transform=None, for_diffwave = False):
def __init__(self, path, transform):
self.npy_filepath = glob(f'{path}/**/*.npy', recursive=True)
self.npy_filepath = path
self.transform = transform
self.data = torch.Tensor(np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]).unsqueeze(1)
self.for_diffwave = for_diffwave
self.data = np.load(self.npy_filepath, encoding="ASCII", allow_pickle=True, mmap_mode='r+')[:,:,0]
def __len__(self):
return self.data.shape[0]
......@@ -53,14 +55,33 @@ class ParticleDatasetVx(Dataset):
if self.transform:
part_traj = self.transform(part_traj)
# 1D convolutions in Pytorch are in the form (batch, channels, length)
# so we need to permute the dimensions for all the N samples
# from (T,3) to (3, T)
# N: number of trajectories, T: number of timesteps, 3: x,y,z coords
if (self.for_diffwave):
return {
'audio': part_traj.permute(1,0),
'spectrogram': None
}
return part_traj.permute(1,0)
\ No newline at end of file
return part_traj
# We can pass from this class if we modify the train() method in
# learner.py, and for this we may need to use a new parameter in the
# model class, to know if we are training for trajectories or not
# we're currently using a command argument to know this
class CollatorForDiffwave:
def __init__(self):
pass
def collate(self, minibatch):
trajectories = np.stack([record ['audio']for record in minibatch])
return {
'audio': torch.from_numpy(trajectories),
'spectrogram': None,
}
class ToDiffwaveTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
trajectory = sample
return {
'audio': torch.from_numpy(trajectory).float(),
'spectrogram': None
}
......@@ -21,7 +21,6 @@ from torch.multiprocessing import spawn
import os
import sys
SOURCE_DIR = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '../..', 'src'))
print(SOURCE_DIR)
sys.path.append(SOURCE_DIR)
from learner import train, train_distributed
......
......@@ -22,10 +22,12 @@ import random
import torch
import torch.nn.functional as F
import torchaudio
from torchvision import transforms
from glob import glob
from torch.utils.data.distributed import DistributedSampler
from data.lagrangian_datatools import ParticleDataset, ParticleDatasetVx
from data.lagrangian_datatools import *
from data.lagrangian_datatools import CollatorForDiffwave as LagrangianCollator
class ConditionalDataset(torch.utils.data.Dataset):
def __init__(self, paths):
......@@ -60,7 +62,6 @@ class UnconditionalDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
audio_filename = self.filenames[idx]
spec_filename = f'{audio_filename}.spec.npy'
signal, _ = torchaudio.load(audio_filename)
return {
'audio': signal[0],
......@@ -85,7 +86,7 @@ class Collator:
start = random.randint(0, record['audio'].shape[-1] - self.params.audio_len)
end = start + self.params.audio_len
record['audio'] = record['audio'][start:end]
record['audio'] = np.pad(record['audio'], (0, (end - start) - len(record['audio'])), mode='constant')
record['audio'] = np.pad(record['audio'], (0, (end - start) - len(record['audio'])), mode='constant') #IDK why this is needed
else:
# Filter out records that aren't long enough.
if len(record['spectrogram']) < self.params.crop_mel_frames:
......@@ -142,21 +143,26 @@ class Collator:
def from_path(args, params, is_distributed=False):
data_dirs = args.data_dirs
if params.unconditional:
if params.unconditional and not 'trajectories' in args.data_type:
dataset = UnconditionalDataset(data_dirs)
else:
if args.data_type == 'trajectories':
dataset = ParticleDataset(path = data_dirs)
dataset = ParticleDataset(path = data_dirs[0], for_diffwave=True)
elif args.data_type == 'trajectories_x':
dataset = ParticleDatasetVx(path = data_dirs)
# the mean and standard deviation were previously computed
# in an exploration notebook
dataset = ParticleDatasetVx(path = data_dirs[0],
transform=transforms.Compose([StandardScaler(mean=-0.0003, std=1.7358),
transforms.Lambda(lambda t: (t * 2) - 1),
ToDiffwaveTensor()]))
else: #with condition
dataset = ConditionalDataset(data_dirs)
return torch.utils.data.DataLoader(
dataset,
batch_size=params.batch_size,
collate_fn= None if args.data_type == 'trajectories' else Collator(params).collate,
collate_fn= LagrangianCollator().collate if 'trajectories' in args.data_type else Collator(params).collate,
shuffle=not is_distributed,
num_workers=os.cpu_count(),
num_workers= 2, #os.cpu_count(),
sampler=DistributedSampler(dataset) if is_distributed else None,
pin_memory=True,
drop_last=True)
......@@ -173,3 +179,5 @@ def from_gtzan(params, is_distributed=False):
sampler=DistributedSampler(dataset) if is_distributed else None,
pin_memory=True,
drop_last=True)
......@@ -26,6 +26,7 @@ from dataset import from_path, from_gtzan
from model import DiffWave
from params import AttrDict
CHECKPOINS_HOP = 10000 # saves checkopints every 10k steps
def _nested_map(struct, map_fn):
if isinstance(struct, tuple):
......@@ -111,7 +112,7 @@ class DiffWaveLearner:
if self.is_master:
if self.step % 50 == 0:
self._write_summary(self.step, features, loss)
if self.step % len(self.dataset) == 0:
if self.step % CHECKPOINS_HOP == 0:
self.save_to_checkpoint()
self.step += 1
......@@ -122,7 +123,7 @@ class DiffWaveLearner:
audio = features['audio']
spectrogram = features['spectrogram']
N, T = audio.shape
N, _ = audio.shape
device = audio.device
self.noise_level = self.noise_level.to(device)
......@@ -149,6 +150,7 @@ class DiffWaveLearner:
if not self.params.unconditional:
writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step)
writer.add_scalar('train/loss', loss, step)
# the following line will print a warning if the audio amplitude is out of range
writer.add_scalar('train/grad_norm', self.grad_norm, step)
writer.flush()
self.summary_writer = writer
......
......@@ -147,7 +147,7 @@ class DiffWave(nn.Module):
def forward(self, audio, diffusion_step, spectrogram=None):
assert (spectrogram is None and self.spectrogram_upsampler is None) or \
(spectrogram is not None and self.spectrogram_upsampler is not None)
x = audio.unsqueeze(1)
x = audio.unsqueeze(1) # watch out for this, we can leave this to the dataloader actually
x = self.input_projection(x)
x = F.relu(x)
......
......@@ -31,7 +31,7 @@ class AttrDict(dict):
raise NotImplementedError
return self
"""
params = AttrDict(
# Training params
batch_size=16,
......@@ -40,25 +40,26 @@ params = AttrDict(
# Data params
sample_rate=2000,
# these are actually not used
n_mels=80,
n_fft=1024,
hop_samples=256,
crop_mel_frames=62, # Probably an error in paper.
# Model params
residual_layers=10,
residual_channels=64,
dilation_cycle_length=10, # with this config and residual layers = 10, we get r=2047
residual_layers=8,
residual_channels=32,
dilation_cycle_length=8, # with this config and residual layers = 10, we get r=2047
unconditional = True,
noise_schedule=np.linspace(1e-4, 0.01, 10).tolist(), # last param is num_timesteps
inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling
noise_schedule=np.linspace(1e-4, 0.01, 200).tolist(), # last param is num_timesteps
inference_noise_schedule=[1e-4, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling
audio_len = 2000 # length of generated samples
)
"""
# This is the original params dictionary
params = AttrDict(
""" params = AttrDict(
# Training params
batch_size=16,
learning_rate=2e-4,
......@@ -75,10 +76,10 @@ params = AttrDict(
residual_layers=30,
residual_channels=64, # 64 or 128 for a larger model (diffwave large)
dilation_cycle_length=10,
unconditional = False,
unconditional = True,
noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],
# unconditional sample len
audio_len = 22050*5, # unconditional_synthesis_samples
)
) """
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment