├── LICENSE ├── README.md ├── setup.py └── src └── diffwave ├── __init__.py ├── __main__.py ├── dataset.py ├── inference.py ├── learner.py ├── model.py ├── params.py └── preprocess.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 LMNT, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffWave 2 | ![PyPI Release](https://img.shields.io/pypi/v/diffwave?label=release) [![License](https://img.shields.io/github/license/lmnt-com/diffwave)](https://github.com/lmnt-com/diffwave/blob/master/LICENSE) 3 | 4 | **We're hiring!** 5 | If you like what we're building here, [come join us at LMNT](https://explore.lmnt.com). 6 | 7 | DiffWave is a fast, high-quality neural vocoder and waveform synthesizer. It starts with Gaussian noise and converts it into speech via iterative refinement. The speech can be controlled by providing a conditioning signal (e.g. log-scaled Mel spectrogram). The model and architecture details are described in [DiffWave: A Versatile Diffusion Model for Audio Synthesis](https://arxiv.org/pdf/2009.09761.pdf). 8 | 9 | ## What's new (2021-11-09) 10 | - unconditional waveform synthesis (thanks to [Andrechang](https://github.com/Andrechang)!) 11 | 12 | ## What's new (2021-04-01) 13 | - fast sampling algorithm based on v3 of the DiffWave paper 14 | 15 | ## What's new (2020-10-14) 16 | - new pretrained model trained for 1M steps 17 | - updated audio samples with output from new model 18 | 19 | ## Status (2021-11-09) 20 | - [x] fast inference procedure 21 | - [x] stable training 22 | - [x] high-quality synthesis 23 | - [x] mixed-precision training 24 | - [x] multi-GPU training 25 | - [x] command-line inference 26 | - [x] programmatic inference API 27 | - [x] PyPI package 28 | - [x] audio samples 29 | - [x] pretrained models 30 | - [x] unconditional waveform synthesis 31 | 32 | Big thanks to [Zhifeng Kong](https://github.com/FengNiMa) (lead author of DiffWave) for pointers and bug fixes. 33 | 34 | ## Audio samples 35 | [22.05 kHz audio samples](https://lmnt.com/assets/diffwave) 36 | 37 | ## Pretrained models 38 | [22.05 kHz pretrained model](https://lmnt.com/assets/diffwave/diffwave-ljspeech-22kHz-1000578.pt) (31 MB, SHA256: `d415d2117bb0bba3999afabdd67ed11d9e43400af26193a451d112e2560821a8`) 39 | 40 | This pre-trained model is able to synthesize speech with a real-time factor of 0.87 (smaller is faster). 41 | 42 | ### Pre-trained model details 43 | - trained on 4x 1080Ti 44 | - default parameters 45 | - single precision floating point (FP32) 46 | - trained on LJSpeech dataset excluding LJ001* and LJ002* 47 | - trained for 1000578 steps (1273 epochs) 48 | 49 | ## Install 50 | 51 | Install using pip: 52 | ``` 53 | pip install diffwave 54 | ``` 55 | 56 | or from GitHub: 57 | ``` 58 | git clone https://github.com/lmnt-com/diffwave.git 59 | cd diffwave 60 | pip install . 61 | ``` 62 | 63 | ### Training 64 | Before you start training, you'll need to prepare a training dataset. The dataset can have any directory structure as long as the contained .wav files are 16-bit mono (e.g. [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [VCTK](https://pytorch.org/audio/_modules/torchaudio/datasets/vctk.html)). By default, this implementation assumes a sample rate of 22.05 kHz. If you need to change this value, edit [params.py](https://github.com/lmnt-com/diffwave/blob/master/src/diffwave/params.py). 65 | 66 | ``` 67 | python -m diffwave.preprocess /path/to/dir/containing/wavs 68 | python -m diffwave /path/to/model/dir /path/to/dir/containing/wavs 69 | 70 | # in another shell to monitor training progress: 71 | tensorboard --logdir /path/to/model/dir --bind_all 72 | ``` 73 | 74 | You should expect to hear intelligible (but noisy) speech by ~8k steps (~1.5h on a 2080 Ti). 75 | 76 | #### Multi-GPU training 77 | By default, this implementation uses as many GPUs in parallel as returned by [`torch.cuda.device_count()`](https://pytorch.org/docs/stable/cuda.html#torch.cuda.device_count). You can specify which GPUs to use by setting the [`CUDA_DEVICES_AVAILABLE`](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/) environment variable before running the training module. 78 | 79 | ### Inference API 80 | Basic usage: 81 | 82 | ```python 83 | from diffwave.inference import predict as diffwave_predict 84 | 85 | model_dir = '/path/to/model/dir' 86 | spectrogram = # get your hands on a spectrogram in [N,C,W] format 87 | audio, sample_rate = diffwave_predict(spectrogram, model_dir, fast_sampling=True) 88 | 89 | # audio is a GPU tensor in [N,T] format. 90 | ``` 91 | 92 | ### Inference CLI 93 | ``` 94 | python -m diffwave.inference --fast /path/to/model /path/to/spectrogram -o output.wav 95 | ``` 96 | 97 | ## References 98 | - [DiffWave: A Versatile Diffusion Model for Audio Synthesis](https://arxiv.org/pdf/2009.09761.pdf) 99 | - [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239.pdf) 100 | - [Code for Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion) 101 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from setuptools import find_packages, setup 17 | 18 | 19 | VERSION = '0.1.8' 20 | DESCRIPTION = 'diffwave' 21 | AUTHOR = 'LMNT, Inc.' 22 | AUTHOR_EMAIL = 'github@lmnt.com' 23 | URL = 'https://www.lmnt.com' 24 | LICENSE = 'Apache 2.0' 25 | KEYWORDS = ['diffwave machine learning neural vocoder tts speech'] 26 | CLASSIFIERS = [ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.5', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Programming Language :: Python :: 3.7', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Topic :: Scientific/Engineering :: Mathematics', 37 | 'Topic :: Software Development :: Libraries :: Python Modules', 38 | 'Topic :: Software Development :: Libraries', 39 | ] 40 | 41 | 42 | setup(name = 'diffwave', 43 | version = VERSION, 44 | description = DESCRIPTION, 45 | long_description = open('README.md', 'r').read(), 46 | long_description_content_type = 'text/markdown', 47 | author = AUTHOR, 48 | author_email = AUTHOR_EMAIL, 49 | url = URL, 50 | license = LICENSE, 51 | keywords = KEYWORDS, 52 | packages = find_packages('src'), 53 | package_dir = { '': 'src' }, 54 | install_requires = [ 55 | 'numpy', 56 | 'torch>=1.6', 57 | 'torchaudio>=0.9.0', 58 | 'tqdm' 59 | ], 60 | classifiers = CLASSIFIERS) 61 | -------------------------------------------------------------------------------- /src/diffwave/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmnt-com/diffwave/0594106093b8d8c444de8bd8cd26482f653c569f/src/diffwave/__init__.py -------------------------------------------------------------------------------- /src/diffwave/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from argparse import ArgumentParser 17 | from torch.cuda import device_count 18 | from torch.multiprocessing import spawn 19 | 20 | from diffwave.learner import train, train_distributed 21 | from diffwave.params import params 22 | 23 | 24 | def _get_free_port(): 25 | import socketserver 26 | with socketserver.TCPServer(('localhost', 0), None) as s: 27 | return s.server_address[1] 28 | 29 | 30 | def main(args): 31 | replica_count = device_count() 32 | if replica_count > 1: 33 | if params.batch_size % replica_count != 0: 34 | raise ValueError(f'Batch size {params.batch_size} is not evenly divisble by # GPUs {replica_count}.') 35 | params.batch_size = params.batch_size // replica_count 36 | port = _get_free_port() 37 | spawn(train_distributed, args=(replica_count, port, args, params), nprocs=replica_count, join=True) 38 | else: 39 | train(args, params) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = ArgumentParser(description='train (or resume training) a DiffWave model') 44 | parser.add_argument('model_dir', 45 | help='directory in which to store model checkpoints and training logs') 46 | parser.add_argument('data_dirs', nargs='+', 47 | help='space separated list of directories from which to read .wav files for training') 48 | parser.add_argument('--max_steps', default=None, type=int, 49 | help='maximum number of training steps') 50 | parser.add_argument('--fp16', action='store_true', default=False, 51 | help='use 16-bit floating point operations for training') 52 | main(parser.parse_args()) 53 | -------------------------------------------------------------------------------- /src/diffwave/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import random 19 | import torch 20 | import torch.nn.functional as F 21 | import torchaudio 22 | 23 | from glob import glob 24 | from torch.utils.data.distributed import DistributedSampler 25 | 26 | 27 | class ConditionalDataset(torch.utils.data.Dataset): 28 | def __init__(self, paths): 29 | super().__init__() 30 | self.filenames = [] 31 | for path in paths: 32 | self.filenames += glob(f'{path}/**/*.wav', recursive=True) 33 | 34 | def __len__(self): 35 | return len(self.filenames) 36 | 37 | def __getitem__(self, idx): 38 | audio_filename = self.filenames[idx] 39 | spec_filename = f'{audio_filename}.spec.npy' 40 | signal, _ = torchaudio.load(audio_filename) 41 | spectrogram = np.load(spec_filename) 42 | return { 43 | 'audio': signal[0], 44 | 'spectrogram': spectrogram.T 45 | } 46 | 47 | 48 | class UnconditionalDataset(torch.utils.data.Dataset): 49 | def __init__(self, paths): 50 | super().__init__() 51 | self.filenames = [] 52 | for path in paths: 53 | self.filenames += glob(f'{path}/**/*.wav', recursive=True) 54 | 55 | def __len__(self): 56 | return len(self.filenames) 57 | 58 | def __getitem__(self, idx): 59 | audio_filename = self.filenames[idx] 60 | spec_filename = f'{audio_filename}.spec.npy' 61 | signal, _ = torchaudio.load(audio_filename) 62 | return { 63 | 'audio': signal[0], 64 | 'spectrogram': None 65 | } 66 | 67 | 68 | class Collator: 69 | def __init__(self, params): 70 | self.params = params 71 | 72 | def collate(self, minibatch): 73 | samples_per_frame = self.params.hop_samples 74 | for record in minibatch: 75 | if self.params.unconditional: 76 | # Filter out records that aren't long enough. 77 | if len(record['audio']) < self.params.audio_len: 78 | del record['spectrogram'] 79 | del record['audio'] 80 | continue 81 | 82 | start = random.randint(0, record['audio'].shape[-1] - self.params.audio_len) 83 | end = start + self.params.audio_len 84 | record['audio'] = record['audio'][start:end] 85 | record['audio'] = np.pad(record['audio'], (0, (end - start) - len(record['audio'])), mode='constant') 86 | else: 87 | # Filter out records that aren't long enough. 88 | if len(record['spectrogram']) < self.params.crop_mel_frames: 89 | del record['spectrogram'] 90 | del record['audio'] 91 | continue 92 | 93 | start = random.randint(0, record['spectrogram'].shape[0] - self.params.crop_mel_frames) 94 | end = start + self.params.crop_mel_frames 95 | record['spectrogram'] = record['spectrogram'][start:end].T 96 | 97 | start *= samples_per_frame 98 | end *= samples_per_frame 99 | record['audio'] = record['audio'][start:end] 100 | record['audio'] = np.pad(record['audio'], (0, (end-start) - len(record['audio'])), mode='constant') 101 | 102 | audio = np.stack([record['audio'] for record in minibatch if 'audio' in record]) 103 | if self.params.unconditional: 104 | return { 105 | 'audio': torch.from_numpy(audio), 106 | 'spectrogram': None, 107 | } 108 | spectrogram = np.stack([record['spectrogram'] for record in minibatch if 'spectrogram' in record]) 109 | return { 110 | 'audio': torch.from_numpy(audio), 111 | 'spectrogram': torch.from_numpy(spectrogram), 112 | } 113 | 114 | # for gtzan 115 | def collate_gtzan(self, minibatch): 116 | ldata = [] 117 | mean_audio_len = self.params.audio_len # change to fit in gpu memory 118 | # audio total generated time = audio_len * sample_rate 119 | # GTZAN statistics 120 | # max len audio 675808; min len audio sample 660000; mean len audio sample 662117 121 | # max audio sample 1; min audio sample -1; mean audio sample -0.0010 (normalized) 122 | # sample rate of all is 22050 123 | for data in minibatch: 124 | if data[0].shape[-1] < mean_audio_len: # pad 125 | data_audio = F.pad(data[0], (0, mean_audio_len - data[0].shape[-1]), mode='constant', value=0) 126 | elif data[0].shape[-1] > mean_audio_len: # crop 127 | start = random.randint(0, data[0].shape[-1] - mean_audio_len) 128 | end = start + mean_audio_len 129 | data_audio = data[0][:, start:end] 130 | else: 131 | data_audio = data[0] 132 | ldata.append(data_audio) 133 | audio = torch.cat(ldata, dim=0) 134 | return { 135 | 'audio': audio, 136 | 'spectrogram': None, 137 | } 138 | 139 | 140 | def from_path(data_dirs, params, is_distributed=False): 141 | if params.unconditional: 142 | dataset = UnconditionalDataset(data_dirs) 143 | else:#with condition 144 | dataset = ConditionalDataset(data_dirs) 145 | return torch.utils.data.DataLoader( 146 | dataset, 147 | batch_size=params.batch_size, 148 | collate_fn=Collator(params).collate, 149 | shuffle=not is_distributed, 150 | num_workers=os.cpu_count(), 151 | sampler=DistributedSampler(dataset) if is_distributed else None, 152 | pin_memory=True, 153 | drop_last=True) 154 | 155 | 156 | def from_gtzan(params, is_distributed=False): 157 | dataset = torchaudio.datasets.GTZAN('./data', download=True) 158 | return torch.utils.data.DataLoader( 159 | dataset, 160 | batch_size=params.batch_size, 161 | collate_fn=Collator(params).collate_gtzan, 162 | shuffle=not is_distributed, 163 | num_workers=os.cpu_count(), 164 | sampler=DistributedSampler(dataset) if is_distributed else None, 165 | pin_memory=True, 166 | drop_last=True) 167 | -------------------------------------------------------------------------------- /src/diffwave/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import torch 19 | import torchaudio 20 | 21 | from argparse import ArgumentParser 22 | 23 | from diffwave.params import AttrDict, params as base_params 24 | from diffwave.model import DiffWave 25 | 26 | 27 | models = {} 28 | 29 | def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('cuda'), fast_sampling=False): 30 | # Lazy load model. 31 | if not model_dir in models: 32 | if os.path.exists(f'{model_dir}/weights.pt'): 33 | checkpoint = torch.load(f'{model_dir}/weights.pt') 34 | else: 35 | checkpoint = torch.load(model_dir) 36 | model = DiffWave(AttrDict(base_params)).to(device) 37 | model.load_state_dict(checkpoint['model']) 38 | model.eval() 39 | models[model_dir] = model 40 | 41 | model = models[model_dir] 42 | model.params.override(params) 43 | with torch.no_grad(): 44 | # Change in notation from the DiffWave paper for fast sampling. 45 | # DiffWave paper -> Implementation below 46 | # -------------------------------------- 47 | # alpha -> talpha 48 | # beta -> training_noise_schedule 49 | # gamma -> alpha 50 | # eta -> beta 51 | training_noise_schedule = np.array(model.params.noise_schedule) 52 | inference_noise_schedule = np.array(model.params.inference_noise_schedule) if fast_sampling else training_noise_schedule 53 | 54 | talpha = 1 - training_noise_schedule 55 | talpha_cum = np.cumprod(talpha) 56 | 57 | beta = inference_noise_schedule 58 | alpha = 1 - beta 59 | alpha_cum = np.cumprod(alpha) 60 | 61 | T = [] 62 | for s in range(len(inference_noise_schedule)): 63 | for t in range(len(training_noise_schedule) - 1): 64 | if talpha_cum[t+1] <= alpha_cum[s] <= talpha_cum[t]: 65 | twiddle = (talpha_cum[t]**0.5 - alpha_cum[s]**0.5) / (talpha_cum[t]**0.5 - talpha_cum[t+1]**0.5) 66 | T.append(t + twiddle) 67 | break 68 | T = np.array(T, dtype=np.float32) 69 | 70 | 71 | if not model.params.unconditional: 72 | if len(spectrogram.shape) == 2:# Expand rank 2 tensors by adding a batch dimension. 73 | spectrogram = spectrogram.unsqueeze(0) 74 | spectrogram = spectrogram.to(device) 75 | audio = torch.randn(spectrogram.shape[0], model.params.hop_samples * spectrogram.shape[-1], device=device) 76 | else: 77 | audio = torch.randn(1, params.audio_len, device=device) 78 | noise_scale = torch.from_numpy(alpha_cum**0.5).float().unsqueeze(1).to(device) 79 | 80 | for n in range(len(alpha) - 1, -1, -1): 81 | c1 = 1 / alpha[n]**0.5 82 | c2 = beta[n] / (1 - alpha_cum[n])**0.5 83 | audio = c1 * (audio - c2 * model(audio, torch.tensor([T[n]], device=audio.device), spectrogram).squeeze(1)) 84 | if n > 0: 85 | noise = torch.randn_like(audio) 86 | sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5 87 | audio += sigma * noise 88 | audio = torch.clamp(audio, -1.0, 1.0) 89 | return audio, model.params.sample_rate 90 | 91 | 92 | def main(args): 93 | if args.spectrogram_path: 94 | spectrogram = torch.from_numpy(np.load(args.spectrogram_path)) 95 | else: 96 | spectrogram = None 97 | audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params) 98 | torchaudio.save(args.output, audio.cpu(), sample_rate=sr) 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = ArgumentParser(description='runs inference on a spectrogram file generated by diffwave.preprocess') 103 | parser.add_argument('model_dir', 104 | help='directory containing a trained model (or full path to weights.pt file)') 105 | parser.add_argument('--spectrogram_path', '-s', 106 | help='path to a spectrogram file generated by diffwave.preprocess') 107 | parser.add_argument('--output', '-o', default='output.wav', 108 | help='output file name') 109 | parser.add_argument('--fast', '-f', action='store_true', 110 | help='fast sampling procedure') 111 | main(parser.parse_args()) 112 | -------------------------------------------------------------------------------- /src/diffwave/learner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import os 18 | import torch 19 | import torch.nn as nn 20 | 21 | from torch.nn.parallel import DistributedDataParallel 22 | from torch.utils.tensorboard import SummaryWriter 23 | from tqdm import tqdm 24 | 25 | from diffwave.dataset import from_path, from_gtzan 26 | from diffwave.model import DiffWave 27 | from diffwave.params import AttrDict 28 | 29 | 30 | def _nested_map(struct, map_fn): 31 | if isinstance(struct, tuple): 32 | return tuple(_nested_map(x, map_fn) for x in struct) 33 | if isinstance(struct, list): 34 | return [_nested_map(x, map_fn) for x in struct] 35 | if isinstance(struct, dict): 36 | return { k: _nested_map(v, map_fn) for k, v in struct.items() } 37 | return map_fn(struct) 38 | 39 | 40 | class DiffWaveLearner: 41 | def __init__(self, model_dir, model, dataset, optimizer, params, *args, **kwargs): 42 | os.makedirs(model_dir, exist_ok=True) 43 | self.model_dir = model_dir 44 | self.model = model 45 | self.dataset = dataset 46 | self.optimizer = optimizer 47 | self.params = params 48 | self.autocast = torch.cuda.amp.autocast(enabled=kwargs.get('fp16', False)) 49 | self.scaler = torch.cuda.amp.GradScaler(enabled=kwargs.get('fp16', False)) 50 | self.step = 0 51 | self.is_master = True 52 | 53 | beta = np.array(self.params.noise_schedule) 54 | noise_level = np.cumprod(1 - beta) 55 | self.noise_level = torch.tensor(noise_level.astype(np.float32)) 56 | self.loss_fn = nn.L1Loss() 57 | self.summary_writer = None 58 | 59 | def state_dict(self): 60 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 61 | model_state = self.model.module.state_dict() 62 | else: 63 | model_state = self.model.state_dict() 64 | return { 65 | 'step': self.step, 66 | 'model': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model_state.items() }, 67 | 'optimizer': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.optimizer.state_dict().items() }, 68 | 'params': dict(self.params), 69 | 'scaler': self.scaler.state_dict(), 70 | } 71 | 72 | def load_state_dict(self, state_dict): 73 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 74 | self.model.module.load_state_dict(state_dict['model']) 75 | else: 76 | self.model.load_state_dict(state_dict['model']) 77 | self.optimizer.load_state_dict(state_dict['optimizer']) 78 | self.scaler.load_state_dict(state_dict['scaler']) 79 | self.step = state_dict['step'] 80 | 81 | def save_to_checkpoint(self, filename='weights'): 82 | save_basename = f'{filename}-{self.step}.pt' 83 | save_name = f'{self.model_dir}/{save_basename}' 84 | link_name = f'{self.model_dir}/{filename}.pt' 85 | torch.save(self.state_dict(), save_name) 86 | if os.name == 'nt': 87 | torch.save(self.state_dict(), link_name) 88 | else: 89 | if os.path.islink(link_name): 90 | os.unlink(link_name) 91 | os.symlink(save_basename, link_name) 92 | 93 | def restore_from_checkpoint(self, filename='weights'): 94 | try: 95 | checkpoint = torch.load(f'{self.model_dir}/{filename}.pt') 96 | self.load_state_dict(checkpoint) 97 | return True 98 | except FileNotFoundError: 99 | return False 100 | 101 | def train(self, max_steps=None): 102 | device = next(self.model.parameters()).device 103 | while True: 104 | for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset: 105 | if max_steps is not None and self.step >= max_steps: 106 | return 107 | features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x) 108 | loss = self.train_step(features) 109 | if torch.isnan(loss).any(): 110 | raise RuntimeError(f'Detected NaN loss at step {self.step}.') 111 | if self.is_master: 112 | if self.step % 50 == 0: 113 | self._write_summary(self.step, features, loss) 114 | if self.step % len(self.dataset) == 0: 115 | self.save_to_checkpoint() 116 | self.step += 1 117 | 118 | def train_step(self, features): 119 | for param in self.model.parameters(): 120 | param.grad = None 121 | 122 | audio = features['audio'] 123 | spectrogram = features['spectrogram'] 124 | 125 | N, T = audio.shape 126 | device = audio.device 127 | self.noise_level = self.noise_level.to(device) 128 | 129 | with self.autocast: 130 | t = torch.randint(0, len(self.params.noise_schedule), [N], device=audio.device) 131 | noise_scale = self.noise_level[t].unsqueeze(1) 132 | noise_scale_sqrt = noise_scale**0.5 133 | noise = torch.randn_like(audio) 134 | noisy_audio = noise_scale_sqrt * audio + (1.0 - noise_scale)**0.5 * noise 135 | 136 | predicted = self.model(noisy_audio, t, spectrogram) 137 | loss = self.loss_fn(noise, predicted.squeeze(1)) 138 | 139 | self.scaler.scale(loss).backward() 140 | self.scaler.unscale_(self.optimizer) 141 | self.grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.params.max_grad_norm or 1e9) 142 | self.scaler.step(self.optimizer) 143 | self.scaler.update() 144 | return loss 145 | 146 | def _write_summary(self, step, features, loss): 147 | writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step) 148 | writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate) 149 | if not self.params.unconditional: 150 | writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step) 151 | writer.add_scalar('train/loss', loss, step) 152 | writer.add_scalar('train/grad_norm', self.grad_norm, step) 153 | writer.flush() 154 | self.summary_writer = writer 155 | 156 | 157 | def _train_impl(replica_id, model, dataset, args, params): 158 | torch.backends.cudnn.benchmark = True 159 | opt = torch.optim.Adam(model.parameters(), lr=params.learning_rate) 160 | 161 | learner = DiffWaveLearner(args.model_dir, model, dataset, opt, params, fp16=args.fp16) 162 | learner.is_master = (replica_id == 0) 163 | learner.restore_from_checkpoint() 164 | learner.train(max_steps=args.max_steps) 165 | 166 | 167 | def train(args, params): 168 | if args.data_dirs[0] == 'gtzan': 169 | dataset = from_gtzan(params) 170 | else: 171 | dataset = from_path(args.data_dirs, params) 172 | model = DiffWave(params).cuda() 173 | _train_impl(0, model, dataset, args, params) 174 | 175 | 176 | def train_distributed(replica_id, replica_count, port, args, params): 177 | os.environ['MASTER_ADDR'] = 'localhost' 178 | os.environ['MASTER_PORT'] = str(port) 179 | torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count) 180 | if args.data_dirs[0] == 'gtzan': 181 | dataset = from_gtzan(params, is_distributed=True) 182 | else: 183 | dataset = from_path(args.data_dirs, params, is_distributed=True) 184 | device = torch.device('cuda', replica_id) 185 | torch.cuda.set_device(device) 186 | model = DiffWave(params).to(device) 187 | model = DistributedDataParallel(model, device_ids=[replica_id]) 188 | _train_impl(replica_id, model, dataset, args, params) 189 | -------------------------------------------------------------------------------- /src/diffwave/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from math import sqrt 22 | 23 | 24 | Linear = nn.Linear 25 | ConvTranspose2d = nn.ConvTranspose2d 26 | 27 | 28 | def Conv1d(*args, **kwargs): 29 | layer = nn.Conv1d(*args, **kwargs) 30 | nn.init.kaiming_normal_(layer.weight) 31 | return layer 32 | 33 | 34 | @torch.jit.script 35 | def silu(x): 36 | return x * torch.sigmoid(x) 37 | 38 | 39 | class DiffusionEmbedding(nn.Module): 40 | def __init__(self, max_steps): 41 | super().__init__() 42 | self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False) 43 | self.projection1 = Linear(128, 512) 44 | self.projection2 = Linear(512, 512) 45 | 46 | def forward(self, diffusion_step): 47 | if diffusion_step.dtype in [torch.int32, torch.int64]: 48 | x = self.embedding[diffusion_step] 49 | else: 50 | x = self._lerp_embedding(diffusion_step) 51 | x = self.projection1(x) 52 | x = silu(x) 53 | x = self.projection2(x) 54 | x = silu(x) 55 | return x 56 | 57 | def _lerp_embedding(self, t): 58 | low_idx = torch.floor(t).long() 59 | high_idx = torch.ceil(t).long() 60 | low = self.embedding[low_idx] 61 | high = self.embedding[high_idx] 62 | return low + (high - low) * (t - low_idx) 63 | 64 | def _build_embedding(self, max_steps): 65 | steps = torch.arange(max_steps).unsqueeze(1) # [T,1] 66 | dims = torch.arange(64).unsqueeze(0) # [1,64] 67 | table = steps * 10.0**(dims * 4.0 / 63.0) # [T,64] 68 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) 69 | return table 70 | 71 | 72 | class SpectrogramUpsampler(nn.Module): 73 | def __init__(self, n_mels): 74 | super().__init__() 75 | self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) 76 | self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) 77 | 78 | def forward(self, x): 79 | x = torch.unsqueeze(x, 1) 80 | x = self.conv1(x) 81 | x = F.leaky_relu(x, 0.4) 82 | x = self.conv2(x) 83 | x = F.leaky_relu(x, 0.4) 84 | x = torch.squeeze(x, 1) 85 | return x 86 | 87 | 88 | class ResidualBlock(nn.Module): 89 | def __init__(self, n_mels, residual_channels, dilation, uncond=False): 90 | ''' 91 | :param n_mels: inplanes of conv1x1 for spectrogram conditional 92 | :param residual_channels: audio conv 93 | :param dilation: audio conv dilation 94 | :param uncond: disable spectrogram conditional 95 | ''' 96 | super().__init__() 97 | self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) 98 | self.diffusion_projection = Linear(512, residual_channels) 99 | if not uncond: # conditional model 100 | self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) 101 | else: # unconditional model 102 | self.conditioner_projection = None 103 | 104 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 105 | 106 | def forward(self, x, diffusion_step, conditioner=None): 107 | assert (conditioner is None and self.conditioner_projection is None) or \ 108 | (conditioner is not None and self.conditioner_projection is not None) 109 | 110 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 111 | y = x + diffusion_step 112 | if self.conditioner_projection is None: # using a unconditional model 113 | y = self.dilated_conv(y) 114 | else: 115 | conditioner = self.conditioner_projection(conditioner) 116 | y = self.dilated_conv(y) + conditioner 117 | 118 | gate, filter = torch.chunk(y, 2, dim=1) 119 | y = torch.sigmoid(gate) * torch.tanh(filter) 120 | 121 | y = self.output_projection(y) 122 | residual, skip = torch.chunk(y, 2, dim=1) 123 | return (x + residual) / sqrt(2.0), skip 124 | 125 | 126 | class DiffWave(nn.Module): 127 | def __init__(self, params): 128 | super().__init__() 129 | self.params = params 130 | self.input_projection = Conv1d(1, params.residual_channels, 1) 131 | self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule)) 132 | if self.params.unconditional: # use unconditional model 133 | self.spectrogram_upsampler = None 134 | else: 135 | self.spectrogram_upsampler = SpectrogramUpsampler(params.n_mels) 136 | 137 | self.residual_layers = nn.ModuleList([ 138 | ResidualBlock(params.n_mels, params.residual_channels, 2**(i % params.dilation_cycle_length), uncond=params.unconditional) 139 | for i in range(params.residual_layers) 140 | ]) 141 | self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1) 142 | self.output_projection = Conv1d(params.residual_channels, 1, 1) 143 | nn.init.zeros_(self.output_projection.weight) 144 | 145 | def forward(self, audio, diffusion_step, spectrogram=None): 146 | assert (spectrogram is None and self.spectrogram_upsampler is None) or \ 147 | (spectrogram is not None and self.spectrogram_upsampler is not None) 148 | x = audio.unsqueeze(1) 149 | x = self.input_projection(x) 150 | x = F.relu(x) 151 | 152 | diffusion_step = self.diffusion_embedding(diffusion_step) 153 | if self.spectrogram_upsampler: # use conditional model 154 | spectrogram = self.spectrogram_upsampler(spectrogram) 155 | 156 | skip = None 157 | for layer in self.residual_layers: 158 | x, skip_connection = layer(x, diffusion_step, spectrogram) 159 | skip = skip_connection if skip is None else skip_connection + skip 160 | 161 | x = skip / sqrt(len(self.residual_layers)) 162 | x = self.skip_projection(x) 163 | x = F.relu(x) 164 | x = self.output_projection(x) 165 | return x 166 | -------------------------------------------------------------------------------- /src/diffwave/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | 18 | 19 | class AttrDict(dict): 20 | def __init__(self, *args, **kwargs): 21 | super(AttrDict, self).__init__(*args, **kwargs) 22 | self.__dict__ = self 23 | 24 | def override(self, attrs): 25 | if isinstance(attrs, dict): 26 | self.__dict__.update(**attrs) 27 | elif isinstance(attrs, (list, tuple, set)): 28 | for attr in attrs: 29 | self.override(attr) 30 | elif attrs is not None: 31 | raise NotImplementedError 32 | return self 33 | 34 | 35 | params = AttrDict( 36 | # Training params 37 | batch_size=16, 38 | learning_rate=2e-4, 39 | max_grad_norm=None, 40 | 41 | # Data params 42 | sample_rate=22050, 43 | n_mels=80, 44 | n_fft=1024, 45 | hop_samples=256, 46 | crop_mel_frames=62, # Probably an error in paper. 47 | 48 | # Model params 49 | residual_layers=30, 50 | residual_channels=64, 51 | dilation_cycle_length=10, 52 | unconditional = False, 53 | noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(), 54 | inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5], 55 | 56 | # unconditional sample len 57 | audio_len = 22050*5, # unconditional_synthesis_samples 58 | ) 59 | -------------------------------------------------------------------------------- /src/diffwave/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torchaudio as T 19 | import torchaudio.transforms as TT 20 | 21 | from argparse import ArgumentParser 22 | from concurrent.futures import ProcessPoolExecutor 23 | from glob import glob 24 | from tqdm import tqdm 25 | 26 | from diffwave.params import params 27 | 28 | 29 | def transform(filename): 30 | audio, sr = T.load(filename) 31 | audio = torch.clamp(audio[0], -1.0, 1.0) 32 | 33 | if params.sample_rate != sr: 34 | raise ValueError(f'Invalid sample rate {sr}.') 35 | mel_args = { 36 | 'sample_rate': sr, 37 | 'win_length': params.hop_samples * 4, 38 | 'hop_length': params.hop_samples, 39 | 'n_fft': params.n_fft, 40 | 'f_min': 20.0, 41 | 'f_max': sr / 2.0, 42 | 'n_mels': params.n_mels, 43 | 'power': 1.0, 44 | 'normalized': True, 45 | } 46 | mel_spec_transform = TT.MelSpectrogram(**mel_args) 47 | 48 | with torch.no_grad(): 49 | spectrogram = mel_spec_transform(audio) 50 | spectrogram = 20 * torch.log10(torch.clamp(spectrogram, min=1e-5)) - 20 51 | spectrogram = torch.clamp((spectrogram + 100) / 100, 0.0, 1.0) 52 | np.save(f'{filename}.spec.npy', spectrogram.cpu().numpy()) 53 | 54 | 55 | def main(args): 56 | filenames = glob(f'{args.dir}/**/*.wav', recursive=True) 57 | with ProcessPoolExecutor() as executor: 58 | list(tqdm(executor.map(transform, filenames), desc='Preprocessing', total=len(filenames))) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = ArgumentParser(description='prepares a dataset to train DiffWave') 63 | parser.add_argument('dir', 64 | help='directory containing .wav files for training') 65 | main(parser.parse_args()) 66 | --------------------------------------------------------------------------------