├── LICENSE ├── README.md ├── setup.py └── src └── wavegrad ├── __init__.py ├── __main__.py ├── dataset.py ├── inference.py ├── learner.py ├── model.py ├── noise_schedule.py ├── params.py └── preprocess.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 LMNT, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveGrad 2 | ![PyPI Release](https://img.shields.io/pypi/v/wavegrad?label=release) [![License](https://img.shields.io/github/license/lmnt-com/wavegrad)](https://github.com/lmnt-com/wavegrad/blob/master/LICENSE) 3 | 4 | **We're hiring!** 5 | If you like what we're building here, [come join us at LMNT](https://explore.lmnt.com). 6 | 7 | WaveGrad is a fast, high-quality neural vocoder designed by the folks at Google Brain. The architecture is described in [WaveGrad: Estimating Gradients for Waveform Generation](https://arxiv.org/pdf/2009.00713.pdf). In short, this model takes a log-scaled Mel spectrogram and converts it to a waveform via iterative refinement. 8 | 9 | ## Status (2020-10-15) 10 | - [x] stable training (22 kHz, 24 kHz) 11 | - [x] high-quality synthesis 12 | - [x] mixed-precision training 13 | - [x] multi-GPU training 14 | - [x] custom noise schedule (faster inference) 15 | - [x] command-line inference 16 | - [x] programmatic inference API 17 | - [x] PyPI package 18 | - [x] audio samples 19 | - [x] pretrained models 20 | - [ ] precomputed noise schedule 21 | 22 | ## Audio samples 23 | [24 kHz audio samples](https://lmnt.com/assets/wavegrad/24kHz) 24 | 25 | ## Pretrained models 26 | [24 kHz pretrained model](https://lmnt.com/assets/wavegrad/wavegrad-24kHz.pt) (183 MB, SHA256: `65e9366da318d58d60d2c78416559351ad16971de906e53b415836c068e335f3`) 27 | 28 | ## Install 29 | 30 | Install using pip: 31 | ``` 32 | pip install wavegrad 33 | ``` 34 | 35 | or from GitHub: 36 | ``` 37 | git clone https://github.com/lmnt-com/wavegrad.git 38 | cd wavegrad 39 | pip install . 40 | ``` 41 | 42 | ### Training 43 | Before you start training, you'll need to prepare a training dataset. The dataset can have any directory structure as long as the contained .wav files are 16-bit mono (e.g. [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [VCTK](https://pytorch.org/audio/_modules/torchaudio/datasets/vctk.html)). By default, this implementation assumes a sample rate of 22 kHz. If you need to change this value, edit [params.py](https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/params.py). 44 | 45 | ``` 46 | python -m wavegrad.preprocess /path/to/dir/containing/wavs 47 | python -m wavegrad /path/to/model/dir /path/to/dir/containing/wavs 48 | 49 | # in another shell to monitor training progress: 50 | tensorboard --logdir /path/to/model/dir --bind_all 51 | ``` 52 | 53 | You should expect to hear intelligible speech by ~20k steps (~1.5h on a 2080 Ti). 54 | 55 | ### Inference API 56 | Basic usage: 57 | 58 | ```python 59 | from wavegrad.inference import predict as wavegrad_predict 60 | 61 | model_dir = '/path/to/model/dir' 62 | spectrogram = # get your hands on a spectrogram in [N,C,W] format 63 | audio, sample_rate = wavegrad_predict(spectrogram, model_dir) 64 | 65 | # audio is a GPU tensor in [N,T] format. 66 | ``` 67 | 68 | If you have a custom noise schedule (see below): 69 | ```python 70 | from wavegrad.inference import predict as wavegrad_predict 71 | 72 | params = { 'noise_schedule': np.load('/path/to/noise_schedule.npy') } 73 | model_dir = '/path/to/model/dir' 74 | spectrogram = # get your hands on a spectrogram in [N,C,W] format 75 | audio, sample_rate = wavegrad_predict(spectrogram, model_dir, params=params) 76 | 77 | # `audio` is a GPU tensor in [N,T] format. 78 | ``` 79 | 80 | ### Inference CLI 81 | ``` 82 | python -m wavegrad.inference /path/to/model /path/to/spectrogram -o output.wav 83 | ``` 84 | 85 | ### Noise schedule 86 | The default implementation uses 1000 iterations to refine the waveform, which runs slower than real-time. WaveGrad is able to achieve high-quality, faster than real-time synthesis with as few as 6 iterations without re-training the model with new hyperparameters. 87 | 88 | To achieve this speed-up, you will need to search for a `noise schedule` that works well for your dataset. This implementation provides a script to perform the search for you: 89 | 90 | ``` 91 | python -m wavegrad.noise_schedule /path/to/trained/model /path/to/preprocessed/validation/dataset 92 | python -m wavegrad.inference /path/to/trained/model /path/to/spectrogram -n noise_schedule.npy -o output.wav 93 | ``` 94 | 95 | The default settings should give good results without spending too much time on the search. If you'd like to find a better noise schedule or use a different number of inference iterations, run the `noise_schedule` script with `--help` to see additional configuration options. 96 | 97 | 98 | ## References 99 | - [WaveGrad: Estimating Gradients for Waveform Generation](https://arxiv.org/pdf/2009.00713.pdf) 100 | - [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239.pdf) 101 | - [Code for Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion) 102 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from setuptools import find_packages, setup 17 | 18 | 19 | VERSION = '0.1.5' 20 | DESCRIPTION = 'wavegrad' 21 | AUTHOR = 'LMNT, Inc.' 22 | AUTHOR_EMAIL = 'github@lmnt.com' 23 | URL = 'https://www.lmnt.com' 24 | LICENSE = 'Apache 2.0' 25 | KEYWORDS = ['wavegrad machine learning neural vocoder tts speech'] 26 | CLASSIFIERS = [ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.5', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Programming Language :: Python :: 3.7', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Topic :: Scientific/Engineering :: Mathematics', 37 | 'Topic :: Software Development :: Libraries :: Python Modules', 38 | 'Topic :: Software Development :: Libraries', 39 | ] 40 | 41 | 42 | setup(name = 'wavegrad', 43 | version = VERSION, 44 | description = DESCRIPTION, 45 | long_description = open('README.md', 'r').read(), 46 | long_description_content_type = 'text/markdown', 47 | author = AUTHOR, 48 | author_email = AUTHOR_EMAIL, 49 | url = URL, 50 | license = LICENSE, 51 | keywords = KEYWORDS, 52 | packages = find_packages('src'), 53 | package_dir = { '': 'src' }, 54 | install_requires = [ 55 | 'numpy', 56 | 'torch>=1.6', 57 | 'torchaudio>=0.9.0', 58 | 'tqdm' 59 | ], 60 | classifiers = CLASSIFIERS) 61 | -------------------------------------------------------------------------------- /src/wavegrad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmnt-com/wavegrad/c3d54c89631c61fb1099a88ffd8fa7623734f180/src/wavegrad/__init__.py -------------------------------------------------------------------------------- /src/wavegrad/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from argparse import ArgumentParser 17 | from torch.cuda import device_count 18 | from torch.multiprocessing import spawn 19 | 20 | from wavegrad.learner import train, train_distributed 21 | from wavegrad.params import params 22 | 23 | 24 | def _get_free_port(): 25 | import socketserver 26 | with socketserver.TCPServer(('localhost', 0), None) as s: 27 | return s.server_address[1] 28 | 29 | 30 | def main(args): 31 | replica_count = device_count() 32 | if replica_count > 1: 33 | if params.batch_size % replica_count != 0: 34 | raise ValueError(f'Batch size {params.batch_size} is not evenly divisble by # GPUs {replica_count}.') 35 | params.batch_size = params.batch_size // replica_count 36 | port = _get_free_port() 37 | spawn(train_distributed, args=(replica_count, port, args, params), nprocs=replica_count, join=True) 38 | else: 39 | train(args, params) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = ArgumentParser(description='train (or resume training) a WaveGrad model') 44 | parser.add_argument('model_dir', 45 | help='directory in which to store model checkpoints and training logs') 46 | parser.add_argument('data_dirs', nargs='+', 47 | help='space separated list of directories from which to read .wav files for training') 48 | parser.add_argument('--max_steps', default=None, type=int, 49 | help='maximum number of training steps') 50 | parser.add_argument('--fp16', action='store_true', default=False, 51 | help='use 16-bit floating point operations for training') 52 | main(parser.parse_args()) 53 | -------------------------------------------------------------------------------- /src/wavegrad/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import random 19 | import torch 20 | import torchaudio 21 | 22 | from glob import glob 23 | from torch.utils.data.distributed import DistributedSampler 24 | 25 | 26 | class NumpyDataset(torch.utils.data.Dataset): 27 | def __init__(self, paths): 28 | super().__init__() 29 | self.filenames = [] 30 | for path in paths: 31 | self.filenames += glob(f'{path}/**/*.wav', recursive=True) 32 | 33 | def __len__(self): 34 | return len(self.filenames) 35 | 36 | def __getitem__(self, idx): 37 | audio_filename = self.filenames[idx] 38 | spec_filename = f'{audio_filename}.spec.npy' 39 | signal, _ = torchaudio.load(audio_filename) 40 | spectrogram = np.load(spec_filename) 41 | return { 42 | 'audio': signal[0], 43 | 'spectrogram': spectrogram.T 44 | } 45 | 46 | 47 | class Collator: 48 | def __init__(self, params): 49 | self.params = params 50 | 51 | def collate(self, minibatch): 52 | samples_per_frame = self.params.hop_samples 53 | for record in minibatch: 54 | # Filter out records that aren't long enough. 55 | if len(record['spectrogram']) < self.params.crop_mel_frames: 56 | del record['spectrogram'] 57 | del record['audio'] 58 | continue 59 | 60 | start = random.randint(0, record['spectrogram'].shape[0] - self.params.crop_mel_frames) 61 | end = start + self.params.crop_mel_frames 62 | record['spectrogram'] = record['spectrogram'][start:end].T 63 | 64 | start *= samples_per_frame 65 | end *= samples_per_frame 66 | record['audio'] = record['audio'][start:end] 67 | record['audio'] = np.pad(record['audio'], (0, (end-start) - len(record['audio'])), mode='constant') 68 | 69 | audio = np.stack([record['audio'] for record in minibatch if 'audio' in record]) 70 | spectrogram = np.stack([record['spectrogram'] for record in minibatch if 'spectrogram' in record]) 71 | return { 72 | 'audio': torch.from_numpy(audio), 73 | 'spectrogram': torch.from_numpy(spectrogram), 74 | } 75 | 76 | 77 | def from_path(data_dirs, params, is_distributed=False): 78 | dataset = NumpyDataset(data_dirs) 79 | return torch.utils.data.DataLoader( 80 | dataset, 81 | batch_size=params.batch_size, 82 | collate_fn=Collator(params).collate, 83 | shuffle=not is_distributed, 84 | sampler=DistributedSampler(dataset) if is_distributed else None, 85 | pin_memory=True, 86 | drop_last=True, 87 | num_workers=os.cpu_count()) 88 | -------------------------------------------------------------------------------- /src/wavegrad/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import torch 19 | import torchaudio 20 | 21 | from argparse import ArgumentParser 22 | 23 | from wavegrad.params import AttrDict, params as base_params 24 | from wavegrad.model import WaveGrad 25 | 26 | 27 | models = {} 28 | 29 | def predict(spectrogram, model_dir=None, params=None, device=torch.device('cuda')): 30 | # Lazy load model. 31 | if not model_dir in models: 32 | if os.path.exists(f'{model_dir}/weights.pt'): 33 | checkpoint = torch.load(f'{model_dir}/weights.pt') 34 | else: 35 | checkpoint = torch.load(model_dir) 36 | model = WaveGrad(AttrDict(base_params)).to(device) 37 | model.load_state_dict(checkpoint['model']) 38 | model.eval() 39 | models[model_dir] = model 40 | 41 | model = models[model_dir] 42 | model.params.override(params) 43 | with torch.no_grad(): 44 | beta = np.array(model.params.noise_schedule) 45 | alpha = 1 - beta 46 | alpha_cum = np.cumprod(alpha) 47 | 48 | # Expand rank 2 tensors by adding a batch dimension. 49 | if len(spectrogram.shape) == 2: 50 | spectrogram = spectrogram.unsqueeze(0) 51 | spectrogram = spectrogram.to(device) 52 | 53 | audio = torch.randn(spectrogram.shape[0], model.params.hop_samples * spectrogram.shape[-1], device=device) 54 | noise_scale = torch.from_numpy(alpha_cum**0.5).float().unsqueeze(1).to(device) 55 | 56 | for n in range(len(alpha) - 1, -1, -1): 57 | c1 = 1 / alpha[n]**0.5 58 | c2 = (1 - alpha[n]) / (1 - alpha_cum[n])**0.5 59 | audio = c1 * (audio - c2 * model(audio, spectrogram, noise_scale[n]).squeeze(1)) 60 | if n > 0: 61 | noise = torch.randn_like(audio) 62 | sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5 63 | audio += sigma * noise 64 | audio = torch.clamp(audio, -1.0, 1.0) 65 | return audio, model.params.sample_rate 66 | 67 | 68 | def main(args): 69 | spectrogram = torch.from_numpy(np.load(args.spectrogram_path)) 70 | params = {} 71 | if args.noise_schedule: 72 | params['noise_schedule'] = torch.from_numpy(np.load(args.noise_schedule)) 73 | audio, sr = predict(spectrogram, model_dir=args.model_dir, params=params) 74 | torchaudio.save(args.output, audio.cpu(), sample_rate=sr) 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = ArgumentParser(description='runs inference on a spectrogram file generated by wavegrad.preprocess') 79 | parser.add_argument('model_dir', 80 | help='directory containing a trained model (or full path to weights.pt file)') 81 | parser.add_argument('spectrogram_path', 82 | help='path to a spectrogram file generated by wavegrad.preprocess') 83 | parser.add_argument('--noise-schedule', '-n', default=None, 84 | help='path to a custom noise schedule file generated by wavegrad.noise_schedule') 85 | parser.add_argument('--output', '-o', default='output.wav', 86 | help='output file name') 87 | main(parser.parse_args()) 88 | -------------------------------------------------------------------------------- /src/wavegrad/learner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import torch 19 | import torch.nn as nn 20 | 21 | from torch.nn.parallel import DistributedDataParallel 22 | from torch.utils.tensorboard import SummaryWriter 23 | from tqdm import tqdm 24 | 25 | from wavegrad.dataset import from_path as dataset_from_path 26 | from wavegrad.model import WaveGrad 27 | 28 | 29 | def _nested_map(struct, map_fn): 30 | if isinstance(struct, tuple): 31 | return tuple(_nested_map(x, map_fn) for x in struct) 32 | if isinstance(struct, list): 33 | return [_nested_map(x, map_fn) for x in struct] 34 | if isinstance(struct, dict): 35 | return { k: _nested_map(v, map_fn) for k, v in struct.items() } 36 | return map_fn(struct) 37 | 38 | 39 | class WaveGradLearner: 40 | def __init__(self, model_dir, model, dataset, optimizer, params, *args, **kwargs): 41 | os.makedirs(model_dir, exist_ok=True) 42 | self.model_dir = model_dir 43 | self.model = model 44 | self.dataset = dataset 45 | self.optimizer = optimizer 46 | self.params = params 47 | self.autocast = torch.cuda.amp.autocast(enabled=kwargs.get('fp16', False)) 48 | self.scaler = torch.cuda.amp.GradScaler(enabled=kwargs.get('fp16', False)) 49 | self.step = 0 50 | self.is_master = True 51 | 52 | beta = np.array(self.params.noise_schedule) 53 | noise_level = np.cumprod(1 - beta)**0.5 54 | noise_level = np.concatenate([[1.0], noise_level], axis=0) 55 | self.noise_level = torch.tensor(noise_level.astype(np.float32)) 56 | self.loss_fn = nn.L1Loss() 57 | self.summary_writer = None 58 | 59 | def state_dict(self): 60 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 61 | model_state = self.model.module.state_dict() 62 | else: 63 | model_state = self.model.state_dict() 64 | return { 65 | 'step': self.step, 66 | 'model': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model_state.items() }, 67 | 'optimizer': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.optimizer.state_dict().items() }, 68 | 'params': dict(self.params), 69 | 'scaler': self.scaler.state_dict(), 70 | } 71 | 72 | def load_state_dict(self, state_dict): 73 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 74 | self.model.module.load_state_dict(state_dict['model']) 75 | else: 76 | self.model.load_state_dict(state_dict['model']) 77 | self.optimizer.load_state_dict(state_dict['optimizer']) 78 | self.scaler.load_state_dict(state_dict['scaler']) 79 | self.step = state_dict['step'] 80 | 81 | def save_to_checkpoint(self, filename='weights'): 82 | save_basename = f'{filename}-{self.step}.pt' 83 | save_name = f'{self.model_dir}/{save_basename}' 84 | link_name = f'{self.model_dir}/{filename}.pt' 85 | torch.save(self.state_dict(), save_name) 86 | if os.name == 'nt': 87 | torch.save(self.state_dict(), link_name) 88 | else: 89 | if os.path.islink(link_name): 90 | os.unlink(link_name) 91 | os.symlink(save_basename, link_name) 92 | 93 | def restore_from_checkpoint(self, filename='weights'): 94 | try: 95 | checkpoint = torch.load(f'{self.model_dir}/{filename}.pt') 96 | self.load_state_dict(checkpoint) 97 | return True 98 | except FileNotFoundError: 99 | return False 100 | 101 | def train(self, max_steps=None): 102 | device = next(self.model.parameters()).device 103 | while True: 104 | for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset: 105 | if max_steps is not None and self.step >= max_steps: 106 | return 107 | features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) 108 | loss = self.train_step(features) 109 | if torch.isnan(loss).any(): 110 | raise RuntimeError(f'Detected NaN loss at step {self.step}.') 111 | if self.is_master: 112 | if self.step % 100 == 0: 113 | self._write_summary(self.step, features, loss) 114 | if self.step % len(self.dataset) == 0: 115 | self.save_to_checkpoint() 116 | self.step += 1 117 | 118 | def train_step(self, features): 119 | for param in self.model.parameters(): 120 | param.grad = None 121 | 122 | audio = features['audio'] 123 | spectrogram = features['spectrogram'] 124 | 125 | N, T = audio.shape 126 | S = 1000 127 | device = audio.device 128 | self.noise_level = self.noise_level.to(device) 129 | 130 | with self.autocast: 131 | s = torch.randint(1, S + 1, [N], device=audio.device) 132 | l_a, l_b = self.noise_level[s-1], self.noise_level[s] 133 | noise_scale = l_a + torch.rand(N, device=audio.device) * (l_b - l_a) 134 | noise_scale = noise_scale.unsqueeze(1) 135 | noise = torch.randn_like(audio) 136 | noisy_audio = noise_scale * audio + (1.0 - noise_scale**2)**0.5 * noise 137 | 138 | predicted = self.model(noisy_audio, spectrogram, noise_scale.squeeze(1)) 139 | loss = self.loss_fn(noise, predicted.squeeze(1)) 140 | 141 | self.scaler.scale(loss).backward() 142 | self.scaler.unscale_(self.optimizer) 143 | self.grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.params.max_grad_norm) 144 | self.scaler.step(self.optimizer) 145 | self.scaler.update() 146 | return loss 147 | 148 | def _write_summary(self, step, features, loss): 149 | writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step) 150 | writer.add_audio('audio/reference', features['audio'][0], step, sample_rate=self.params.sample_rate) 151 | writer.add_scalar('train/loss', loss, step) 152 | writer.add_scalar('train/grad_norm', self.grad_norm, step) 153 | writer.flush() 154 | self.summary_writer = writer 155 | 156 | 157 | def _train_impl(replica_id, model, dataset, args, params): 158 | torch.backends.cudnn.benchmark = True 159 | opt = torch.optim.Adam(model.parameters(), lr=params.learning_rate) 160 | 161 | learner = WaveGradLearner(args.model_dir, model, dataset, opt, params, fp16=args.fp16) 162 | learner.is_master = (replica_id == 0) 163 | learner.restore_from_checkpoint() 164 | learner.train(max_steps=args.max_steps) 165 | 166 | 167 | def train(args, params): 168 | dataset = dataset_from_path(args.data_dirs, params) 169 | model = WaveGrad(params).cuda() 170 | _train_impl(0, model, dataset, args, params) 171 | 172 | 173 | def train_distributed(replica_id, replica_count, port, args, params): 174 | os.environ['MASTER_ADDR'] = 'localhost' 175 | os.environ['MASTER_PORT'] = str(port) 176 | torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count) 177 | 178 | device = torch.device('cuda', replica_id) 179 | torch.cuda.set_device(device) 180 | model = WaveGrad(params).to(device) 181 | model = DistributedDataParallel(model, device_ids=[replica_id]) 182 | _train_impl(replica_id, model, dataset_from_path(args.data_dirs, params, is_distributed=True), args, params) 183 | -------------------------------------------------------------------------------- /src/wavegrad/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from math import log as ln 22 | 23 | 24 | class Conv1d(nn.Conv1d): 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | nn.init.orthogonal_(self.weight) 31 | nn.init.zeros_(self.bias) 32 | 33 | 34 | class PositionalEncoding(nn.Module): 35 | def __init__(self, dim): 36 | super().__init__() 37 | self.dim = dim 38 | 39 | def forward(self, x, noise_level): 40 | """ 41 | Arguments: 42 | x: 43 | (shape: [N,C,T], dtype: float32) 44 | noise_level: 45 | (shape: [N], dtype: float32) 46 | 47 | Returns: 48 | noise_level: 49 | (shape: [N,C,T], dtype: float32) 50 | """ 51 | N = x.shape[0] 52 | T = x.shape[2] 53 | return (x + self._build_encoding(noise_level)[:, :, None]) 54 | 55 | def _build_encoding(self, noise_level): 56 | count = self.dim // 2 57 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 58 | encoding = noise_level.unsqueeze(1) * torch.exp(-ln(1e4) * step.unsqueeze(0)) 59 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 60 | return encoding 61 | 62 | 63 | class FiLM(nn.Module): 64 | def __init__(self, input_size, output_size): 65 | super().__init__() 66 | self.encoding = PositionalEncoding(input_size) 67 | self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) 68 | self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | nn.init.xavier_uniform_(self.input_conv.weight) 73 | nn.init.xavier_uniform_(self.output_conv.weight) 74 | nn.init.zeros_(self.input_conv.bias) 75 | nn.init.zeros_(self.output_conv.bias) 76 | 77 | def forward(self, x, noise_scale): 78 | x = self.input_conv(x) 79 | x = F.leaky_relu(x, 0.2) 80 | x = self.encoding(x, noise_scale) 81 | shift, scale = torch.chunk(self.output_conv(x), 2, dim=1) 82 | return shift, scale 83 | 84 | 85 | class UBlock(nn.Module): 86 | def __init__(self, input_size, hidden_size, factor, dilation): 87 | super().__init__() 88 | assert isinstance(dilation, (list, tuple)) 89 | assert len(dilation) == 4 90 | 91 | self.factor = factor 92 | self.block1 = Conv1d(input_size, hidden_size, 1) 93 | self.block2 = nn.ModuleList([ 94 | Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), 95 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]) 96 | ]) 97 | self.block3 = nn.ModuleList([ 98 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), 99 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]) 100 | ]) 101 | 102 | def forward(self, x, film_shift, film_scale): 103 | block1 = F.interpolate(x, size=x.shape[-1] * self.factor) 104 | block1 = self.block1(block1) 105 | 106 | block2 = F.leaky_relu(x, 0.2) 107 | block2 = F.interpolate(block2, size=x.shape[-1] * self.factor) 108 | block2 = self.block2[0](block2) 109 | block2 = film_shift + film_scale * block2 110 | block2 = F.leaky_relu(block2, 0.2) 111 | block2 = self.block2[1](block2) 112 | 113 | x = block1 + block2 114 | 115 | block3 = film_shift + film_scale * x 116 | block3 = F.leaky_relu(block3, 0.2) 117 | block3 = self.block3[0](block3) 118 | block3 = film_shift + film_scale * block3 119 | block3 = F.leaky_relu(block3, 0.2) 120 | block3 = self.block3[1](block3) 121 | 122 | x = x + block3 123 | return x 124 | 125 | 126 | class DBlock(nn.Module): 127 | def __init__(self, input_size, hidden_size, factor): 128 | super().__init__() 129 | self.factor = factor 130 | self.residual_dense = Conv1d(input_size, hidden_size, 1) 131 | self.conv = nn.ModuleList([ 132 | Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), 133 | Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), 134 | Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), 135 | ]) 136 | 137 | def forward(self, x): 138 | size = x.shape[-1] // self.factor 139 | 140 | residual = self.residual_dense(x) 141 | residual = F.interpolate(residual, size=size) 142 | 143 | x = F.interpolate(x, size=size) 144 | for layer in self.conv: 145 | x = F.leaky_relu(x, 0.2) 146 | x = layer(x) 147 | 148 | return x + residual 149 | 150 | 151 | class WaveGrad(nn.Module): 152 | def __init__(self, params): 153 | super().__init__() 154 | self.params = params 155 | self.downsample = nn.ModuleList([ 156 | Conv1d(1, 32, 5, padding=2), 157 | DBlock(32, 128, 2), 158 | DBlock(128, 128, 2), 159 | DBlock(128, 256, 3), 160 | DBlock(256, 512, 5), 161 | ]) 162 | self.film = nn.ModuleList([ 163 | FiLM(32, 128), 164 | FiLM(128, 128), 165 | FiLM(128, 256), 166 | FiLM(256, 512), 167 | FiLM(512, 512), 168 | ]) 169 | self.upsample = nn.ModuleList([ 170 | UBlock(768, 512, 5, [1, 2, 1, 2]), 171 | UBlock(512, 512, 5, [1, 2, 1, 2]), 172 | UBlock(512, 256, 3, [1, 2, 4, 8]), 173 | UBlock(256, 128, 2, [1, 2, 4, 8]), 174 | UBlock(128, 128, 2, [1, 2, 4, 8]), 175 | ]) 176 | self.first_conv = Conv1d(128, 768, 3, padding=1) 177 | self.last_conv = Conv1d(128, 1, 3, padding=1) 178 | 179 | def forward(self, audio, spectrogram, noise_scale): 180 | x = audio.unsqueeze(1) 181 | downsampled = [] 182 | for film, layer in zip(self.film, self.downsample): 183 | x = layer(x) 184 | downsampled.append(film(x, noise_scale)) 185 | 186 | x = self.first_conv(spectrogram) 187 | for layer, (film_shift, film_scale) in zip(self.upsample, reversed(downsampled)): 188 | x = layer(x, film_shift, film_scale) 189 | x = self.last_conv(x) 190 | return x 191 | -------------------------------------------------------------------------------- /src/wavegrad/noise_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import torch 19 | import torchaudio as T 20 | import torchaudio.transforms as TT 21 | 22 | from argparse import ArgumentParser 23 | from glob import glob 24 | from itertools import product as cartesian_product 25 | from tqdm import tqdm 26 | 27 | from wavegrad.params import params 28 | from wavegrad.inference import predict 29 | 30 | 31 | def _round_up(x, multiple): 32 | return (x + multiple - 1) // multiple * multiple 33 | 34 | 35 | def _ls_mse(reference, predicted): 36 | sr = params.sample_rate 37 | hop = params.hop_samples // 2 38 | win = params.hop_samples * 4 39 | n_fft = 2**((win-1).bit_length()) 40 | f_max = sr / 2.0 41 | mel_spec_transform = TT.MelSpectrogram( 42 | sample_rate=params.sample_rate, 43 | n_fft=n_fft, 44 | win_length=win, 45 | hop_length=hop, 46 | f_min=20.0, 47 | f_max=f_max, 48 | power=1.0, 49 | normalized=True).cuda() 50 | 51 | reference = torch.log(mel_spec_transform(reference) + 1e-5) 52 | predicted = torch.log(mel_spec_transform(predicted) + 1e-5) 53 | return torch.sum((reference - predicted)**2) 54 | 55 | 56 | def main(args): 57 | audio_filenames = glob(f'{args.data_dir}/**/*.wav', recursive=True) 58 | audio_filenames = [f for f in audio_filenames if os.path.exists(f'{f}.spec.npy')] 59 | if len(audio_filenames) == 0: 60 | raise ValueError('No files found.') 61 | 62 | audio = [] 63 | spectrogram = [] 64 | max_audio_len = 0 65 | max_spec_len = 0 66 | for filename in audio_filenames[:args.batch_size]: 67 | clip, _ = T.load(filename) 68 | clip = clip[0].numpy() 69 | spec = np.load(f'{filename}.spec.npy') 70 | audio.append(clip) 71 | spectrogram.append(spec) 72 | max_audio_len = max(max_audio_len, _round_up(len(clip), params.hop_samples)) 73 | max_spec_len = max(max_spec_len, spec.shape[1]) 74 | 75 | padded_audio = [np.pad(a, [0, max_audio_len - len(a)], mode='constant') for a in audio] 76 | spectrogram = [np.pad(s, [[0, 0], [0, max_spec_len - s.shape[1]]], mode='constant') for s in spectrogram] 77 | 78 | padded_audio = torch.from_numpy(np.stack(padded_audio)).cuda() 79 | spectrogram = torch.from_numpy(np.stack(spectrogram)).cuda() 80 | 81 | mantissa = list(sorted(10 * np.random.uniform(size=args.search_level))) 82 | exponent = 10**np.linspace(-6, -1, num=args.iterations) 83 | best_score = 1e32 84 | for candidate in tqdm(cartesian_product(mantissa, repeat=args.iterations), total=len(mantissa)**args.iterations): 85 | noise_schedule = np.array(candidate) * exponent 86 | predicted, _ = predict(spectrogram, model_dir=args.model_dir, params={ 'noise_schedule': noise_schedule }) 87 | score = _ls_mse(padded_audio, predicted) 88 | if score < best_score: 89 | best_score = score 90 | np.save(args.output, noise_schedule) 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = ArgumentParser(description='runs a search to find the best noise schedule for a specified number of inference iterations') 95 | parser.add_argument('model_dir', 96 | help='directory containing a trained model (or full path to weights.pt file)') 97 | parser.add_argument('data_dir', 98 | help='directory from which to read .wav and spectrogram files for noise schedule opitimization') 99 | parser.add_argument('--batch-size', '-b', type=int, default=1, 100 | help='how many wav files to use for the optimization process') 101 | parser.add_argument('--iterations', '-i', type=int, default=6, 102 | help='how many refinement steps to use during inference (more => increase inference time)') 103 | parser.add_argument('--search-level', '-s', type=int, default=3, choices=range(1, 20), 104 | help='how many points to use for the search (more => exponentially more time to search)') 105 | parser.add_argument('--output', '-o', default='noise_schedule.npy', 106 | help='output file name') 107 | main(parser.parse_args()) 108 | -------------------------------------------------------------------------------- /src/wavegrad/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | 18 | 19 | class AttrDict(dict): 20 | def __init__(self, *args, **kwargs): 21 | super(AttrDict, self).__init__(*args, **kwargs) 22 | self.__dict__ = self 23 | 24 | def override(self, attrs): 25 | if isinstance(attrs, dict): 26 | self.__dict__.update(**attrs) 27 | elif isinstance(attrs, (list, tuple, set)): 28 | for attr in attrs: 29 | self.override(attr) 30 | elif attrs is None: 31 | pass 32 | else: 33 | raise NotImplementedError 34 | return self 35 | 36 | 37 | params = AttrDict( 38 | # Training params 39 | batch_size=32, 40 | learning_rate=2e-4, 41 | max_grad_norm=1.0, 42 | 43 | # Data params 44 | sample_rate=22050, 45 | hop_samples=300, # Don't change this. Really. 46 | crop_mel_frames=24, 47 | 48 | # Model params 49 | noise_schedule=np.linspace(1e-6, 0.01, 1000).tolist(), 50 | ) 51 | -------------------------------------------------------------------------------- /src/wavegrad/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torchaudio as T 19 | import torchaudio.transforms as TT 20 | 21 | from argparse import ArgumentParser 22 | from concurrent.futures import ProcessPoolExecutor 23 | from glob import glob 24 | from tqdm import tqdm 25 | 26 | from wavegrad.params import params 27 | 28 | 29 | def transform(filename): 30 | audio, sr = T.load(filename) 31 | if params.sample_rate != sr: 32 | raise ValueError(f'Invalid sample rate {sr}.') 33 | audio = torch.clamp(audio[0], -1.0, 1.0) 34 | 35 | hop = params.hop_samples 36 | win = hop * 4 37 | n_fft = 2**((win-1).bit_length()) 38 | f_max = sr / 2.0 39 | mel_spec_transform = TT.MelSpectrogram(sample_rate=sr, n_fft=n_fft, win_length=win, hop_length=hop, f_min=20.0, f_max=f_max, power=1.0, normalized=True) 40 | 41 | with torch.no_grad(): 42 | spectrogram = mel_spec_transform(audio) 43 | spectrogram = 20 * torch.log10(torch.clamp(spectrogram, min=1e-5)) - 20 44 | spectrogram = torch.clamp((spectrogram + 100) / 100, 0.0, 1.0) 45 | np.save(f'{filename}.spec.npy', spectrogram.cpu().numpy()) 46 | 47 | 48 | def main(args): 49 | filenames = glob(f'{args.dir}/**/*.wav', recursive=True) 50 | with ProcessPoolExecutor() as executor: 51 | list(tqdm(executor.map(transform, filenames), desc='Preprocessing', total=len(filenames))) 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = ArgumentParser(description='prepares a dataset to train WaveGrad') 56 | parser.add_argument('dir', 57 | help='directory containing .wav files for training') 58 | main(parser.parse_args()) 59 | --------------------------------------------------------------------------------