├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── configs └── default.json ├── data.py ├── filelists ├── test.txt └── train.txt ├── generated_samples ├── 1000iters.wav ├── 100iters.wav ├── 12iters.wav ├── 25iters.wav ├── 50iters.wav ├── 6iters.wav ├── 7iters.wav └── denoising.gif ├── inference.py ├── logger.py ├── model ├── __init__.py ├── base.py ├── diffusion_process.py ├── downsampling.py ├── interpolation.py ├── layers.py ├── linear_modulation.py ├── nn.py └── upsampling.py ├── notebooks └── inference.ipynb ├── requirements.txt ├── runs ├── inference.sh └── train.sh ├── schedules └── pretrained │ ├── 1000iters.pt │ ├── 100iters.pt │ ├── 12iters.pt │ ├── 25iters.pt │ ├── 50iters.pt │ ├── 6iters.pt │ └── 7iters.pt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*test.ipynb 2 | **/*logs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Ivan Vovk 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![alt-text-1](generated_samples/denoising.gif "denoising") 2 | 3 | # WaveGrad 4 | Implementation (PyTorch) of Google Brain's high-fidelity WaveGrad vocoder ([paper](https://arxiv.org/pdf/2009.00713.pdf)). First implementation on GitHub with high-quality generation for 6-iterations. 5 | 6 | ## **Status** 7 | 8 | - [x] Documented API. 9 | - [x] High-fidelity generation. 10 | - [x] Multi-iteration inference support (**stable for low iterations**). 11 | - [x] Stable and fast training with **mixed-precision** support. 12 | - [x] **Distributed training** support. 13 | - [x] Training also successfully runs on a single 12GB GPU with batch size 96. 14 | - [x] CLI inference support. 15 | - [x] Flexible architecture configuration for your own data. 16 | - [x] Estimated RTF on popular GPU and CPU devices (see below). 17 | - [x] 100- and lower-iteration inferences are faster than real-time on RTX 2080 Ti. **6-iteration inference is faster than one reported in the paper**. 18 | - [x] **Parallel grid search for the best noise schedule**. 19 | - [x] Uploaded generated samples for different number of iterations (see `generated_samples` folder). 20 | - [x] [Pretrained checkpoint](https://drive.google.com/file/d/1X_AquK11C0j7U1lLxMDBdK5guabHaoXB/view?usp=sharing) on 22KHz LJSpeech dataset **with noise schedules**. 21 | 22 | #### Real-time factor (RTF) 23 | 24 | **Number of parameters**: 15.810.401 25 | 26 | | Model | Stable | RTX 2080 Ti | Tesla K80 | Intel Xeon 2.3GHz* | 27 | |-------------------|-----------|---------------|---------------|--------------------| 28 | | 1000 iterations | **+** | 9.59 | - | - | 29 | | 100 iterations | **+** | 0.94 | 5.85 | - | 30 | | 50 iterations | **+** | 0.45 | 2.92 | - | 31 | | 25 iterations | **+** | 0.22 | 1.45 | - | 32 | | 12 iterations | **+** | 0.10 | 0.69 | 4.55 | 33 | | 6 iterations | **+** | 0.04 | 0.33 | 2.09 | 34 | 35 | ***Note**: Used an old version of Intel Xeon CPU. 36 | 37 | ___ 38 | ## About 39 | 40 | WaveGrad is a conditional model for waveform generation through estimating gradients of the data density with WaveNet-similar sampling quality. **This vocoder is neither GAN, nor Normalizing Flow, nor classical autoregressive model**. The main concept of vocoder is based on *Denoising Diffusion Probabilistic Models* (DDPM), which utilize *Langevin dynamics* and *score matching* frameworks. Furthemore, comparing to classic DDPM, WaveGrad achieves super-fast convergence (6 iterations and probably lower) w.r.t. Langevin dynamics iterative sampling scheme. 41 | 42 | ___ 43 | ## Installation 44 | 45 | 1. Clone this repo: 46 | 47 | ```bash 48 | git clone https://github.com/ivanvovk/WaveGrad.git 49 | cd WaveGrad 50 | ``` 51 | 52 | 2. Install requirements: 53 | ```bash 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | ___ 58 | ## Training 59 | 60 | #### 1 Preparing data 61 | 62 | 1. Make train and test filelists of your audio data like ones included into `filelists` folder. 63 | 2. Make a configuration file* in `configs` folder. 64 | 65 | ***Note:** if you are going to change `hop_length` for STFT, then make sure that the product of your upsampling `factors` in config is equal to your new `hop_length`. 66 | 67 | #### 2 Single and Distributed GPU training 68 | 69 | 1. Open `runs/train.sh` script and specify visible GPU devices and path to your configuration file. If you specify more than one GPU the training will run in distributed mode. 70 | 2. Run `sh runs/train.sh` 71 | 72 | #### 3 Tensorboard and logging 73 | 74 | To track your training process run tensorboard by `tensorboard --logdir=logs/YOUR_LOGDIR_FOLDER`. All logging information and checkpoints will be stored in `logs/YOUR_LOGDIR_FOLDER`. `logdir` is specified in config file. 75 | 76 | #### 4 Noise schedule grid search 77 | 78 | Once model is trained, grid search for the best schedule* for a needed number of iterations in [`notebooks/inference.ipynb`](notebooks/inference.ipynb). The code supports parallelism, so you can specify more than one number of jobs to accelerate the search. 79 | 80 | ***Note**: grid search is necessary just for a small number of iterations (like 6 or 7). For larger number just try Fibonacci sequence `benchmark.fibonacci(...)` initialization: I used it for 25 iteration and it works well. From good 25-iteration schedule, for example, you can build a higher-order schedule by copying elements. 81 | 82 | ##### Noise schedules for pretrained model 83 | 84 | * 6-iteration schedule was obtained using grid search. After, based on obtained scheme, by hand, I found a slightly better approximation. 85 | * 7-iteration schedule was obtained in the same way. 86 | * 12-iteration schedule was obtained in the same way. 87 | * 25-iteration schedule was obtained using Fibonacci sequence `benchmark.fibonacci(...)`. 88 | * 50-iteration schedule was obtained by repeating elements from 25-iteration scheme. 89 | * 100-iteration schedule was obtained in the same way. 90 | * 1000-iteration schedule was obtained in the same way. 91 | 92 | ___ 93 | ## Inference 94 | 95 | #### CLI 96 | 97 | Put your mel-spectrograms in some folder. Make a filelist. Then run this command with your own arguments: 98 | 99 | ```bash 100 | sh runs/inference.sh -c -ch -ns -m -v "yes" 101 | ``` 102 | 103 | #### Jupyter Notebook 104 | 105 | More inference details are provided in [`notebooks/inference.ipynb`](notebooks/inference.ipynb). There you can also find how to set a noise schedule for the model and make grid search for the best scheme. 106 | 107 | ___ 108 | ## Other 109 | 110 | #### Generated audios 111 | 112 | Examples of generated audios are provided in [`generated_samples`](generated_samples/) folder. Quality degradation between 1000-iteration and 6-iteration inferences is not noticeable if found the best schedule for the latter. 113 | 114 | #### Pretrained checkpoints 115 | 116 | You can find a pretrained checkpoint file* on LJSpeech (22KHz) via [this](https://drive.google.com/file/d/1X_AquK11C0j7U1lLxMDBdK5guabHaoXB/view?usp=sharing) Google Drive link. 117 | 118 | ***Note**: uploaded checkpoint is a `dict` with a single key `'model'`. 119 | 120 | ___ 121 | ## Important details, issues and comments 122 | 123 | * During training WaveGrad uses a default noise schedule with 1000 iterations and linear scale betas from range (1e-6, 0.01). For inference you can set another schedule with less iterations. Tune betas carefully, the output quality really highly depends on it. 124 | * By default model runs in a mixed-precision way. Batch size is modified compared to the paper (256 -> 96) since authors trained their model on TPU. 125 | * After ~10k training iterations (1-2 hours) on a single GPU the model performs good generation for 50-iteration inference. Total training time is about 1-2 days (for absolute convergence). 126 | * At some point training might start to behave weird and crazy (loss explodes), so I have introduced learning rate (LR) scheduling and gradient clipping. If loss explodes for your data, then try to decrease LR scheduler gamma a bit. It should help. 127 | * By default hop length of your STFT is equal 300 (thus total upsampling factor). Other cases are not tested, but you can try. Remember, that total upsampling factor should be still equal to your new hop length. 128 | 129 | ___ 130 | ## History of updates 131 | 132 | * (**NEW**: 10/24/2020) Huge update. Distributed training and mixed-precision support. More correct positional encoding. CLI support for inference. Parallel grid search. Model size significantly decreased. 133 | * New RTF info for NVIDIA Tesla K80 GPU card (popular in Google Colab service) and CPU Intel Xeon 2.3GHz. 134 | * Huge update. New 6-iteration well generated sample example. New noise schedule setting API. Added the best schedule grid search code. 135 | * Improved training by introducing smarter learning rate scheduler. Obtained high-fidelity synthesis. 136 | * Stable training and multi-iteration inference. 6-iteration noise scheduling is supported. 137 | * Stable training and fixed-iteration inference with significant background static noise left. All positional encoding issues are solved. 138 | * Stable training of 25-, 50- and 1000-fixed-iteration models. Found no linear scaling (C=5000 from paper) of positional encoding (bug). 139 | * Stable training of 25-, 50- and 1000-fixed-iteration models. Fixed positional encoding downscaling. Parallel segment sampling is replaced by full-mel sampling. 140 | * (**RELEASE, first on GitHub**). Parallel segment sampling and broken positional encoding downscaling. Bad quality with clicks from concatenation from parallel-segment generation. 141 | 142 | ___ 143 | ## References 144 | 145 | * Nanxin Chen et al., [WaveGrad: Estimating Gradients for Waveform Generation](https://arxiv.org/pdf/2009.00713.pdf) 146 | * Jonathan Ho et al., [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239.pdf) 147 | * [Denoising Diffusion Probabilistic Models repository](https://github.com/hojonathanho/diffusion) (TensorFlow implementation), from which diffusion calculations have been adopted 148 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from datetime import datetime 8 | from tqdm import tqdm 9 | from functools import partial 10 | from multiprocessing.dummy import Pool as ThreadPool 11 | 12 | from data import AudioDataset, MelSpectrogramFixed 13 | from utils import show_message 14 | 15 | 16 | def compute_rtf(sample, generation_time, sample_rate=22050): 17 | """ 18 | Computes RTF for a given sample. 19 | """ 20 | total_length = sample.shape[-1] 21 | return float(generation_time * sample_rate / total_length) 22 | 23 | 24 | def estimate_average_rtf_on_filelist(filelist_path, config, model, verbose=True): 25 | """ 26 | Runs RTF estimation of filelist of audios and computes statistics. 27 | :param filelist_path (str): path to a filelist with needed audios 28 | :param config (utils.ConfigWrapper): configuration dict 29 | :param model (torch.nn.Module): WaveGrad model 30 | :param verbose (bool, optional): verbosity level 31 | :return stats: statistics dict 32 | """ 33 | device = next(model.parameters()).device 34 | config.training_config.test_filelist_path = filelist_path 35 | dataset = AudioDataset(config, training=False) 36 | mel_fn = MelSpectrogramFixed( 37 | sample_rate=config.data_config.sample_rate, 38 | n_fft=config.data_config.n_fft, 39 | win_length=config.data_config.win_length, 40 | hop_length=config.data_config.hop_length, 41 | f_min=config.data_config.f_min, 42 | f_max=config.data_config.f_max, 43 | n_mels=config.data_config.n_mels, 44 | window_fn=torch.hann_window 45 | ).to(device) 46 | rtfs = [] 47 | for i in (tqdm(range(len(dataset))) if verbose else range(len(dataset))): 48 | datapoint = dataset[i].to(device) 49 | mel = mel_fn(datapoint)[None] 50 | start = datetime.now() 51 | sample = model.forward(mel, store_intermediate_states=False) 52 | end = datetime.now() 53 | generation_time = (end - start).total_seconds() 54 | rtf = compute_rtf( 55 | sample, generation_time, sample_rate=config.data_config.sample_rate 56 | ) 57 | rtfs.append(rtf) 58 | average_rtf = np.mean(rtfs) 59 | std_rtf = np.std(rtfs) 60 | 61 | show_message(f'DEVICE: {device}. average_rtf={average_rtf}, std={std_rtf}', verbose=verbose) 62 | 63 | rtf_stats = { 64 | 'rtfs': rtfs, 65 | 'average': average_rtf, 66 | 'std': std_rtf 67 | } 68 | return rtf_stats 69 | 70 | 71 | def _betas_estimate(betas, model, mels, mel_fn): 72 | n_iter = len(betas) 73 | init_fn = lambda **kwargs: torch.FloatTensor(betas) 74 | model.set_new_noise_schedule(init=init_fn, init_kwargs={'steps': n_iter}) 75 | 76 | outputs = model.forward(mels, store_intermediate_states=False) 77 | test_pred_mels = mel_fn(outputs) 78 | 79 | loss = torch.nn.L1Loss()(test_pred_mels, mels).item() 80 | return loss 81 | 82 | 83 | def generate_betas_grid(n_iter, betas_range, verbose): 84 | betas_range = torch.FloatTensor(betas_range).log10() 85 | exp_step = (betas_range[1] - betas_range[0]) / (n_iter - 1) 86 | exponents = 10**torch.arange(betas_range[0], betas_range[1] + exp_step, step=exp_step) 87 | 88 | grid = [] 89 | state = int(''.join(['1'] * n_iter)) # initial state 90 | final_state = 9**n_iter 91 | max_grid_size = 9**5 92 | step = int(np.ceil(final_state / (max_grid_size))) 93 | for _ in range(max_grid_size): 94 | multipliers = list(map(int, str(state))) 95 | if 0 in multipliers: 96 | state += step 97 | continue 98 | betas = [mult * exp for mult, exp in zip(multipliers, exponents)] 99 | grid.append(betas) 100 | state += step 101 | return grid 102 | 103 | 104 | def iters_schedule_grid_search(model, config, 105 | n_iter=6, 106 | betas_range=(1e-6, 1e-2), 107 | test_batch_size=2, 108 | step=1, 109 | path_to_store_schedule=None, 110 | save_stats_for_grid=True, 111 | verbose=True, 112 | n_jobs=1): 113 | """ 114 | Performs grid search for 6 iterations schedule. Run it only on GPU and only for a small number of iterations! 115 | :param model (torch.nn.Module): WaveGrad model 116 | :param config (ConfigWrapper): model configuration 117 | :param n_iter (int, optional): number of iterations to search for 118 | :param test_batch_size (int, optional): number of one second samples to be tested grid sets on 119 | :param path_to_store_schedule (str, optional): path to store stats. If not specified, then it will no be saved and would be just returned. 120 | :param save_stats_for_grid (str, optional): flag to save stats for whole grid or not 121 | :param verbose (bool, optional): output all the process 122 | :param n_jobs(int, optional): number of parallel threads to use 123 | :return betas (list): list of betas, which gives the lowest log10-mel-spectrogram absolute error 124 | :return stats (dict): dict of type {betas: loss} for the whole grid 125 | """ 126 | device = next(model.parameters()).device 127 | if 'cpu' in str(device): 128 | show_message('WARNING: running grid search on CPU will be slow.') 129 | 130 | show_message('Initializing betas grid...', verbose=verbose) 131 | grid = generate_betas_grid(n_iter, betas_range, verbose=verbose)[::step] 132 | 133 | show_message('Initializing utils...', verbose=verbose) 134 | mel_fn = MelSpectrogramFixed( 135 | sample_rate=config.data_config.sample_rate, 136 | n_fft=config.data_config.n_fft, 137 | win_length=config.data_config.win_length, 138 | hop_length=config.data_config.hop_length, 139 | f_min=config.data_config.f_min, 140 | f_max=config.data_config.f_max, 141 | n_mels=config.data_config.n_mels, 142 | window_fn=torch.hann_window 143 | ).to(device) 144 | dataset = AudioDataset(config, training=True) 145 | idx = np.random.choice(range(len(dataset)), size=test_batch_size, replace=False) 146 | test_batch = torch.stack([dataset[i] for i in idx]).to(device) 147 | test_mels = mel_fn(test_batch) 148 | 149 | show_message('Starting search...', verbose=verbose) 150 | with ThreadPool(processes=n_jobs) as pool: 151 | process_fn = partial(_betas_estimate, model=model, mels=test_mels, mel_fn=mel_fn) 152 | stats = list(tqdm(pool.imap(process_fn, grid), total=len(grid))) 153 | stats = {i : (grid[i], stats[i]) for i in range(len(stats))} 154 | 155 | if save_stats_for_grid: 156 | tmp_stats_path = f'{os.path.dirname(path_to_store_schedule)}/{n_iter}stats.pt' 157 | show_message(f'Saving tmp stats for whole grid to `{tmp_stats_path}`...', verbose=verbose) 158 | torch.save(stats, tmp_stats_path) 159 | 160 | best_idx = np.argmin(list([value for _, value in stats.values()])) 161 | best_betas = grid[best_idx] 162 | 163 | if not isinstance(path_to_store_schedule, type(None)): 164 | show_message(f'Saving best schedule to `{path_to_store_schedule}`...', verbose=verbose) 165 | torch.save(best_betas, path_to_store_schedule) 166 | 167 | return best_betas, stats 168 | 169 | 170 | def fibonacci(b1=1e-6, b2=9e-6, n_iter=25): 171 | betas = [b1, b2] 172 | for _ in range(n_iter - 2): 173 | betas.append(sum(betas[-2:])) 174 | return betas[:n_iter] 175 | -------------------------------------------------------------------------------- /configs/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "factors": [5, 5, 3, 2, 2], 4 | "upsampling_preconv_out_channels": 768, 5 | "upsampling_out_channels": [512, 512, 256, 128, 128], 6 | "upsampling_dilations": [ 7 | [1, 2, 1, 2], 8 | [1, 2, 1, 2], 9 | [1, 2, 4, 8], 10 | [1, 2, 4, 8], 11 | [1, 2, 4, 8] 12 | ], 13 | "downsampling_preconv_out_channels": 32, 14 | "downsampling_out_channels": [128, 128, 256, 512], 15 | "downsampling_dilations": [ 16 | [1, 2, 4], [1, 2, 4], [1, 2, 4], [1, 2, 4] 17 | ] 18 | }, 19 | "data_config": { 20 | "sample_rate": 22050, 21 | "n_fft": 1024, 22 | "win_length": 1024, 23 | "hop_length": 300, 24 | "f_min": 80.0, 25 | "f_max": 8000, 26 | "n_mels": 80 27 | }, 28 | "training_config": { 29 | "logdir": "logs/default", 30 | "continue_training": false, 31 | "train_filelist_path": "filelists/train.txt", 32 | "test_filelist_path": "filelists/test.txt", 33 | "batch_size": 96, 34 | "segment_length": 7200, 35 | "lr": 1e-3, 36 | "grad_clip_threshold": 1, 37 | "scheduler_step_size": 1, 38 | "scheduler_gamma": 0.9, 39 | "n_epoch": 100000000, 40 | "n_samples_to_test": 4, 41 | "test_interval": 1, 42 | "use_fp16": true, 43 | 44 | "training_noise_schedule": { 45 | "n_iter": 1000, 46 | "betas_range": [1e-6, 0.01] 47 | }, 48 | "test_noise_schedule": { 49 | "n_iter": 50, 50 | "betas_range": [1e-6, 0.01] 51 | } 52 | }, 53 | "dist_config": { 54 | "MASTER_ADDR": "localhost", 55 | "MASTER_PORT": "600010" 56 | } 57 | } -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(1234) 3 | 4 | import torch 5 | import torchaudio 6 | from torchaudio.transforms import MelSpectrogram 7 | 8 | from utils import parse_filelist 9 | 10 | 11 | class AudioDataset(torch.utils.data.Dataset): 12 | """ 13 | Provides dataset management for given filelist. 14 | """ 15 | def __init__(self, config, training=True): 16 | super(AudioDataset, self).__init__() 17 | self.config = config 18 | self.hop_length = config.data_config.hop_length 19 | self.training = training 20 | 21 | if self.training: 22 | self.segment_length = config.training_config.segment_length 23 | self.sample_rate = config.data_config.sample_rate 24 | 25 | self.filelist_path = config.training_config.train_filelist_path \ 26 | if self.training else config.training_config.test_filelist_path 27 | self.audio_paths = parse_filelist(self.filelist_path) 28 | 29 | def load_audio_to_torch(self, audio_path): 30 | audio, sample_rate = torchaudio.load(audio_path) 31 | # To ensure upsampling/downsampling will be processed in a right way for full signals 32 | if not self.training: 33 | p = (audio.shape[-1] // self.hop_length + 1) * self.hop_length - audio.shape[-1] 34 | audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data 35 | return audio.squeeze(), sample_rate 36 | 37 | def __getitem__(self, index): 38 | audio_path = self.audio_paths[index] 39 | audio, sample_rate = self.load_audio_to_torch(audio_path) 40 | 41 | assert sample_rate == self.sample_rate, \ 42 | f"""Got path to audio of sampling rate {sample_rate}, \ 43 | but required {self.sample_rate} according config.""" 44 | 45 | if not self.training: # If test 46 | return audio 47 | # Take segment of audio for training 48 | if audio.shape[-1] > self.segment_length: 49 | max_audio_start = audio.shape[-1] - self.segment_length 50 | audio_start = np.random.randint(0, max_audio_start) 51 | segment = audio[audio_start:audio_start+self.segment_length] 52 | else: 53 | segment = torch.nn.functional.pad( 54 | audio, (0, self.segment_length - audio.shape[-1]), 'constant' 55 | ).data 56 | return segment 57 | 58 | def __len__(self): 59 | return len(self.audio_paths) 60 | 61 | def sample_test_batch(self, size): 62 | idx = np.random.choice(range(len(self)), size=size, replace=False) 63 | test_batch = [] 64 | for index in idx: 65 | test_batch.append(self.__getitem__(index)) 66 | return test_batch 67 | 68 | 69 | class MelSpectrogramFixed(torch.nn.Module): 70 | """In order to remove padding of torchaudio package + add log10 scale.""" 71 | def __init__(self, **kwargs): 72 | super(MelSpectrogramFixed, self).__init__() 73 | self.torchaudio_backend = MelSpectrogram(**kwargs) 74 | 75 | def forward(self, x): 76 | outputs = self.torchaudio_backend(x).log10() 77 | mask = torch.isinf(outputs) 78 | outputs[mask] = 0 79 | return outputs[..., :-1] 80 | -------------------------------------------------------------------------------- /filelists/test.txt: -------------------------------------------------------------------------------- 1 | YOUR_PATH/LJ033-0112.wav 2 | YOUR_PATH/LJ047-0089.wav 3 | YOUR_PATH/LJ048-0040.wav 4 | YOUR_PATH/LJ010-0096.wav 5 | YOUR_PATH/LJ022-0092.wav 6 | YOUR_PATH/LJ027-0160.wav 7 | YOUR_PATH/LJ048-0023.wav 8 | YOUR_PATH/LJ002-0235.wav 9 | YOUR_PATH/LJ015-0055.wav 10 | YOUR_PATH/LJ006-0203.wav 11 | YOUR_PATH/LJ045-0056.wav 12 | YOUR_PATH/LJ018-0380.wav 13 | YOUR_PATH/LJ040-0023.wav 14 | YOUR_PATH/LJ003-0331.wav 15 | YOUR_PATH/LJ032-0051.wav 16 | YOUR_PATH/LJ002-0194.wav 17 | YOUR_PATH/LJ009-0273.wav 18 | YOUR_PATH/LJ006-0096.wav 19 | YOUR_PATH/LJ018-0353.wav 20 | YOUR_PATH/LJ030-0116.wav 21 | YOUR_PATH/LJ019-0124.wav 22 | YOUR_PATH/LJ011-0207.wav 23 | YOUR_PATH/LJ034-0094.wav 24 | YOUR_PATH/LJ005-0273.wav 25 | YOUR_PATH/LJ018-0179.wav 26 | YOUR_PATH/LJ004-0221.wav 27 | YOUR_PATH/LJ003-0127.wav 28 | YOUR_PATH/LJ019-0239.wav 29 | YOUR_PATH/LJ048-0261.wav 30 | YOUR_PATH/LJ037-0149.wav 31 | YOUR_PATH/LJ028-0217.wav 32 | YOUR_PATH/LJ029-0130.wav 33 | YOUR_PATH/LJ001-0178.wav 34 | YOUR_PATH/LJ010-0060.wav 35 | YOUR_PATH/LJ020-0028.wav 36 | YOUR_PATH/LJ008-0156.wav 37 | YOUR_PATH/LJ048-0076.wav 38 | YOUR_PATH/LJ037-0035.wav 39 | YOUR_PATH/LJ004-0107.wav 40 | YOUR_PATH/LJ012-0097.wav 41 | YOUR_PATH/LJ017-0234.wav 42 | YOUR_PATH/LJ047-0048.wav 43 | YOUR_PATH/LJ008-0210.wav 44 | YOUR_PATH/LJ049-0137.wav 45 | YOUR_PATH/LJ008-0065.wav 46 | YOUR_PATH/LJ036-0125.wav 47 | YOUR_PATH/LJ002-0094.wav 48 | YOUR_PATH/LJ041-0026.wav 49 | YOUR_PATH/LJ030-0082.wav 50 | YOUR_PATH/LJ036-0149.wav 51 | YOUR_PATH/LJ050-0277.wav 52 | YOUR_PATH/LJ042-0132.wav 53 | YOUR_PATH/LJ028-0381.wav 54 | YOUR_PATH/LJ027-0128.wav 55 | YOUR_PATH/LJ048-0054.wav 56 | YOUR_PATH/LJ010-0062.wav 57 | YOUR_PATH/LJ030-0071.wav 58 | YOUR_PATH/LJ032-0058.wav 59 | YOUR_PATH/LJ005-0129.wav 60 | YOUR_PATH/LJ049-0106.wav 61 | YOUR_PATH/LJ031-0042.wav 62 | YOUR_PATH/LJ007-0113.wav 63 | YOUR_PATH/LJ032-0167.wav 64 | YOUR_PATH/LJ028-0467.wav 65 | YOUR_PATH/LJ019-0329.wav 66 | YOUR_PATH/LJ010-0241.wav 67 | YOUR_PATH/LJ011-0146.wav 68 | YOUR_PATH/LJ030-0100.wav 69 | YOUR_PATH/LJ037-0113.wav 70 | YOUR_PATH/LJ045-0158.wav 71 | YOUR_PATH/LJ006-0184.wav 72 | YOUR_PATH/LJ046-0237.wav 73 | YOUR_PATH/LJ008-0209.wav 74 | YOUR_PATH/LJ050-0100.wav 75 | YOUR_PATH/LJ026-0115.wav 76 | YOUR_PATH/LJ030-0224.wav 77 | YOUR_PATH/LJ032-0137.wav 78 | YOUR_PATH/LJ049-0046.wav 79 | YOUR_PATH/LJ005-0057.wav 80 | YOUR_PATH/LJ030-0106.wav 81 | YOUR_PATH/LJ012-0120.wav 82 | YOUR_PATH/LJ027-0037.wav 83 | YOUR_PATH/LJ019-0182.wav 84 | YOUR_PATH/LJ017-0270.wav 85 | YOUR_PATH/LJ033-0122.wav 86 | YOUR_PATH/LJ025-0152.wav 87 | YOUR_PATH/LJ019-0254.wav 88 | YOUR_PATH/LJ046-0171.wav 89 | YOUR_PATH/LJ001-0167.wav 90 | YOUR_PATH/LJ024-0089.wav 91 | YOUR_PATH/LJ011-0208.wav 92 | YOUR_PATH/LJ016-0317.wav 93 | YOUR_PATH/LJ032-0043.wav 94 | YOUR_PATH/LJ042-0210.wav 95 | YOUR_PATH/LJ011-0280.wav 96 | YOUR_PATH/LJ035-0197.wav 97 | YOUR_PATH/LJ033-0030.wav 98 | YOUR_PATH/LJ015-0099.wav 99 | YOUR_PATH/LJ042-0106.wav 100 | YOUR_PATH/LJ046-0238.wav -------------------------------------------------------------------------------- /generated_samples/1000iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/1000iters.wav -------------------------------------------------------------------------------- /generated_samples/100iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/100iters.wav -------------------------------------------------------------------------------- /generated_samples/12iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/12iters.wav -------------------------------------------------------------------------------- /generated_samples/25iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/25iters.wav -------------------------------------------------------------------------------- /generated_samples/50iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/50iters.wav -------------------------------------------------------------------------------- /generated_samples/6iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/6iters.wav -------------------------------------------------------------------------------- /generated_samples/7iters.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/7iters.wav -------------------------------------------------------------------------------- /generated_samples/denoising.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/generated_samples/denoising.gif -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | import torchaudio 8 | 9 | from tqdm import tqdm 10 | from datetime import datetime 11 | 12 | from model import WaveGrad 13 | from benchmark import compute_rtf 14 | from utils import ConfigWrapper, show_message, str2bool, parse_filelist 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | '-c', '--config', required=True, 21 | type=str, help='configuration file path' 22 | ) 23 | parser.add_argument( 24 | '-ch', '--checkpoint_path', 25 | required=True, type=str, help='checkpoint path' 26 | ) 27 | parser.add_argument( 28 | '-ns', '--noise_schedule_path', required=True, type=str, 29 | help='noise schedule, should be just a torch.Tensor array of shape [n_iter]' 30 | ) 31 | parser.add_argument( 32 | '-m', '--mel_filelist', required=True, type=str, 33 | help='mel spectorgram filelist, files of which should be just a torch.Tensor array of shape [n_mels, T]' 34 | ) 35 | parser.add_argument( 36 | '-v', '--verbose', required=False, type=str2bool, 37 | nargs='?', const=True, default=True, help='verbosity level' 38 | ) 39 | args = parser.parse_args() 40 | 41 | # Initialize config 42 | with open(args.config) as f: 43 | config = ConfigWrapper(**json.load(f)) 44 | 45 | # Initialize the model 46 | model = WaveGrad(config) 47 | model.load_state_dict(torch.load(args.checkpoint_path)['model'], strict=False) 48 | 49 | # Set noise schedule 50 | noise_schedule = torch.load(args.noise_schedule_path) 51 | n_iter = noise_schedule.shape[-1] 52 | init_fn = lambda **kwargs: noise_schedule 53 | init_kwargs = {'steps': n_iter} 54 | model.set_new_noise_schedule(init_fn, init_kwargs) 55 | 56 | # Trying to run inference on GPU 57 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 58 | model = model.to(device) 59 | 60 | # Inference 61 | filelist = parse_filelist(args.mel_filelist) 62 | rtfs = [] 63 | for mel_path in (tqdm(filelist, leave=False) if args.verbose else filelist): 64 | with torch.no_grad(): 65 | mel = torch.load(mel_path).unsqueeze(0).to(device) 66 | 67 | start = datetime.now() 68 | outputs = model.forward(mel, store_intermediate_states=False) 69 | end = datetime.now() 70 | 71 | outputs = outputs.cpu().squeeze() 72 | baseidx = os.path.basename(os.path.abspath(mel_path)).split('_')[-1].replace('.pt', '') 73 | save_path = f'{os.path.dirname(os.path.abspath(mel_path))}/predicted_{baseidx}.wav' 74 | torchaudio.save( 75 | save_path, outputs, sample_rate=config.data_config.sample_rate 76 | ) 77 | 78 | inference_time = (end - start).total_seconds() 79 | rtf = compute_rtf(outputs, inference_time, sample_rate=config.data_config.sample_rate) 80 | rtfs.append(rtf) 81 | 82 | show_message(f'Done. RTF estimate: {np.mean(rtfs)} ± {np.std(rtfs)}', verbose=args.verbose) 83 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from utils import show_message, load_latest_checkpoint, plot_tensor_to_numpy 8 | 9 | 10 | class Logger(object): 11 | def __init__(self, config, rank=0): 12 | self.rank = rank 13 | self.summary_writer = None 14 | self.continue_training = config.training_config.continue_training 15 | self.logdir = config.training_config.logdir 16 | self.sample_rate = config.data_config.sample_rate 17 | 18 | if self.rank == 0: 19 | if not self.continue_training and os.path.exists(self.logdir): 20 | raise RuntimeError( 21 | f"You're trying to run training from scratch, " 22 | f"but logdir `{self.logdir} already exists. Remove it or specify new one.`" 23 | ) 24 | if not self.continue_training: 25 | os.makedirs(self.logdir) 26 | self.summary_writer = SummaryWriter(config.training_config.logdir) 27 | self.save_model_config(config) 28 | 29 | def _log_losses(self, iteration, loss_stats: dict): 30 | for key, value in loss_stats.items(): 31 | self.summary_writer.add_scalar(key, value, iteration) 32 | 33 | def log_training(self, iteration, stats, verbose=True): 34 | if self.rank != 0: return 35 | stats = {f'training/{key}': value for key, value in stats.items()} 36 | self._log_losses(iteration, loss_stats=stats) 37 | show_message( 38 | f'Iteration: {iteration} | Losses: {[value for value in stats.values()]}', 39 | verbose=verbose 40 | ) 41 | 42 | def log_test(self, iteration, stats, verbose=True): 43 | if self.rank != 0: return 44 | stats = {f'test/{key}': value for key, value in stats.items()} 45 | self._log_losses(iteration, loss_stats=stats) 46 | show_message( 47 | f'Iteration: {iteration} | Losses: {[value for value in stats.values()]}', 48 | verbose=verbose 49 | ) 50 | 51 | def log_audios(self, iteration, audios: dict): 52 | if self.rank != 0: return 53 | for key, audio in audios.items(): 54 | self.summary_writer.add_audio(key, audio, iteration, sample_rate=self.sample_rate) 55 | 56 | def log_specs(self, iteration, specs: dict): 57 | if self.rank != 0: return 58 | for key, image in specs.items(): 59 | self.summary_writer.add_image(key, plot_tensor_to_numpy(image), iteration, dataformats='HWC') 60 | 61 | def save_model_config(self, config): 62 | if self.rank != 0: return 63 | with open(f'{self.logdir}/config.json', 'w') as f: 64 | json.dump(config.to_dict_type(), f) 65 | 66 | def save_checkpoint(self, iteration, model, optimizer=None): 67 | if self.rank != 0: return 68 | d = {} 69 | d['iteration'] = iteration 70 | d['model'] = model.state_dict() 71 | if not isinstance(optimizer, type(None)): 72 | d['optimizer'] = optimizer.state_dict() 73 | filename = f'{self.summary_writer.log_dir}/checkpoint_{iteration}.pt' 74 | torch.save(d, filename) 75 | 76 | def load_latest_checkpoint(self, model, optimizer=None): 77 | if not self.continue_training: 78 | raise RuntimeError( 79 | f"Trying to load the latest checkpoint from logdir {self.logdir}, " 80 | "but did not set `continue_training=true` in configuration." 81 | ) 82 | model, optimizer, iteration = load_latest_checkpoint(self.logdir, model, optimizer) 83 | return model, optimizer, iteration 84 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_process import WaveGrad 2 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModule(torch.nn.Module): 5 | def __init__(self): 6 | super(BaseModule, self).__init__() 7 | 8 | @property 9 | def nparams(self): 10 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 11 | -------------------------------------------------------------------------------- /model/diffusion_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from model.base import BaseModule 6 | from model.nn import WaveGradNN 7 | 8 | 9 | class WaveGrad(BaseModule): 10 | """ 11 | WaveGrad diffusion process as described in WaveGrad paper 12 | (link: https://arxiv.org/pdf/2009.00713.pdf). 13 | Implementation adopted from `Denoising Diffusion Probabilistic Models` 14 | repository (link: https://github.com/hojonathanho/diffusion, 15 | paper: https://arxiv.org/pdf/2006.11239.pdf). 16 | """ 17 | def __init__(self, config): 18 | super(WaveGrad, self).__init__() 19 | # Setup noise schedule 20 | self.noise_schedule_is_set = False 21 | 22 | # Backbone neural network to model noise 23 | self.total_factor = np.product(config.model_config.factors) 24 | assert self.total_factor == config.data_config.hop_length, \ 25 | """Total factor-product should be equal to the hop length of STFT.""" 26 | self.nn = WaveGradNN(config) 27 | 28 | def set_new_noise_schedule( 29 | self, 30 | init=torch.linspace, 31 | init_kwargs={'steps': 50, 'start': 1e-6, 'end': 1e-2} 32 | ): 33 | """ 34 | Sets sampling noise schedule. Authors in the paper showed 35 | that WaveGrad supports variable noise schedules during inference. 36 | Thanks to the continuous noise level conditioning. 37 | :param init (callable function, optional): function which initializes betas 38 | :param init_kwargs (dict, optional): dict of arguments to be pushed to `init` function. 39 | Should always contain the key `steps` corresponding to the number of iterations to be done by the model. 40 | This is done so because `torch.linspace` has this argument named as `steps`. 41 | """ 42 | assert 'steps' in list(init_kwargs.keys()), \ 43 | '`init_kwargs` should always contain the key `steps` corresponding to the number of iterations to be done by the model.' 44 | n_iter = init_kwargs['steps'] 45 | 46 | betas = init(**init_kwargs) 47 | alphas = 1 - betas 48 | alphas_cumprod = alphas.cumprod(dim=0) 49 | alphas_cumprod_prev = torch.cat([torch.FloatTensor([1]), alphas_cumprod[:-1]]) 50 | alphas_cumprod_prev_with_last = torch.cat([torch.FloatTensor([1]), alphas_cumprod]) 51 | self.register_buffer('betas', betas) 52 | self.register_buffer('alphas', alphas) 53 | self.register_buffer('alphas_cumprod', alphas_cumprod) 54 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 55 | 56 | # Calculations for posterior q(y_n|y_0) 57 | sqrt_alphas_cumprod = alphas_cumprod.sqrt() 58 | # For WaveGrad special continuous noise level conditioning 59 | self.sqrt_alphas_cumprod_prev = alphas_cumprod_prev_with_last.sqrt().numpy() 60 | sqrt_recip_alphas_cumprod = (1 / alphas_cumprod).sqrt() 61 | sqrt_alphas_cumprod_m1 = (1 - alphas_cumprod).sqrt() * sqrt_recip_alphas_cumprod 62 | self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod) 63 | self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod) 64 | self.register_buffer('sqrt_alphas_cumprod_m1', sqrt_alphas_cumprod_m1) 65 | 66 | # Calculations for posterior q(y_{t-1} | y_t, y_0) 67 | posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod) 68 | posterior_variance = torch.stack([posterior_variance, torch.FloatTensor([1e-20] * n_iter)]) 69 | posterior_log_variance_clipped = posterior_variance.max(dim=0).values.log() 70 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 71 | posterior_mean_coef1 = betas * alphas_cumprod_prev.sqrt() / (1 - alphas_cumprod) 72 | posterior_mean_coef2 = (1 - alphas_cumprod_prev) * alphas.sqrt() / (1 - alphas_cumprod) 73 | self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped) 74 | self.register_buffer('posterior_mean_coef1', posterior_mean_coef1) 75 | self.register_buffer('posterior_mean_coef2', posterior_mean_coef2) 76 | 77 | self.n_iter = n_iter 78 | self.noise_schedule_kwargs = {'init': init, 'init_kwargs': init_kwargs} 79 | self.noise_schedule_is_set = True 80 | 81 | def sample_continuous_noise_level(self, batch_size, device): 82 | """ 83 | Samples continuous noise level sqrt(alpha_cumprod). 84 | This is what makes WaveGrad different from other Denoising Diffusion Probabilistic Models. 85 | """ 86 | s = np.random.choice(range(1, self.n_iter + 1), size=batch_size) 87 | continuous_sqrt_alpha_cumprod = torch.FloatTensor( 88 | np.random.uniform( 89 | self.sqrt_alphas_cumprod_prev[s-1], 90 | self.sqrt_alphas_cumprod_prev[s], 91 | size=batch_size 92 | ) 93 | ).to(device) 94 | return continuous_sqrt_alpha_cumprod.unsqueeze(-1) 95 | 96 | def q_sample(self, y_0, continuous_sqrt_alpha_cumprod=None, eps=None): 97 | """ 98 | Efficiently computes diffusion version y_t from y_0 using a closed form expression: 99 | y_t = sqrt(alpha_cumprod)_t * y_0 + sqrt(1 - alpha_cumprod_t) * eps, 100 | where eps is sampled from a standard Gaussian. 101 | """ 102 | batch_size = y_0.shape[0] 103 | continuous_sqrt_alpha_cumprod \ 104 | = self.sample_continuous_noise_level(batch_size, device=y_0.device) \ 105 | if isinstance(eps, type(None)) else continuous_sqrt_alpha_cumprod 106 | if isinstance(eps, type(None)): 107 | eps = torch.randn_like(y_0) 108 | # Closed form signal diffusion 109 | outputs = continuous_sqrt_alpha_cumprod * y_0 + (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * eps 110 | return outputs 111 | 112 | def q_posterior(self, y_start, y, t): 113 | """ 114 | Computes reverse (denoising) process posterior q(y_{t-1}|y_0, y_t, x) 115 | parameters: mean and variance. 116 | """ 117 | posterior_mean = self.posterior_mean_coef1[t] * y_start + self.posterior_mean_coef2[t] * y 118 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 119 | return posterior_mean, posterior_log_variance_clipped 120 | 121 | def predict_start_from_noise(self, y, t, eps): 122 | """ 123 | Computes y_0 from given y_t and reconstructed noise. 124 | Is needed to reconstruct the reverse (denoising) 125 | process posterior q(y_{t-1}|y_0, y_t, x). 126 | """ 127 | return self.sqrt_recip_alphas_cumprod[t] * y - self.sqrt_alphas_cumprod_m1[t] * eps 128 | 129 | def p_mean_variance(self, mels, y, t, clip_denoised: bool): 130 | """ 131 | Computes Gaussian transitions of Markov chain at step t 132 | for further computation of y_{t-1} given current state y_t and features. 133 | """ 134 | batch_size = mels.shape[0] 135 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(mels) 136 | eps_recon = self.nn(mels, y, noise_level) 137 | y_recon = self.predict_start_from_noise(y, t, eps_recon) 138 | 139 | if clip_denoised: 140 | y_recon.clamp_(-1.0, 1.0) 141 | 142 | model_mean, posterior_log_variance = self.q_posterior(y_start=y_recon, y=y, t=t) 143 | return model_mean, posterior_log_variance 144 | 145 | def compute_inverse_dynamics(self, mels, y, t, clip_denoised=True): 146 | """ 147 | Computes reverse (denoising) process dynamics. Closely related to the idea of Langevin dynamics. 148 | :param mels (torch.Tensor): mel-spectrograms acoustic features of shape [B, n_mels, T//hop_length] 149 | :param y (torch.Tensor): previous state from dynamics trajectory 150 | :param clip_denoised (bool, optional): clip signal to [-1, 1] 151 | :return (torch.Tensor): next state 152 | """ 153 | model_mean, model_log_variance = self.p_mean_variance(mels, y, t, clip_denoised) 154 | eps = torch.randn_like(y) if t > 0 else torch.zeros_like(y) 155 | return model_mean + eps * (0.5 * model_log_variance).exp() 156 | 157 | def sample(self, mels, store_intermediate_states=False): 158 | """ 159 | Samples speech waveform via progressive denoising of white noise with guidance of mels-epctrogram. 160 | :param mels (torch.Tensor): mel-spectrograms acoustic features of shape [B, n_mels, T//hop_length] 161 | :param store_intermediate_states (bool, optional): whether to store dynamics trajectory or not 162 | :return ys (list of torch.Tensor) (if store_intermediate_states=True) 163 | or y_0 (torch.Tensor): predicted signals on every dynamics iteration of shape [B, T] 164 | """ 165 | with torch.no_grad(): 166 | device = next(self.parameters()).device 167 | batch_size, T = mels.shape[0], mels.shape[-1] 168 | ys = [torch.randn(batch_size, T*self.total_factor, dtype=torch.float32).to(device)] 169 | t = self.n_iter - 1 170 | while t >= 0: 171 | y_t = self.compute_inverse_dynamics(mels, y=ys[-1], t=t) 172 | ys.append(y_t) 173 | t -= 1 174 | return ys if store_intermediate_states else ys[-1] 175 | 176 | def compute_loss(self, mels, y_0): 177 | """ 178 | Computes loss between GT Gaussian noise and reconstructed noise by model from diffusion process. 179 | :param mels (torch.Tensor): mel-spectrograms acoustic features of shape [B, n_mels, T//hop_length] 180 | :param y_0 (torch.Tensor): GT speech signals 181 | :return loss (torch.Tensor): loss of diffusion model 182 | """ 183 | self._verify_noise_schedule_existence() 184 | 185 | # Sample continuous noise level 186 | batch_size = y_0.shape[0] 187 | continuous_sqrt_alpha_cumprod \ 188 | = self.sample_continuous_noise_level(batch_size, device=y_0.device) 189 | eps = torch.randn_like(y_0) 190 | 191 | # Diffuse the signal 192 | y_noisy = self.q_sample(y_0, continuous_sqrt_alpha_cumprod, eps) 193 | 194 | # Reconstruct the added noise 195 | eps_recon = self.nn(mels, y_noisy, continuous_sqrt_alpha_cumprod) 196 | loss = torch.nn.L1Loss()(eps_recon, eps) 197 | return loss 198 | 199 | def forward(self, mels, store_intermediate_states=False): 200 | """ 201 | Generates speech from given mel-spectrogram. 202 | :param mels (torch.Tensor): mel-spectrogram tensor of shape [1, n_mels, T//hop_length] 203 | :param store_intermediate_states (bool, optional): 204 | flag to set return tensor to be a set of all states of denoising process 205 | """ 206 | self._verify_noise_schedule_existence() 207 | 208 | return self.sample( 209 | mels, store_intermediate_states 210 | ) 211 | 212 | def _verify_noise_schedule_existence(self): 213 | if not self.noise_schedule_is_set: 214 | raise RuntimeError( 215 | 'No noise schedule is found. Specify your noise schedule ' 216 | 'by pushing arguments into `set_new_noise_schedule(...)` method. ' 217 | 'For example: ' 218 | "`wavegrad.set_new_noise_level(init=torch.linspace, init_kwargs=\{'steps': 50, 'start': 1e-6, 'end': 1e-2\})`." 219 | ) 220 | -------------------------------------------------------------------------------- /model/downsampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.base import BaseModule 4 | from model.interpolation import InterpolationBlock 5 | from model.layers import Conv1dWithInitialization 6 | 7 | 8 | class ConvolutionBlock(BaseModule): 9 | def __init__(self, in_channels, out_channels, dilation): 10 | super(ConvolutionBlock, self).__init__() 11 | self.leaky_relu = torch.nn.LeakyReLU(0.2) 12 | self.convolution = Conv1dWithInitialization( 13 | in_channels=in_channels, 14 | out_channels=out_channels, 15 | kernel_size=3, 16 | stride=1, 17 | padding=dilation, 18 | dilation=dilation 19 | ) 20 | 21 | def forward(self, x): 22 | outputs = self.leaky_relu(x) 23 | outputs = self.convolution(outputs) 24 | return outputs 25 | 26 | 27 | class DownsamplingBlock(BaseModule): 28 | def __init__(self, in_channels, out_channels, factor, dilations): 29 | super(DownsamplingBlock, self).__init__() 30 | in_sizes = [in_channels] + [out_channels for _ in range(len(dilations) - 1)] 31 | out_sizes = [out_channels for _ in range(len(in_sizes))] 32 | self.main_branch = torch.nn.Sequential(*([ 33 | InterpolationBlock( 34 | scale_factor=factor, 35 | mode='linear', 36 | align_corners=False, 37 | downsample=True 38 | ) 39 | ] + [ 40 | ConvolutionBlock(in_size, out_size, dilation) 41 | for in_size, out_size, dilation in zip(in_sizes, out_sizes, dilations) 42 | ])) 43 | self.residual_branch = torch.nn.Sequential(*[ 44 | Conv1dWithInitialization( 45 | in_channels=in_channels, 46 | out_channels=out_channels, 47 | kernel_size=1, 48 | stride=1 49 | ), 50 | InterpolationBlock( 51 | scale_factor=factor, 52 | mode='linear', 53 | align_corners=False, 54 | downsample=True 55 | ) 56 | ]) 57 | 58 | def forward(self, x): 59 | outputs = self.main_branch(x) 60 | outputs = outputs + self.residual_branch(x) 61 | return outputs 62 | -------------------------------------------------------------------------------- /model/interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.base import BaseModule 4 | 5 | 6 | class InterpolationBlock(BaseModule): 7 | def __init__(self, scale_factor, mode='linear', align_corners=False, downsample=False): 8 | super(InterpolationBlock, self).__init__() 9 | self.downsample = downsample 10 | self.scale_factor = scale_factor 11 | self.mode = mode 12 | self.align_corners = align_corners 13 | 14 | def forward(self, x): 15 | outputs = torch.nn.functional.interpolate( 16 | x, 17 | size=x.shape[-1] * self.scale_factor \ 18 | if not self.downsample else x.shape[-1] // self.scale_factor, 19 | mode=self.mode, 20 | align_corners=self.align_corners, 21 | recompute_scale_factor=False 22 | ) 23 | return outputs 24 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.base import BaseModule 4 | 5 | 6 | class Conv1dWithInitialization(BaseModule): 7 | def __init__(self, **kwargs): 8 | super(Conv1dWithInitialization, self).__init__() 9 | self.conv1d = torch.nn.Conv1d(**kwargs) 10 | torch.nn.init.orthogonal_(self.conv1d.weight.data, gain=1) 11 | 12 | def forward(self, x): 13 | return self.conv1d(x) 14 | -------------------------------------------------------------------------------- /model/linear_modulation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from model.base import BaseModule 6 | from model.layers import Conv1dWithInitialization 7 | 8 | 9 | LINEAR_SCALE=5000 10 | 11 | 12 | class PositionalEncoding(BaseModule): 13 | def __init__(self, n_channels): 14 | super(PositionalEncoding, self).__init__() 15 | self.n_channels = n_channels 16 | 17 | def forward(self, noise_level): 18 | if len(noise_level.shape) > 1: 19 | noise_level = noise_level.squeeze(-1) 20 | half_dim = self.n_channels // 2 21 | exponents = torch.arange(half_dim, dtype=torch.float32).to(noise_level) / float(half_dim) 22 | exponents = 1e-4 ** exponents 23 | exponents = LINEAR_SCALE * noise_level.unsqueeze(1) * exponents.unsqueeze(0) 24 | return torch.cat([exponents.sin(), exponents.cos()], dim=-1) 25 | 26 | 27 | class FeatureWiseLinearModulation(BaseModule): 28 | def __init__(self, in_channels, out_channels, input_dscaled_by): 29 | super(FeatureWiseLinearModulation, self).__init__() 30 | self.signal_conv = torch.nn.Sequential(*[ 31 | Conv1dWithInitialization( 32 | in_channels=in_channels, 33 | out_channels=in_channels, 34 | kernel_size=3, 35 | stride=1, 36 | padding=1 37 | ), 38 | torch.nn.LeakyReLU(0.2) 39 | ]) 40 | self.positional_encoding = PositionalEncoding(in_channels) 41 | self.scale_conv = Conv1dWithInitialization( 42 | in_channels=in_channels, 43 | out_channels=out_channels, 44 | kernel_size=3, 45 | stride=1, 46 | padding=1 47 | ) 48 | self.shift_conv = Conv1dWithInitialization( 49 | in_channels=in_channels, 50 | out_channels=out_channels, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1 54 | ) 55 | 56 | def forward(self, x, noise_level): 57 | outputs = self.signal_conv(x) 58 | outputs = outputs + self.positional_encoding(noise_level).unsqueeze(-1) 59 | scale, shift = self.scale_conv(outputs), self.shift_conv(outputs) 60 | return scale, shift 61 | 62 | 63 | class FeatureWiseAffine(BaseModule): 64 | def __init__(self): 65 | super(FeatureWiseAffine, self).__init__() 66 | 67 | def forward(self, x, scale, shift): 68 | outputs = scale * x + shift 69 | return outputs 70 | -------------------------------------------------------------------------------- /model/nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from model.base import BaseModule 6 | from model.layers import Conv1dWithInitialization 7 | from model.upsampling import UpsamplingBlock as UBlock 8 | from model.downsampling import DownsamplingBlock as DBlock 9 | from model.linear_modulation import FeatureWiseLinearModulation as FiLM 10 | 11 | 12 | class WaveGradNN(BaseModule): 13 | """ 14 | WaveGrad is a fully-convolutional mel-spectrogram conditional 15 | vocoder model for waveform generation introduced in 16 | "WaveGrad: Estimating Gradients for Waveform Generation" paper (link: https://arxiv.org/pdf/2009.00713.pdf). 17 | The concept is built on the prior work on score matching and diffusion probabilistic models. 18 | Current implementation follows described architecture in the paper. 19 | """ 20 | def __init__(self, config): 21 | super(WaveGradNN, self).__init__() 22 | # Building upsampling branch (mels -> signal) 23 | self.ublock_preconv = Conv1dWithInitialization( 24 | in_channels=config.data_config.n_mels, 25 | out_channels=config.model_config.upsampling_preconv_out_channels, 26 | kernel_size=3, 27 | stride=1, 28 | padding=1 29 | ) 30 | upsampling_in_sizes = [config.model_config.upsampling_preconv_out_channels] \ 31 | + config.model_config.upsampling_out_channels[:-1] 32 | self.ublocks = torch.nn.ModuleList([ 33 | UBlock( 34 | in_channels=in_size, 35 | out_channels=out_size, 36 | factor=factor, 37 | dilations=dilations 38 | ) for in_size, out_size, factor, dilations in zip( 39 | upsampling_in_sizes, 40 | config.model_config.upsampling_out_channels, 41 | config.model_config.factors, 42 | config.model_config.upsampling_dilations 43 | ) 44 | ]) 45 | self.ublock_postconv = Conv1dWithInitialization( 46 | in_channels=config.model_config.upsampling_out_channels[-1], 47 | out_channels=1, 48 | kernel_size=3, 49 | stride=1, 50 | padding=1 51 | ) 52 | 53 | # Building downsampling branch (starting from signal) 54 | self.dblock_preconv = Conv1dWithInitialization( 55 | in_channels=1, 56 | out_channels=config.model_config.downsampling_preconv_out_channels, 57 | kernel_size=5, 58 | stride=1, 59 | padding=2 60 | ) 61 | downsampling_in_sizes = [config.model_config.downsampling_preconv_out_channels] \ 62 | + config.model_config.downsampling_out_channels[:-1] 63 | self.dblocks = torch.nn.ModuleList([ 64 | DBlock( 65 | in_channels=in_size, 66 | out_channels=out_size, 67 | factor=factor, 68 | dilations=dilations 69 | ) for in_size, out_size, factor, dilations in zip( 70 | downsampling_in_sizes, 71 | config.model_config.downsampling_out_channels, 72 | config.model_config.factors[1:][::-1], 73 | config.model_config.downsampling_dilations 74 | ) 75 | ]) 76 | 77 | # Building FiLM connections (in order of downscaling stream) 78 | film_in_sizes = [32] + config.model_config.downsampling_out_channels 79 | film_out_sizes = config.model_config.upsampling_out_channels[::-1] 80 | film_factors = [1] + config.model_config.factors[1:][::-1] 81 | self.films = torch.nn.ModuleList([ 82 | FiLM( 83 | in_channels=in_size, 84 | out_channels=out_size, 85 | input_dscaled_by=np.product(film_factors[:i+1]) # for proper positional encodings initialization 86 | ) for i, (in_size, out_size) in enumerate( 87 | zip(film_in_sizes, film_out_sizes) 88 | ) 89 | ]) 90 | 91 | def forward(self, mels, yn, noise_level): 92 | """ 93 | Computes forward pass of neural network. 94 | :param mels (torch.Tensor): mel-spectrogram acoustic features of shape [B, n_mels, T//hop_length] 95 | :param yn (torch.Tensor): noised signal `y_n` of shape [B, T] 96 | :param noise_level (float): level of noise added by diffusion 97 | :return (torch.Tensor): epsilon noise 98 | """ 99 | # Prepare inputs 100 | assert len(mels.shape) == 3 # B, n_mels, T 101 | yn = yn.unsqueeze(1) 102 | assert len(yn.shape) == 3 # B, 1, T 103 | 104 | # Downsampling stream + Linear Modulation statistics calculation 105 | statistics = [] 106 | dblock_outputs = self.dblock_preconv(yn) 107 | scale, shift = self.films[0](x=dblock_outputs, noise_level=noise_level) 108 | statistics.append([scale, shift]) 109 | for dblock, film in zip(self.dblocks, self.films[1:]): 110 | dblock_outputs = dblock(dblock_outputs) 111 | scale, shift = film(x=dblock_outputs, noise_level=noise_level) 112 | statistics.append([scale, shift]) 113 | statistics = statistics[::-1] 114 | 115 | # Upsampling stream 116 | ublock_outputs = self.ublock_preconv(mels) 117 | for i, ublock in enumerate(self.ublocks): 118 | scale, shift = statistics[i] 119 | ublock_outputs = ublock(x=ublock_outputs, scale=scale, shift=shift) 120 | outputs = self.ublock_postconv(ublock_outputs) 121 | return outputs.squeeze(1) 122 | -------------------------------------------------------------------------------- /model/upsampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.base import BaseModule 4 | from model.linear_modulation import FeatureWiseAffine 5 | from model.interpolation import InterpolationBlock 6 | from model.layers import Conv1dWithInitialization 7 | 8 | 9 | class BasicModulationBlock(BaseModule): 10 | """ 11 | Linear modulation part of UBlock, represented by sequence of the following layers: 12 | - Feature-wise Affine 13 | - LReLU 14 | - 3x1 Conv 15 | """ 16 | def __init__(self, n_channels, dilation): 17 | super(BasicModulationBlock, self).__init__() 18 | self.featurewise_affine = FeatureWiseAffine() 19 | self.leaky_relu = torch.nn.LeakyReLU(0.2) 20 | self.convolution = Conv1dWithInitialization( 21 | in_channels=n_channels, 22 | out_channels=n_channels, 23 | kernel_size=3, 24 | stride=1, 25 | padding=dilation, 26 | dilation=dilation 27 | ) 28 | 29 | def forward(self, x, scale, shift): 30 | outputs = self.featurewise_affine(x, scale, shift) 31 | outputs = self.leaky_relu(outputs) 32 | outputs = self.convolution(outputs) 33 | return outputs 34 | 35 | 36 | class UpsamplingBlock(BaseModule): 37 | def __init__(self, in_channels, out_channels, factor, dilations): 38 | super(UpsamplingBlock, self).__init__() 39 | self.first_block_main_branch = torch.nn.ModuleDict({ 40 | 'upsampling': torch.nn.Sequential(*[ 41 | torch.nn.LeakyReLU(0.2), 42 | InterpolationBlock( 43 | scale_factor=factor, 44 | mode='linear', 45 | align_corners=False 46 | ), 47 | Conv1dWithInitialization( 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=3, 51 | stride=1, 52 | padding=dilations[0], 53 | dilation=dilations[0] 54 | ) 55 | ]), 56 | 'modulation': BasicModulationBlock( 57 | out_channels, dilation=dilations[1] 58 | ) 59 | }) 60 | self.first_block_residual_branch = torch.nn.Sequential(*[ 61 | Conv1dWithInitialization( 62 | in_channels=in_channels, 63 | out_channels=out_channels, 64 | kernel_size=1, 65 | stride=1 66 | ), 67 | InterpolationBlock( 68 | scale_factor=factor, 69 | mode='linear', 70 | align_corners=False 71 | ) 72 | ]) 73 | self.second_block_main_branch = torch.nn.ModuleDict({ 74 | f'modulation_{idx}': BasicModulationBlock( 75 | out_channels, dilation=dilations[2 + idx] 76 | ) for idx in range(2) 77 | }) 78 | 79 | def forward(self, x, scale, shift): 80 | # First upsampling residual block 81 | outputs = self.first_block_main_branch['upsampling'](x) 82 | outputs = self.first_block_main_branch['modulation'](outputs, scale, shift) 83 | outputs = outputs + self.first_block_residual_branch(x) 84 | 85 | # Second residual block 86 | residual = self.second_block_main_branch['modulation_0'](outputs, scale, shift) 87 | outputs = outputs + self.second_block_main_branch['modulation_1'](residual, scale, shift) 88 | return outputs 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchaudio==0.6.0 3 | numpy==1.18.5 4 | matplotlib>=3.3.1 5 | tqdm -------------------------------------------------------------------------------- /runs/inference.sh: -------------------------------------------------------------------------------- 1 | CONFIG_PATH=$1 #configs/default.json 2 | CHECKPOINT_PATH=$2 #logs/pretrained_ljspeech.pt 3 | NOISE_SCHEDULE_PATH=$3 #schedules/default/6iters.pt 4 | MEL_FILELIST_PATH=$4 #tmp/mel_filelist.txt 5 | VERBOSE=$5 #'yes' 6 | 7 | python inference.py -c $CONFIG_PATH -ch $CHECKPOINT_PATH -ns $NOISE_SCHEDULE_PATH -m $MEL_FILELIST_PATH -v $VERBOSE 8 | -------------------------------------------------------------------------------- /runs/train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1 2 | 3 | CONFIG_PATH=configs/default.json 4 | VERBOSE="yes" 5 | 6 | python train.py -c $CONFIG_PATH -v $VERBOSE 7 | -------------------------------------------------------------------------------- /schedules/pretrained/1000iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/1000iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/100iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/100iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/12iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/12iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/25iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/25iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/50iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/50iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/6iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/6iters.pt -------------------------------------------------------------------------------- /schedules/pretrained/7iters.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivanvovk/WaveGrad/721c37c216132a2ef0a16adc38439f993998e0b7/schedules/pretrained/7iters.pt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data import DataLoader 11 | 12 | from datetime import datetime 13 | from tqdm import tqdm 14 | 15 | from logger import Logger 16 | from model import WaveGrad 17 | from data import AudioDataset, MelSpectrogramFixed 18 | from benchmark import compute_rtf 19 | from utils import ConfigWrapper, show_message, str2bool 20 | 21 | 22 | def run_training(rank, config, args): 23 | if args.n_gpus > 1: 24 | init_distributed(rank, args.n_gpus, config.dist_config) 25 | torch.cuda.set_device(f'cuda:{rank}') 26 | 27 | show_message('Initializing logger...', verbose=args.verbose, rank=rank) 28 | logger = Logger(config, rank=rank) 29 | 30 | show_message('Initializing model...', verbose=args.verbose, rank=rank) 31 | model = WaveGrad(config).cuda() 32 | show_message(f'Number of WaveGrad parameters: {model.nparams}', verbose=args.verbose, rank=rank) 33 | mel_fn = MelSpectrogramFixed( 34 | sample_rate=config.data_config.sample_rate, 35 | n_fft=config.data_config.n_fft, 36 | win_length=config.data_config.win_length, 37 | hop_length=config.data_config.hop_length, 38 | f_min=config.data_config.f_min, 39 | f_max=config.data_config.f_max, 40 | n_mels=config.data_config.n_mels, 41 | window_fn=torch.hann_window 42 | ).cuda() 43 | 44 | show_message('Initializing optimizer, scheduler and losses...', verbose=args.verbose, rank=rank) 45 | optimizer = torch.optim.Adam(params=model.parameters(), lr=config.training_config.lr) 46 | scheduler = torch.optim.lr_scheduler.StepLR( 47 | optimizer, 48 | step_size=config.training_config.scheduler_step_size, 49 | gamma=config.training_config.scheduler_gamma 50 | ) 51 | if config.training_config.use_fp16: 52 | scaler = torch.cuda.amp.GradScaler() 53 | 54 | show_message('Initializing data loaders...', verbose=args.verbose, rank=rank) 55 | train_dataset = AudioDataset(config, training=True) 56 | train_sampler = DistributedSampler(train_dataset) if args.n_gpus > 1 else None 57 | train_dataloader = DataLoader( 58 | train_dataset, batch_size=config.training_config.batch_size, 59 | sampler=train_sampler, drop_last=True 60 | ) 61 | 62 | if rank == 0: 63 | test_dataset = AudioDataset(config, training=False) 64 | test_dataloader = DataLoader(test_dataset, batch_size=1) 65 | test_batch = test_dataset.sample_test_batch( 66 | config.training_config.n_samples_to_test 67 | ) 68 | 69 | if config.training_config.continue_training: 70 | show_message('Loading latest checkpoint to continue training...', verbose=args.verbose, rank=rank) 71 | model, optimizer, iteration = logger.load_latest_checkpoint(model, optimizer) 72 | epoch_size = len(train_dataset) // config.training_config.batch_size 73 | epoch_start = iteration // epoch_size 74 | else: 75 | iteration = 0 76 | epoch_start = 0 77 | 78 | # Log ground truth test batch 79 | if rank == 0: 80 | audios = { 81 | f'audio_{index}/gt': audio 82 | for index, audio in enumerate(test_batch) 83 | } 84 | logger.log_audios(0, audios) 85 | specs = { 86 | f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze() 87 | for index, audio in enumerate(test_batch) 88 | } 89 | logger.log_specs(0, specs) 90 | 91 | if args.n_gpus > 1: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) 93 | show_message(f'INITIALIZATION IS DONE ON RANK {rank}.') 94 | 95 | show_message('Start training...', verbose=args.verbose, rank=rank) 96 | try: 97 | for epoch in range(epoch_start, config.training_config.n_epoch): 98 | # Training step 99 | model.train() 100 | (model if args.n_gpus == 1 else model.module).set_new_noise_schedule( 101 | init=torch.linspace, 102 | init_kwargs={ 103 | 'steps': config.training_config.training_noise_schedule.n_iter, 104 | 'start': config.training_config.training_noise_schedule.betas_range[0], 105 | 'end': config.training_config.training_noise_schedule.betas_range[1] 106 | } 107 | ) 108 | for batch in ( 109 | tqdm(train_dataloader, leave=False) \ 110 | if args.verbose and rank == 0 else train_dataloader 111 | ): 112 | model.zero_grad() 113 | 114 | batch = batch.cuda() 115 | mels = mel_fn(batch) 116 | 117 | if config.training_config.use_fp16: 118 | with torch.cuda.amp.autocast(): 119 | loss = (model if args.n_gpus == 1 else model.module).compute_loss(mels, batch) 120 | scaler.scale(loss).backward() 121 | scaler.unscale_(optimizer) 122 | else: 123 | loss = (model if args.n_gpus == 1 else model.module).compute_loss(mels, batch) 124 | loss.backward() 125 | 126 | grad_norm = torch.nn.utils.clip_grad_norm_( 127 | parameters=model.parameters(), 128 | max_norm=config.training_config.grad_clip_threshold 129 | ) 130 | 131 | if config.training_config.use_fp16: 132 | scaler.step(optimizer) 133 | scaler.update() 134 | else: 135 | optimizer.step() 136 | 137 | loss_stats = { 138 | 'total_loss': loss.item(), 139 | 'grad_norm': grad_norm.item() 140 | } 141 | logger.log_training(iteration, loss_stats, verbose=False) 142 | 143 | iteration += 1 144 | 145 | # Test step after epoch on rank==0 GPU 146 | if epoch % config.training_config.test_interval == 0 and rank == 0: 147 | model.eval() 148 | (model if args.n_gpus == 1 else model.module).set_new_noise_schedule( 149 | init=torch.linspace, 150 | init_kwargs={ 151 | 'steps': config.training_config.test_noise_schedule.n_iter, 152 | 'start': config.training_config.test_noise_schedule.betas_range[0], 153 | 'end': config.training_config.test_noise_schedule.betas_range[1] 154 | } 155 | ) 156 | with torch.no_grad(): 157 | # Calculate test set loss 158 | test_loss = 0 159 | for i, batch in enumerate( 160 | tqdm(test_dataloader) \ 161 | if args.verbose and rank == 0 else test_dataloader 162 | ): 163 | batch = batch.cuda() 164 | mels = mel_fn(batch) 165 | test_loss_ = (model if args.n_gpus == 1 else model.module).compute_loss(mels, batch) 166 | test_loss += test_loss_ 167 | test_loss /= (i + 1) 168 | loss_stats = {'total_loss': test_loss.item()} 169 | 170 | # Restore random batch from test dataset 171 | audios = {} 172 | specs = {} 173 | test_l1_loss = 0 174 | test_l1_spec_loss = 0 175 | average_rtf = 0 176 | 177 | for index, test_sample in enumerate(test_batch): 178 | test_sample = test_sample[None].cuda() 179 | test_mel = mel_fn(test_sample.cuda()) 180 | 181 | start = datetime.now() 182 | y_0_hat = (model if args.n_gpus == 1 else model.module).forward( 183 | test_mel, store_intermediate_states=False 184 | ) 185 | y_0_hat_mel = mel_fn(y_0_hat) 186 | end = datetime.now() 187 | generation_time = (end - start).total_seconds() 188 | average_rtf += compute_rtf( 189 | y_0_hat, generation_time, config.data_config.sample_rate 190 | ) 191 | 192 | test_l1_loss += torch.nn.L1Loss()(y_0_hat, test_sample).item() 193 | test_l1_spec_loss += torch.nn.L1Loss()(y_0_hat_mel, test_mel).item() 194 | 195 | audios[f'audio_{index}/predicted'] = y_0_hat.cpu().squeeze() 196 | specs[f'mel_{index}/predicted'] = y_0_hat_mel.cpu().squeeze() 197 | 198 | average_rtf /= len(test_batch) 199 | show_message(f'Device: GPU. average_rtf={average_rtf}', verbose=args.verbose) 200 | 201 | test_l1_loss /= len(test_batch) 202 | loss_stats['l1_test_batch_loss'] = test_l1_loss 203 | test_l1_spec_loss /= len(test_batch) 204 | loss_stats['l1_spec_test_batch_loss'] = test_l1_spec_loss 205 | 206 | logger.log_test(iteration, loss_stats, verbose=args.verbose) 207 | logger.log_audios(iteration, audios) 208 | logger.log_specs(iteration, specs) 209 | 210 | logger.save_checkpoint( 211 | iteration, 212 | model if args.n_gpus == 1 else model.module, 213 | optimizer 214 | ) 215 | if epoch % (epoch//10 + 1) == 0: 216 | scheduler.step() 217 | except KeyboardInterrupt: 218 | print('KeyboardInterrupt: training has been stopped.') 219 | cleanup() 220 | return 221 | 222 | 223 | def run_distributed(fn, config, args): 224 | try: 225 | mp.spawn(fn, args=(config, args), nprocs=args.n_gpus, join=True) 226 | except: 227 | cleanup() 228 | 229 | 230 | def init_distributed(rank, n_gpus, dist_config): 231 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 232 | 233 | torch.cuda.set_device(rank % n_gpus) 234 | 235 | os.environ['MASTER_ADDR'] = dist_config.MASTER_ADDR 236 | os.environ['MASTER_PORT'] = dist_config.MASTER_PORT 237 | 238 | torch.distributed.init_process_group( 239 | backend='nccl', world_size=n_gpus, rank=rank 240 | ) 241 | 242 | 243 | def cleanup(): 244 | dist.destroy_process_group() 245 | 246 | 247 | if __name__ == '__main__': 248 | torch.manual_seed(1234) 249 | np.random.seed(1234) 250 | 251 | parser = argparse.ArgumentParser() 252 | parser.add_argument('-c', '--config', required=True, type=str, help='configuration file') 253 | parser.add_argument( 254 | '-v', '--verbose', required=False, type=str2bool, 255 | nargs='?', const=True, default=True, help='verbosity level' 256 | ) 257 | args = parser.parse_args() 258 | 259 | with open(args.config) as f: 260 | config = ConfigWrapper(**json.load(f)) 261 | 262 | torch.backends.cudnn.enabled = True 263 | torch.backends.cudnn.benchmark = True 264 | 265 | n_gpus = torch.cuda.device_count() 266 | args.__setattr__('n_gpus', n_gpus) 267 | 268 | if args.n_gpus > 1: 269 | run_distributed(run_training, config, args) 270 | else: 271 | run_training(0, config, args) 272 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | 9 | 10 | def show_message(text, verbose=True, end='\n', rank=0): 11 | if verbose and (rank == 0): print(text, end=end) 12 | 13 | 14 | def str2bool(v): 15 | if isinstance(v, bool): return v 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): return True 17 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False 18 | else: raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | 21 | def parse_filelist(filelist_path): 22 | with open(filelist_path, 'r') as f: 23 | filelist = [line.strip() for line in f.readlines()] 24 | return filelist 25 | 26 | 27 | def latest_checkpoint_path(dir_path, regex="checkpoint_*.pt"): 28 | f_list = glob.glob(os.path.join(dir_path, regex)) 29 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 30 | x = f_list[-1] 31 | return x 32 | 33 | 34 | def load_latest_checkpoint(logdir, model, optimizer=None): 35 | latest_model_path = latest_checkpoint_path(logdir, regex="checkpoint_*.pt") 36 | print(f'Latest checkpoint: {latest_model_path}') 37 | d = torch.load( 38 | latest_model_path, 39 | map_location=lambda loc, storage: loc 40 | ) 41 | iteration = d['iteration'] 42 | valid_incompatible_unexp_keys = [ 43 | 'betas', 44 | 'alphas', 45 | 'alphas_cumprod', 46 | 'alphas_cumprod_prev', 47 | 'sqrt_alphas_cumprod', 48 | 'sqrt_recip_alphas_cumprod', 49 | 'sqrt_recipm1_alphas_cumprod', 50 | 'posterior_log_variance_clipped', 51 | 'posterior_mean_coef1', 52 | 'posterior_mean_coef2' 53 | ] 54 | d['model'] = { 55 | key: value for key, value in d['model'].items() \ 56 | if key not in valid_incompatible_unexp_keys 57 | } 58 | model.load_state_dict(d['model'], strict=False) 59 | if not isinstance(optimizer, type(None)): 60 | optimizer.load_state_dict(d['optimizer']) 61 | return model, optimizer, iteration 62 | 63 | 64 | def save_figure_to_numpy(fig): 65 | # save it to a numpy array. 66 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 67 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 68 | return data 69 | 70 | 71 | def plot_tensor_to_numpy(tensor): 72 | plt.style.use('default') 73 | fig, ax = plt.subplots(figsize=(12, 3)) 74 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') 75 | plt.colorbar(im, ax=ax) 76 | plt.tight_layout() 77 | 78 | fig.canvas.draw() 79 | data = save_figure_to_numpy(fig) 80 | plt.close() 81 | return data 82 | 83 | 84 | class ConfigWrapper(object): 85 | """ 86 | Wrapper dict class to avoid annoying key dict indexing like: 87 | `config.sample_rate` instead of `config["sample_rate"]`. 88 | """ 89 | def __init__(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if type(v) == dict: 92 | v = ConfigWrapper(**v) 93 | self[k] = v 94 | 95 | def keys(self): 96 | return self.__dict__.keys() 97 | 98 | def items(self): 99 | return self.__dict__.items() 100 | 101 | def values(self): 102 | return self.__dict__.values() 103 | 104 | def to_dict_type(self): 105 | return { 106 | key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type()) 107 | for key, value in dict(**self).items() 108 | } 109 | 110 | def __len__(self): 111 | return len(self.__dict__) 112 | 113 | def __getitem__(self, key): 114 | return getattr(self, key) 115 | 116 | def __setitem__(self, key, value): 117 | return setattr(self, key, value) 118 | 119 | def __contains__(self, key): 120 | return key in self.__dict__ 121 | 122 | def __repr__(self): 123 | return self.__dict__.__repr__() 124 | --------------------------------------------------------------------------------