Skip to content
Snippets Groups Projects
Commit 2cc0b801 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

supports inference of 1D lagrangian trajectories

parent 66e8fdbf
No related branches found
No related tags found
1 merge request!3Inference
......@@ -17,6 +17,7 @@ import numpy as np
import os
import torch
import torchaudio
from torchvision.transforms import Compose, Lambda, ToPILImage
from argparse import ArgumentParser
......@@ -90,8 +91,9 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('
noise = torch.randn_like(audio)
sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5
audio += sigma * noise
audio = torch.clamp(audio, -1.0, 1.0)
return audio, model.params.sample_rate
audio = torch.clamp(audio, -1.0, 1.0)
#audio = torch.clamp(audio, -1.0, 1.0) if(params.audio_len != 2000) else audio
return audio, model.params.sample_rate
def main(args):
......@@ -99,8 +101,31 @@ def main(args):
spectrogram = torch.from_numpy(np.load(args.spectrogram_path))
else:
spectrogram = None
audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
torchaudio.save(args.output, audio.cpu(), sample_rate=sr)
samples = []
for i in range(args.num_samples):
audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
if base_params.audio_len !=2000:
samples.append(audio.cpu())
else:
# this is a lagrangian trajectory, we have to apply the inverse of
# the transformations used when preprocessing
reverse_transform = Compose([
Lambda(lambda t: (t + 1) / 2),
Lambda(lambda t: t.numpy(force=True).astype(np.float64).transpose()),
])
trajectory = reverse_transform(audio)
samples.append(trajectory)
if base_params.audio_len !=2000:
for audio in samples:
torchaudio.save(args.output, audio, sample_rate=sr)
else:
# vertically stack all the trajectories
trajectories = np.stack(samples, axis=0)
print(trajectories.shape)
with open(args.output, 'wb') as f:
np.save(f, trajectories)
if __name__ == '__main__':
......@@ -115,4 +140,6 @@ if __name__ == '__main__':
help='fast sampling procedure')
parser.add_argument('--cpu', action='store_true',
help='use cpu instead of cuda')
parser.add_argument('--num_samples', default=1, type=int,
help='number of samples to generate')
main(parser.parse_args())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment