Skip to content
Snippets Groups Projects
Commit 0e0af639 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

adapts inference to be run in cpu too

parent 57fbc5bf
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,11 @@ import torchaudio ...@@ -20,6 +20,11 @@ import torchaudio
from argparse import ArgumentParser from argparse import ArgumentParser
import sys
import os
module_path = os.path.abspath(os.path.join('./'))
if module_path not in sys.path:
sys.path.append(module_path+'/src')
from diffwave.params import AttrDict, params as base_params from diffwave.params import AttrDict, params as base_params
from diffwave.model import DiffWave from diffwave.model import DiffWave
...@@ -32,9 +37,9 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device(' ...@@ -32,9 +37,9 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('
if os.path.exists(f'{model_dir}/weights.pt'): if os.path.exists(f'{model_dir}/weights.pt'):
checkpoint = torch.load(f'{model_dir}/weights.pt') checkpoint = torch.load(f'{model_dir}/weights.pt')
else: else:
checkpoint = torch.load(model_dir) checkpoint = torch.load(model_dir, map_location=device)
model = DiffWave(AttrDict(base_params)).to(device) model = DiffWave(AttrDict(base_params)).to(device)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model']) # if the params settings do not match with the checkpoint, this will fail
model.eval() model.eval()
models[model_dir] = model models[model_dir] = model
...@@ -94,7 +99,7 @@ def main(args): ...@@ -94,7 +99,7 @@ def main(args):
spectrogram = torch.from_numpy(np.load(args.spectrogram_path)) spectrogram = torch.from_numpy(np.load(args.spectrogram_path))
else: else:
spectrogram = None spectrogram = None
audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params) audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
torchaudio.save(args.output, audio.cpu(), sample_rate=sr) torchaudio.save(args.output, audio.cpu(), sample_rate=sr)
...@@ -108,4 +113,6 @@ if __name__ == '__main__': ...@@ -108,4 +113,6 @@ if __name__ == '__main__':
help='output file name') help='output file name')
parser.add_argument('--fast', '-f', action='store_true', parser.add_argument('--fast', '-f', action='store_true',
help='fast sampling procedure') help='fast sampling procedure')
parser.add_argument('--cpu', action='store_true',
help='use cpu instead of cuda')
main(parser.parse_args()) main(parser.parse_args())
...@@ -171,7 +171,7 @@ def train(args, params): ...@@ -171,7 +171,7 @@ def train(args, params):
if args.data_type == 'trajectories': if args.data_type == 'trajectories':
dataset = from_gtzan(params) dataset = from_gtzan(params)
dataset = from_path(args.data_dirs, params) dataset = from_path(args.data_dirs, params)
model = DiffWave(params) model = DiffWave(params).to(device='cuda' if torch.cuda.is_available() else 'cpu')
_train_impl(0, model, dataset, args, params) _train_impl(0, model, dataset, args, params)
......
...@@ -31,7 +31,7 @@ class AttrDict(dict): ...@@ -31,7 +31,7 @@ class AttrDict(dict):
raise NotImplementedError raise NotImplementedError
return self return self
"""
params = AttrDict( params = AttrDict(
# Training params # Training params
batch_size=16, batch_size=16,
...@@ -52,14 +52,12 @@ params = AttrDict( ...@@ -52,14 +52,12 @@ params = AttrDict(
unconditional = True, unconditional = True,
noise_schedule=np.linspace(1e-4, 0.01, 10).tolist(), # last param is num_timesteps noise_schedule=np.linspace(1e-4, 0.01, 10).tolist(), # last param is num_timesteps
inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling
audio_len = 2000 # length of generated samples
# unconditional sample len
audio_len = 2000, # length of generated samples
) )
"""
# This is the original params dictionary # This is the original params dictionary
"""
params = AttrDict( params = AttrDict(
# Training params # Training params
batch_size=16, batch_size=16,
...@@ -84,4 +82,3 @@ params = AttrDict( ...@@ -84,4 +82,3 @@ params = AttrDict(
# unconditional sample len # unconditional sample len
audio_len = 22050*5, # unconditional_synthesis_samples audio_len = 22050*5, # unconditional_synthesis_samples
) )
"""
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment