Skip to content
Snippets Groups Projects
Commit 6db8ebe8 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

Merge branch 'inference' into 'main'

Inference

See merge request !3
parents 46d5fedd 45b2a05c
No related branches found
No related tags found
1 merge request!3Inference
......@@ -153,7 +153,6 @@ def from_path(args, params, is_distributed=False):
# in an exploration notebook
dataset = ParticleDatasetVx(path = data_dirs[0],
transform=transforms.Compose([StandardScaler(mean=-0.0003, std=1.7358),
ScaleDiffusionRange(),
ToDiffwaveTensor()]))
else: #with condition
dataset = ConditionalDataset(data_dirs)
......
......@@ -17,6 +17,7 @@ import numpy as np
import os
import torch
import torchaudio
from torchvision.transforms import Compose, Lambda, ToPILImage
from argparse import ArgumentParser
......@@ -90,8 +91,9 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('
noise = torch.randn_like(audio)
sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5
audio += sigma * noise
audio = torch.clamp(audio, -1.0, 1.0)
return audio, model.params.sample_rate
audio = torch.clamp(audio, -1.0, 1.0)
#audio = torch.clamp(audio, -1.0, 1.0) if(params.audio_len != 2000) else audio
return audio, model.params.sample_rate
def main(args):
......@@ -99,8 +101,30 @@ def main(args):
spectrogram = torch.from_numpy(np.load(args.spectrogram_path))
else:
spectrogram = None
audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
torchaudio.save(args.output, audio.cpu(), sample_rate=sr)
samples = []
for i in range(args.num_samples):
audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
if base_params.audio_len !=2000:
samples.append(audio.cpu())
else:
# this is a lagrangian trajectory, we have to apply the inverse of
# the transformations used when preprocessing
reverse_transform = Compose([
Lambda(lambda t: t.numpy(force=True).astype(np.float64).transpose()),
])
trajectory = reverse_transform(audio)
samples.append(trajectory)
if base_params.audio_len !=2000:
for audio in samples:
torchaudio.save(args.output, audio, sample_rate=sr)
else:
# vertically stack all the trajectories
trajectories = np.stack(samples, axis=0)
print(trajectories.shape)
with open(args.output, 'wb') as f:
np.save(f, trajectories)
if __name__ == '__main__':
......@@ -115,4 +139,6 @@ if __name__ == '__main__':
help='fast sampling procedure')
parser.add_argument('--cpu', action='store_true',
help='use cpu instead of cuda')
parser.add_argument('--num_samples', default=1, type=int,
help='number of samples to generate')
main(parser.parse_args())
......@@ -103,6 +103,8 @@ class DiffWaveLearner:
while True:
for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset:
if max_steps is not None and self.step >= max_steps:
# Save final checkpoint.
self.save_to_checkpoint()
return
features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
loss = self.train_step(features)
......@@ -145,11 +147,11 @@ class DiffWaveLearner:
def _write_summary(self, step, features, loss):
writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step)
writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
# the following line will print a warning if the audio amplitude is out of range
# writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
if not self.params.unconditional:
writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step)
writer.add_scalar('train/loss', loss, step)
# the following line will print a warning if the audio amplitude is out of range
writer.add_scalar('train/grad_norm', self.grad_norm, step)
writer.flush()
self.summary_writer = writer
......
......@@ -34,8 +34,8 @@ class AttrDict(dict):
params = AttrDict(
# Training params
batch_size=16,
learning_rate=2e-4,
batch_size=64,
learning_rate=1e-5,
max_grad_norm=None,
# Data params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment