Skip to content
Snippets Groups Projects
Commit 776298f9 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

adds parameter to control frequency of checkpoints

parent f4ec0965
No related branches found
No related tags found
1 merge request!1Uses trajectories
...@@ -26,7 +26,6 @@ from dataset import from_path, from_gtzan ...@@ -26,7 +26,6 @@ from dataset import from_path, from_gtzan
from model import DiffWave from model import DiffWave
from params import AttrDict from params import AttrDict
CHECKPOINS_HOP = 10000 # saves checkopints every 10k steps
def _nested_map(struct, map_fn): def _nested_map(struct, map_fn):
if isinstance(struct, tuple): if isinstance(struct, tuple):
...@@ -112,7 +111,7 @@ class DiffWaveLearner: ...@@ -112,7 +111,7 @@ class DiffWaveLearner:
if self.is_master: if self.is_master:
if self.step % 50 == 0: if self.step % 50 == 0:
self._write_summary(self.step, features, loss) self._write_summary(self.step, features, loss)
if self.step % CHECKPOINS_HOP == 0: if self.step % self.params.checkpoints_hop == 0:
self.save_to_checkpoint() self.save_to_checkpoint()
self.step += 1 self.step += 1
......
...@@ -53,7 +53,8 @@ params = AttrDict( ...@@ -53,7 +53,8 @@ params = AttrDict(
unconditional = True, unconditional = True,
noise_schedule=np.linspace(1e-4, 0.01, 200).tolist(), # last param is num_timesteps noise_schedule=np.linspace(1e-4, 0.01, 200).tolist(), # last param is num_timesteps
inference_noise_schedule=[1e-4, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling inference_noise_schedule=[1e-4, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling
audio_len = 2000 # length of generated samples audio_len = 2000, # length of generated samples
checkpoints_hop = 50000 # how often to save checkpoints
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment