Skip to content
Snippets Groups Projects

Inference

Merged Maria Guaranda-Cabezas requested to merge inference into main
1 file
+ 3
1
Compare changes
  • Side-by-side
  • Inline
+ 4
2
@@ -103,6 +103,8 @@ class DiffWaveLearner:
while True:
for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset:
if max_steps is not None and self.step >= max_steps:
# Save final checkpoint.
self.save_to_checkpoint()
return
features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
loss = self.train_step(features)
@@ -145,11 +147,11 @@ class DiffWaveLearner:
def _write_summary(self, step, features, loss):
writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step)
writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
# the following line will print a warning if the audio amplitude is out of range
# writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
if not self.params.unconditional:
writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step)
writer.add_scalar('train/loss', loss, step)
# the following line will print a warning if the audio amplitude is out of range
writer.add_scalar('train/grad_norm', self.grad_norm, step)
writer.flush()
self.summary_writer = writer
Loading