Skip to content
Snippets Groups Projects
Commit 7591e1b0 authored by Maria Guaranda-Cabezas's avatar Maria Guaranda-Cabezas
Browse files

adapts code to train with trajectories in 1D

parent 82b61dfb
No related branches found
No related tags found
1 merge request!1Uses trajectories
# DiffWave
**Note**
This code is an adaptation of the original work by the LMNT team. [Original repo link](https://img.shields.io/github/license/lmnt-com/diffwave); the version used is 0.17.
This code is an adaptation of the original work by the LMNT team. [Original repo link](https://img.shields.io/github/license/lmnt-com/diffwave); version 0.17.
**Part of the original README follows here:**
......@@ -39,6 +39,12 @@ tensorboard --logdir /path/to/model/dir --bind_all
You should expect to hear intelligible (but noisy) speech by ~8k steps (~1.5h on a 2080 Ti).
#### Training with trajectories
You don't need to run any preprocessing.
```
python src/diffwave/__main__.py /path/to/model/dir /path/to/file/containing/trajectories --data_type trajectories_x #or trajectories if you want to use 3D data
```
#### Multi-GPU training
By default, this implementation uses as many GPUs in parallel as returned by [`torch.cuda.device_count()`](https://pytorch.org/docs/stable/cuda.html#torch.cuda.device_count). You can specify which GPUs to use by setting the [`CUDA_DEVICES_AVAILABLE`](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/) environment variable before running the training module.
......
......@@ -39,8 +39,8 @@ params = AttrDict(
max_grad_norm=None,
# Data params
# these are actually not used for trajectories
sample_rate=2000,
# these are actually not used
n_mels=80,
n_fft=1024,
hop_samples=256,
......@@ -49,7 +49,7 @@ params = AttrDict(
# Model params
residual_layers=8,
residual_channels=32,
dilation_cycle_length=8, # with this config and residual layers = 10, we get r=2047
dilation_cycle_length=8, # with this config and residual layers = 8, we get r=511*T
unconditional = True,
noise_schedule=np.linspace(1e-4, 0.01, 200).tolist(), # last param is num_timesteps
inference_noise_schedule=[1e-4, 0.001, 0.01, 0.05, 0.2, 0.5], # for fast sampling
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment