Skip to content
Snippets Groups Projects
Commit 66e8fdbf authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

saves last checkpoint

parent 46d5fedd
No related branches found
No related tags found
1 merge request!3Inference
......@@ -103,6 +103,8 @@ class DiffWaveLearner:
while True:
for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset:
if max_steps is not None and self.step >= max_steps:
# Save final checkpoint.
self.save_to_checkpoint()
return
features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
loss = self.train_step(features)
......@@ -145,11 +147,11 @@ class DiffWaveLearner:
def _write_summary(self, step, features, loss):
writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step)
# the following line will print a warning if the audio amplitude is out of range
writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
if not self.params.unconditional:
writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step)
writer.add_scalar('train/loss', loss, step)
# the following line will print a warning if the audio amplitude is out of range
writer.add_scalar('train/grad_norm', self.grad_norm, step)
writer.flush()
self.summary_writer = writer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment