Skip to content
Snippets Groups Projects
Commit d872c448 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

adds comments to code

parent 4b64fcec
No related branches found
No related tags found
1 merge request!5Inference
...@@ -75,6 +75,7 @@ class CollatorForDiffwave: ...@@ -75,6 +75,7 @@ class CollatorForDiffwave:
pass pass
def collate(self, minibatch): def collate(self, minibatch):
# shape is (batch_size, trajectory_length)
trajectories = np.stack([record ['audio']for record in minibatch]) trajectories = np.stack([record ['audio']for record in minibatch])
return { return {
'audio': torch.from_numpy(trajectories), 'audio': torch.from_numpy(trajectories),
......
...@@ -28,6 +28,17 @@ from params import AttrDict ...@@ -28,6 +28,17 @@ from params import AttrDict
def _nested_map(struct, map_fn): def _nested_map(struct, map_fn):
'''
This function will dive into an structure until it finds a tensor, and then
send it to a device.
Example:
if struct is a dict like:
x = {"audio": Tensor(64,22000),
"spectrogram": Tensor(64,1024,128)}
and map_fn is a function that sends a tensor to a device, then the result is
x = {"audio": Tensor(64,22000).to(device),
"spectrogram": Tensor(64,1024,128).to(device)}
'''
if isinstance(struct, tuple): if isinstance(struct, tuple):
return tuple(_nested_map(x, map_fn) for x in struct) return tuple(_nested_map(x, map_fn) for x in struct)
if isinstance(struct, list): if isinstance(struct, list):
...@@ -101,6 +112,8 @@ class DiffWaveLearner: ...@@ -101,6 +112,8 @@ class DiffWaveLearner:
def train(self, max_steps=None): def train(self, max_steps=None):
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
while True: while True:
# number of epochs = max_steps / num_batches
# e.g. for max_steps = 100000 and num_batches = 1000, we have 100 epochs
for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset: for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset:
if max_steps is not None and self.step >= max_steps: if max_steps is not None and self.step >= max_steps:
# Save final checkpoint. # Save final checkpoint.
......
...@@ -37,6 +37,9 @@ def silu(x): ...@@ -37,6 +37,9 @@ def silu(x):
class DiffusionEmbedding(nn.Module): class DiffusionEmbedding(nn.Module):
'''
Sinusoidal embedding for diffusion step.
'''
def __init__(self, max_steps): def __init__(self, max_steps):
super().__init__() super().__init__()
self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False) self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False)
...@@ -147,7 +150,8 @@ class DiffWave(nn.Module): ...@@ -147,7 +150,8 @@ class DiffWave(nn.Module):
def forward(self, audio, diffusion_step, spectrogram=None): def forward(self, audio, diffusion_step, spectrogram=None):
assert (spectrogram is None and self.spectrogram_upsampler is None) or \ assert (spectrogram is None and self.spectrogram_upsampler is None) or \
(spectrogram is not None and self.spectrogram_upsampler is not None) (spectrogram is not None and self.spectrogram_upsampler is not None)
x = audio.unsqueeze(1) # watch out for this, we can leave this to the dataloader actually # watch out for this, we can leave this to the dataloader actually
x = audio.unsqueeze(1) # shape is (batch_size, 1, trajectory_length)
x = self.input_projection(x) x = self.input_projection(x)
x = F.relu(x) x = F.relu(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment