├── .gitignore ├── LICENSE ├── README.md ├── check_dataset.py ├── dataset.py ├── inference.py ├── model.py ├── params.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NU-Wave 2 [WIP] : 2 | Unofficial implementation of [NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates](https://arxiv.org/abs/2206.08545). 3 | 4 | ## Coming Soon 5 | 6 | ## Official Code is released [here.](https://github.com/mindslab-ai/nuwave2) 7 | 8 | ## Citations: 9 | ``` 10 | @misc{https://doi.org/10.48550/arxiv.2206.08545, 11 | doi = {10.48550/ARXIV.2206.08545}, 12 | 13 | url = {https://arxiv.org/abs/2206.08545}, 14 | 15 | author = {Han, Seungu and Lee, Junhyeok}, 16 | 17 | keywords = {Audio and Speech Processing (eess.AS), Machine Learning (cs.LG), FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Computer and information sciences, FOS: Computer and information sciences}, 18 | 19 | title = {NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates}, 20 | 21 | publisher = {arXiv}, 22 | 23 | year = {2022}, 24 | 25 | copyright = {arXiv.org perpetual, non-exclusive license} 26 | } 27 | 28 | ``` 29 | -------------------------------------------------------------------------------- /check_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torchaudio as T 5 | from params import params 6 | from tqdm import tqdm 7 | 8 | 9 | class AudioDataset: 10 | def __init__(self, params): 11 | self.params = params 12 | self.path = params.path 13 | self.wav_list = glob.glob( 14 | os.path.join(self.path, "**", "*.wav"), recursive=True 15 | ) 16 | 17 | self.mapping = [i for i in range(len(self.wav_list))] 18 | self.downsample = T.transforms.Resample( 19 | params.new_sample_rate, 20 | params.sample_rate, 21 | resampling_method="sinc_interpolation", 22 | ) 23 | 24 | def check_dataset(self, idx): 25 | wavpath = self.wav_list[idx] 26 | id = os.path.basename(wavpath).split(".")[0] 27 | audio, sr = T.load_wav(wavpath) 28 | if self.params.new_sample_rate != sr: 29 | raise ValueError(f"Invalid sample rate {sr}.") 30 | 31 | start = np.random.randint(0, audio.shape[1] - self.params.n_segment - 1) 32 | 33 | if audio.shape[0] == 2: 34 | audio = audio[0, :] 35 | audio = audio.squeeze(0)[start : start + self.params.n_segment] 36 | audio = audio / 32767.5 37 | 38 | lr_audio = self.downsample(audio) 39 | lr_audio = lr_audio / 32767.5 40 | 41 | return {"audio": audio, "lr_audio": lr_audio, "id": id} 42 | 43 | 44 | if __name__ == "__main__": 45 | M = AudioDataset(params) 46 | for i in tqdm(range(0, len(M.wav_list))): 47 | try: 48 | out = M.check_dataset(i) 49 | except Exception as e: 50 | print(e) 51 | continue 52 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import numpy as np 5 | import torchaudio as T 6 | from torch.utils.data import Dataset, DataLoader 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | 10 | def create_dataloader(params, train, is_distributed=False): 11 | dataset = AudioDataset(params, train) 12 | 13 | return DataLoader( 14 | dataset=dataset, 15 | batch_size=params.batch_size, 16 | shuffle=not is_distributed, 17 | sampler=DistributedSampler(dataset) if is_distributed else None, 18 | num_workers=0, 19 | pin_memory=True, 20 | drop_last=True, 21 | ) 22 | 23 | 24 | class AudioDataset(Dataset): 25 | def __init__(self, params, train): 26 | self.params = params 27 | self.train = train 28 | self.path = params.path 29 | self.wav_list = glob.glob( 30 | os.path.join(self.path, "**", "*.wav"), recursive=True 31 | ) 32 | 33 | self.mapping = [i for i in range(len(self.wav_list))] 34 | self.downsample = T.transforms.Resample( 35 | params.new_sample_rate, 36 | params.sample_rate, 37 | resampling_method="sinc_interpolation", 38 | ) 39 | 40 | def __len__(self): 41 | return len(self.wav_list) 42 | 43 | def __getitem__(self, idx): 44 | return self.my_getitem(idx) 45 | 46 | def shuffle_mapping(self): 47 | random.shuffle(self.mapping) 48 | 49 | def my_getitem(self, idx): 50 | wavpath = self.wav_list[idx] 51 | id = os.path.basename(wavpath).split(".")[0] 52 | audio, sr = T.load_wav(wavpath) 53 | if self.params.new_sample_rate != sr: 54 | raise ValueError(f"Invalid sample rate {sr}.") 55 | 56 | start = np.random.randint(0, audio.shape[1] - self.params.n_segment - 1) 57 | 58 | if audio.shape[0] == 2: 59 | audio = audio[0, :] 60 | audio = audio.squeeze(0)[start : start + self.params.n_segment] 61 | 62 | lr_audio = self.downsample(audio) 63 | lr_audio = lr_audio / 32767.5 64 | audio = audio / 32767.5 65 | 66 | return {"audio": audio, "lr_audio": lr_audio, "id": id} 67 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torchaudio 5 | 6 | from argparse import ArgumentParser 7 | 8 | from params import AttrDict, params as base_params 9 | from model import NUWave 10 | 11 | 12 | models = {} 13 | 14 | 15 | def predict(lr_audio, model_dir=None, params=None, device=torch.device("cuda")): 16 | # Lazy load model. 17 | if not model_dir in models: 18 | if os.path.exists(f"{model_dir}/weights.pt"): 19 | checkpoint = torch.load(f"{model_dir}/weights.pt") 20 | else: 21 | checkpoint = torch.load(model_dir, map_location=device) 22 | model = NUWave(AttrDict(base_params)).to(device) 23 | model.load_state_dict(checkpoint["model"]) 24 | model.eval() 25 | models[model_dir] = model 26 | 27 | model = models[model_dir] 28 | model.params.override(params) 29 | with torch.no_grad(): 30 | beta = np.array(model.params.inference_noise_schedule) 31 | alpha = 1 - beta 32 | alpha_cum = np.cumprod(alpha) 33 | 34 | # Expand rank 2 tensors by adding a batch dimension. 35 | if len(lr_audio.shape) == 1: 36 | lr_audio = lr_audio.unsqueeze(0) 37 | lr_audio = lr_audio.to(device) 38 | 39 | audio = torch.randn(lr_audio.shape[0], 2 * lr_audio.shape[-1], device=device) 40 | noise_scale = torch.from_numpy(alpha_cum ** 0.5).float().unsqueeze(1).to(device) 41 | 42 | for n in range(len(alpha) - 1, -1, -1): 43 | c1 = 1 / alpha[n] ** 0.5 44 | c2 = (1 - alpha[n]) / (1 - alpha_cum[n]) ** 0.5 45 | audio = c1 * ( 46 | audio - c2 * model(audio, lr_audio, noise_scale[n]).squeeze(1) 47 | ) 48 | if n > 0: 49 | noise = torch.randn_like(audio) 50 | sigma = ( 51 | (1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) * beta[n] 52 | ) ** 0.5 53 | audio += sigma * noise 54 | audio = torch.clamp(audio, -1.0, 1.0) 55 | return audio, model.params.new_sample_rate 56 | 57 | 58 | def main(args): 59 | lr_audio, sr = torchaudio.load(args.audio_path) 60 | if 22050 != sr: 61 | raise ValueError(f"Invalid sample rate {sr}.") 62 | params = {} 63 | # if args.noise_schedule: 64 | # params["noise_schedule"] = torch.from_numpy(np.load(args.noise_schedule)) 65 | 66 | audio, sr = predict(lr_audio, model_dir=args.model_dir, params=params) 67 | torchaudio.save(args.output, audio.cpu(), sample_rate=sr) 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = ArgumentParser( 72 | description="runs inference on a spectrogram file generated by wavegrad.preprocess" 73 | ) 74 | parser.add_argument( 75 | "model_dir", 76 | help="directory containing a trained model (or full path to weights.pt file)", 77 | ) 78 | parser.add_argument( 79 | "audio_path", 80 | help="path to a low resolution file generated", 81 | ) 82 | 83 | parser.add_argument("--output", "-o", default="output.wav", help="output file name") 84 | main(parser.parse_args()) 85 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from math import sqrt 7 | 8 | 9 | Linear = nn.Linear 10 | ConvTranspose2d = nn.ConvTranspose2d 11 | 12 | 13 | def Conv1d(*args, **kwargs): 14 | layer = nn.Conv1d(*args, **kwargs) 15 | nn.init.kaiming_normal_(layer.weight) 16 | return layer 17 | 18 | 19 | @torch.jit.script 20 | def silu(x): 21 | return x * torch.sigmoid(x) 22 | 23 | 24 | LINEAR_SCALE = 50000 25 | GAMMA = 1 // 16 26 | 27 | 28 | class PositionalEncoding(nn.Module): 29 | def __init__(self, n_channels): 30 | super(PositionalEncoding, self).__init__() 31 | self.n_channels = n_channels 32 | 33 | def forward(self, noise_level): 34 | if len(noise_level.shape) > 1: 35 | noise_level = noise_level.squeeze(-1) 36 | half_dim = self.n_channels // 2 37 | exponents = torch.arange(half_dim, dtype=torch.float32).to(noise_level) / float( 38 | half_dim 39 | ) 40 | exponents = 10 ** -(exponents * GAMMA) 41 | exponents = LINEAR_SCALE * noise_level.unsqueeze(1) * exponents.unsqueeze(0) 42 | return torch.cat([exponents.sin(), exponents.cos()], dim=-1) 43 | 44 | 45 | class DiffusionEmbedding(nn.Module): 46 | def __init__(self, n_channels): 47 | super().__init__() 48 | self.embedding = PositionalEncoding(n_channels) 49 | self.projection1 = Linear(n_channels, 512) 50 | self.projection2 = Linear(512, 512) 51 | 52 | def forward(self, noise_level): 53 | x = self.embedding(noise_level) 54 | x = self.projection1(x) 55 | x = silu(x) 56 | x = self.projection2(x) 57 | x = silu(x) 58 | return x 59 | 60 | 61 | class ResidualBlock(nn.Module): 62 | def __init__(self, residual_channels, dilation): 63 | super().__init__() 64 | self.dilated_conv = Conv1d( 65 | residual_channels, 66 | 2 * residual_channels, 67 | 3, 68 | padding=dilation, 69 | dilation=dilation, 70 | ) 71 | self.diffusion_projection = Linear(512, residual_channels) 72 | self.conditioner_projection = Conv1d( 73 | residual_channels, 74 | 2 * residual_channels, 75 | 3, 76 | padding=dilation, 77 | dilation=dilation, 78 | ) 79 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 80 | 81 | def forward(self, x, conditioner, noise_scale): 82 | 83 | noise_scale = self.diffusion_projection(noise_scale).unsqueeze(-1) 84 | conditioner = self.conditioner_projection(conditioner) 85 | y = x + noise_scale 86 | y = self.dilated_conv(y) + conditioner 87 | 88 | gate, filter = torch.chunk(y, 2, dim=1) 89 | y = torch.sigmoid(gate) * torch.tanh(filter) 90 | 91 | y = self.output_projection(y) 92 | residual, skip = torch.chunk(y, 2, dim=1) 93 | return (x + residual) / sqrt(2.0), skip 94 | 95 | 96 | class NUWave(nn.Module): 97 | def __init__(self, params): 98 | super().__init__() 99 | self.params = params 100 | self.factor = 2 101 | self.input_projection = Conv1d( 102 | params.input_channels, params.residual_channels, 1 103 | ) 104 | self.conditioner_projection = Conv1d( 105 | params.input_channels, params.residual_channels, 1 106 | ) 107 | self.diffusion_embedding = DiffusionEmbedding(params.residual_channels) 108 | self.residual_layers = nn.ModuleList( 109 | [ 110 | ResidualBlock( 111 | params.residual_channels, 112 | 2 ** (i % params.dilation_cycle_length), 113 | ) 114 | for i in range(params.residual_layers) 115 | ] 116 | ) 117 | self.skip_projection = Conv1d( 118 | params.residual_channels, params.residual_channels, 1 119 | ) 120 | self.output_projection = Conv1d( 121 | params.residual_channels, params.output_channels, 1 122 | ) 123 | nn.init.zeros_(self.output_projection.weight) 124 | 125 | def forward(self, audio, lr_audio, noise_scale): 126 | x = audio.unsqueeze(1) 127 | x = self.input_projection(x) 128 | x = silu(x) 129 | 130 | noise_scale = self.diffusion_embedding(noise_scale) 131 | lr_audio = lr_audio.unsqueeze(1) 132 | cond = F.interpolate(lr_audio, size=lr_audio.shape[-1] * self.factor) 133 | cond = self.conditioner_projection(cond) 134 | cond = silu(cond) 135 | 136 | skip = [] 137 | for layer in self.residual_layers: 138 | x, skip_connection = layer(x, cond, noise_scale) 139 | skip.append(skip_connection) 140 | 141 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) 142 | x = self.skip_projection(x) 143 | x = silu(x) 144 | x = self.output_projection(x) 145 | return x 146 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self 8 | 9 | def override(self, attrs): 10 | if isinstance(attrs, dict): 11 | self.__dict__.update(**attrs) 12 | elif isinstance(attrs, (list, tuple, set)): 13 | for attr in attrs: 14 | self.override(attr) 15 | elif attrs is not None: 16 | raise NotImplementedError 17 | return self 18 | 19 | 20 | params = AttrDict( 21 | # Training params 22 | path="wavs_dir", 23 | data_dir="./data/", 24 | batch_size=12, 25 | learning_rate=3e-5, 26 | max_grad_norm=None, 27 | # Data params 28 | sample_rate=22050, 29 | n_mels=256, 30 | n_fft=1024, 31 | hop_samples=256, 32 | crop_mel_frames=62, # Probably an error in paper. 33 | n_segment=32768, # For 44.1KHz -> 22.05 kHz n_segment mod 2 34 | new_sample_rate=44100, 35 | # Model params 36 | input_channels=1, 37 | output_channels=1, 38 | residual_layers=30, 39 | residual_channels=64, 40 | dilation_cycle_length=10, 41 | noise_schedule=np.linspace(1e-6, 0.006, 1000).tolist(), 42 | inference_noise_schedule=np.linspace(1e-6, 0.006, 150).tolist(), 43 | ) 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from argparse import ArgumentParser 6 | from torch.cuda import device_count 7 | from torch.multiprocessing import spawn 8 | from torch.nn.parallel import DistributedDataParallel 9 | from torch.utils.tensorboard import SummaryWriter 10 | from tqdm import tqdm 11 | 12 | from dataset import create_dataloader 13 | from model import NUWave 14 | from params import params 15 | 16 | 17 | def _nested_map(struct, map_fn): 18 | if isinstance(struct, tuple): 19 | return tuple(_nested_map(x, map_fn) for x in struct) 20 | if isinstance(struct, list): 21 | return [_nested_map(x, map_fn) for x in struct] 22 | if isinstance(struct, dict): 23 | return {k: _nested_map(v, map_fn) for k, v in struct.items()} 24 | return map_fn(struct) 25 | 26 | 27 | class NUWaveLearner: 28 | def __init__(self, model_dir, model, dataset, optimizer, params, *args, **kwargs): 29 | os.makedirs(model_dir, exist_ok=True) 30 | self.model_dir = model_dir 31 | self.model = model 32 | self.dataset = dataset 33 | self.optimizer = optimizer 34 | self.params = params 35 | self.autocast = torch.cuda.amp.autocast(enabled=kwargs.get("fp16", False)) 36 | self.scaler = torch.cuda.amp.GradScaler(enabled=kwargs.get("fp16", False)) 37 | self.step = 0 38 | self.is_master = True 39 | 40 | beta = np.array(self.params.noise_schedule) 41 | noise_level = np.cumprod(1 - beta) 42 | noise_level = np.concatenate([[1.0], noise_level], axis=0) 43 | self.noise_level = torch.tensor(noise_level.astype(np.float32)) 44 | self.loss_fn = nn.L1Loss() 45 | self.summary_writer = None 46 | 47 | def state_dict(self): 48 | if hasattr(self.model, "module") and isinstance(self.model.module, nn.Module): 49 | model_state = self.model.module.state_dict() 50 | else: 51 | model_state = self.model.state_dict() 52 | return { 53 | "step": self.step, 54 | "model": { 55 | k: v.cpu() if isinstance(v, torch.Tensor) else v 56 | for k, v in model_state.items() 57 | }, 58 | "optimizer": { 59 | k: v.cpu() if isinstance(v, torch.Tensor) else v 60 | for k, v in self.optimizer.state_dict().items() 61 | }, 62 | "params": dict(self.params), 63 | "scaler": self.scaler.state_dict(), 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | if hasattr(self.model, "module") and isinstance(self.model.module, nn.Module): 68 | self.model.module.load_state_dict(state_dict["model"]) 69 | else: 70 | self.model.load_state_dict(state_dict["model"]) 71 | self.optimizer.load_state_dict(state_dict["optimizer"]) 72 | self.scaler.load_state_dict(state_dict["scaler"]) 73 | self.step = state_dict["step"] 74 | 75 | def save_to_checkpoint(self, filename="weights"): 76 | save_basename = f"{filename}-{self.step}.pt" 77 | save_name = f"{self.model_dir}/{save_basename}" 78 | link_name = f"{self.model_dir}/{filename}.pt" 79 | torch.save(self.state_dict(), save_name) 80 | if os.name == "nt": 81 | torch.save(self.state_dict(), link_name) 82 | else: 83 | if os.path.islink(link_name): 84 | os.unlink(link_name) 85 | os.symlink(save_basename, link_name) 86 | 87 | def restore_from_checkpoint(self, filename="weights"): 88 | try: 89 | checkpoint = torch.load(f"{self.model_dir}/{filename}.pt") 90 | self.load_state_dict(checkpoint) 91 | return True 92 | except FileNotFoundError: 93 | return False 94 | 95 | def train(self, max_steps=None): 96 | device = next(self.model.parameters()).device 97 | while True: 98 | for features in ( 99 | tqdm(self.dataset, desc=f"Epoch {self.step // len(self.dataset)}") 100 | if self.is_master 101 | else self.dataset 102 | ): 103 | if max_steps is not None and self.step >= max_steps: 104 | return 105 | features = _nested_map( 106 | features, 107 | lambda x: x.to(device) if isinstance(x, torch.Tensor) else x, 108 | ) 109 | loss = self.train_step(features) 110 | if torch.isnan(loss).any(): 111 | raise RuntimeError(f"Detected NaN loss at step {self.step}.") 112 | if self.is_master: 113 | if self.step % 50 == 0: 114 | self._write_summary(self.step, features, loss) 115 | if self.step % len(self.dataset) == 0: 116 | self.save_to_checkpoint() 117 | self.step += 1 118 | 119 | def train_step(self, features): 120 | for param in self.model.parameters(): 121 | param.grad = None 122 | 123 | lr_audio = features["lr_audio"] 124 | audio = features["audio"] 125 | 126 | N, T = audio.shape 127 | S = 1000 128 | device = audio.device 129 | self.noise_level = self.noise_level.to(device) 130 | 131 | with self.autocast: 132 | 133 | s = torch.randint(1, S + 1, [N], device=audio.device) 134 | l_a, l_b = self.noise_level[s - 1], self.noise_level[s] 135 | noise_scale = l_a + torch.rand(N, device=audio.device) * (l_b - l_a) 136 | noise_scale = noise_scale.unsqueeze(1) 137 | noise = torch.randn_like(audio) 138 | 139 | noisy_audio = noise_scale * audio + (1.0 - noise_scale ** 2) ** 0.5 * noise 140 | predicted = self.model(noisy_audio, lr_audio, noise_scale.squeeze(1)) 141 | loss = self.loss_fn(noise, predicted.squeeze(1)) 142 | 143 | self.scaler.scale(loss).backward() 144 | self.scaler.unscale_(self.optimizer) 145 | self.grad_norm = nn.utils.clip_grad_norm_( 146 | self.model.parameters(), self.params.max_grad_norm or 1e9 147 | ) 148 | self.scaler.step(self.optimizer) 149 | self.scaler.update() 150 | return loss 151 | 152 | def _write_summary(self, step, features, loss): 153 | writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step) 154 | writer.add_audio( 155 | "feature/audio", 156 | features["audio"][0], 157 | step, 158 | sample_rate=self.params.new_sample_rate, 159 | ) 160 | writer.add_audio( 161 | "feature/lr_audio", 162 | features["lr_audio"][0], 163 | step, 164 | sample_rate=self.params.sample_rate, 165 | ) 166 | writer.add_scalar("train/loss", loss, step) 167 | writer.add_scalar("train/grad_norm", self.grad_norm, step) 168 | writer.flush() 169 | self.summary_writer = writer 170 | 171 | 172 | def _train_impl(replica_id, model, dataset, args, params): 173 | torch.backends.cudnn.benchmark = True 174 | opt = torch.optim.Adam(model.parameters(), lr=params.learning_rate) 175 | 176 | learner = NUWaveLearner(args.model_dir, model, dataset, opt, params, fp16=args.fp16) 177 | learner.is_master = replica_id == 0 178 | learner.restore_from_checkpoint() 179 | learner.train(max_steps=args.max_steps) 180 | 181 | 182 | def train(args, params): 183 | dataset = create_dataloader( 184 | params, 185 | True, 186 | ) 187 | model = NUWave(params).cuda() 188 | _train_impl(0, model, dataset, args, params) 189 | 190 | 191 | def train_distributed(replica_id, replica_count, port, args, params): 192 | os.environ["MASTER_ADDR"] = "localhost" 193 | os.environ["MASTER_PORT"] = str(port) 194 | torch.distributed.init_process_group( 195 | "nccl", rank=replica_id, world_size=replica_count 196 | ) 197 | 198 | device = torch.device("cuda", replica_id) 199 | torch.cuda.set_device(device) 200 | model = NUWave(params).to(device) 201 | model = DistributedDataParallel(model, device_ids=[replica_id]) 202 | _train_impl( 203 | replica_id, 204 | model, 205 | create_dataloader(params, True, is_distributed=True), 206 | args, 207 | params, 208 | ) 209 | 210 | 211 | def _get_free_port(): 212 | import socketserver 213 | 214 | with socketserver.TCPServer(("localhost", 0), None) as s: 215 | return s.server_address[1] 216 | 217 | 218 | def main(args): 219 | replica_count = device_count() 220 | if replica_count > 1: 221 | if params.batch_size % replica_count != 0: 222 | raise ValueError( 223 | f"Batch size {params.batch_size} is not evenly divisble by # GPUs {replica_count}." 224 | ) 225 | params.batch_size = params.batch_size // replica_count 226 | port = _get_free_port() 227 | spawn( 228 | train_distributed, 229 | args=(replica_count, port, args, params), 230 | nprocs=replica_count, 231 | join=True, 232 | ) 233 | else: 234 | train(args, params) 235 | 236 | 237 | if __name__ == "__main__": 238 | parser = ArgumentParser(description="train (or resume training) a DiffWave model") 239 | parser.add_argument( 240 | "model_dir", 241 | help="directory in which to store model checkpoints and training logs", 242 | ) 243 | 244 | parser.add_argument( 245 | "--max_steps", default=None, type=int, help="maximum number of training steps" 246 | ) 247 | parser.add_argument( 248 | "--fp16", 249 | action="store_true", 250 | default=False, 251 | help="use 16-bit floating point operations for training", 252 | ) 253 | main(parser.parse_args()) 254 | --------------------------------------------------------------------------------