├── .github └── FUNDING.yml ├── .gitignore ├── .gitmodules ├── README.md ├── configs ├── melglow_LJ_speech.json ├── mr_waveglow_LJ_speech.json ├── musicnet_config.json ├── waveflow_LJ_speech.json ├── waveglow_LJ_speech.json ├── waveglow_LJ_speech_fast.json ├── wsrglow_vctk_2x.json └── wsrglow_vctk_3x.json ├── inference.py ├── model ├── __init__.py ├── base.py ├── condition.py ├── efficient_modules.py ├── lightning.py ├── loss.py ├── melglow.py ├── mr_waveglow.py ├── waveflow.py ├── waveglow.py └── wsrglow.py ├── samples ├── 2293_generated.wav ├── 2298_generated.wav └── waveflow_64chs │ ├── LJ001-0001.wav │ ├── LJ010-0001.wav │ └── LJ020-0001.wav ├── test.py ├── tests ├── __init__.py └── test_fwd_bwd.py ├── train.py ├── utils.py └── vctk_wsrglow_infer.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: yoyololicon 4 | custom: PayPal.Me/iamycy 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs/* 2 | saved/* 3 | *.pyc 4 | .vscode/* 5 | lightning_logs/* 6 | wsrglow_checkpoints/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch-wav-datasets"] 2 | path = datasets 3 | url = https://github.com/yoyololicon/pytorch-wav-datasets 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constant Memory WaveGlow 2 | [![DOI](https://zenodo.org/badge/159754913.svg)](https://zenodo.org/badge/latestdoi/159754913) 3 | 4 | A PyTorch implementation of 5 | [WaveGlow: A Flow-based Generative Network for Speech Synthesis](https://arxiv.org/abs/1811.00002) 6 | using constant memory method described in [Training Glow with Constant 7 | Memory Cost](http://bayesiandeeplearning.org/2018/papers/37.pdf). 8 | 9 | The model implementation details are slightly differed from the 10 | [official implementation](https://github.com/NVIDIA/waveglow) based on 11 | personal favor, and the project structure is brought from 12 | [pytorch-template](https://github.com/victoresque/pytorch-template). 13 | 14 | Besides, we also add implementations of Baidu's [WaveFlow](https://arxiv.org/abs/1912.01219), and [MelGlow](https://arxiv.org/abs/2012.01684), 15 | which are easier to train and more memory fiendly. 16 | 17 | In addition to neural vocoder, we also add an implementation of audio super-resolution model [WSRGlow](https://arxiv.org/abs/2106.08507). 18 | 19 | ## Requirements 20 | 21 | After install the requirements from [pytorch-template](https://github.com/victoresque/pytorch-template#requirements): 22 | 23 | ```commandline 24 | pip install nnAudio torch_optimizer 25 | ``` 26 | 27 | ## Quick Start 28 | 29 | Modify the `data_dir` in the json file to a directory which has a bunch of wave files with the same sampling rate, 30 | then your are good to go. The mel-spectrogram will be computed on the fly. 31 | 32 | ```json 33 | { 34 | "data_loader": { 35 | "type": "RandomWaveFileLoader", 36 | "args": { 37 | "data_dir": "/your/data/wave/files", 38 | "batch_size": 8, 39 | "num_workers": 2, 40 | "segment": 16000 41 | } 42 | } 43 | } 44 | ``` 45 | 46 | ``` 47 | python train.py -c config.json 48 | ``` 49 | 50 | ## Memory consumption of model training in PyTorch 51 | 52 | 53 | | Model | Memory (MB) | 54 | ---------------------------------------------------|:-------------:| 55 | | WaveGlow, channels=256, batch size=24 (naive) | N.A. | 56 | | WaveGlow, channels=256, batch size=24 (efficient)| 4951 | 57 | 58 | 59 | 60 | ## Result 61 | 62 | ### WaveGlow 63 | 64 | I trained the model on some cello music pieces from MusicNet using the `musicnet_config.json`. 65 | The clips in the `samples` folder is what I got. Although the audio quality is not very good, it's possible to use 66 | WaveGlow on music generation as well. 67 | The generation speed is around 470kHz on a 1080ti. 68 | 69 | 70 | ### WaveFlow 71 | 72 | I trained on full LJ speech dataset using the `waveflow_LJ_speech.json`. The settings are corresponding to the **64 residual channels, h=64** model in the paper. After training about 1.25M steps, the audio quality is very similiar to their official examples. 73 | Samples generated from training data can be listened [here](samples/waveflow_64chs). 74 | 75 | ### MelGlow 76 | 77 | Coming soon. 78 | 79 | 80 | ### WSRGlow 81 | 82 | Pre-trained models on VCTK dataset are available [here](). We follow the settings of [NU-Wave](https://arxiv.org/abs/2104.02321) to get the training data. 83 | 84 | 85 | ## Citation 86 | If you use our code on any project and research, please cite: 87 | 88 | ```bibtex 89 | @misc{memwaveglow, 90 | doi = {10.5281/zenodo.3874330}, 91 | author = {Chin Yun Yu}, 92 | title = {Constant Memory WaveGlow: A PyTorch implementation of WaveGlow with constant memory cost}, 93 | howpublished = {\url{https://github.com/yoyololicon/constant-memory-waveglow}}, 94 | year = {2019} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /configs/melglow_LJ_speech.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch MelGlow", 3 | "arch": { 4 | "type": "MelGlow", 5 | "args": { 6 | "flows": 12, 7 | "n_group": 8, 8 | "n_early_every": 4, 9 | "n_early_size": 2, 10 | "hop_size": 256, 11 | "n_mels": 80, 12 | "reverse_mode": false, 13 | "memory_efficient": true, 14 | "dilation_channels": 48, 15 | "residual_channels": 48, 16 | "skip_channels": 48, 17 | "depth": 7, 18 | "radix": 3, 19 | "predict_channels": 64, 20 | "predict_layers": 3, 21 | "bias": false 22 | } 23 | }, 24 | "dataset": { 25 | "type": "RandomWAVDataset", 26 | "args": { 27 | "data_dir": "~/data-disk/Datasets/LJSpeech-1.1/wavs/", 28 | "size": 8000, 29 | "segment": 22016 30 | } 31 | }, 32 | "data_loader": { 33 | "batch_size": 8, 34 | "shuffle": true, 35 | "num_workers": 4, 36 | "prefetch_factor": 4, 37 | "pin_memory": true 38 | }, 39 | "optimizer": { 40 | "type": "Adam", 41 | "args": { 42 | "lr": 0.0001 43 | } 44 | }, 45 | "loss": { 46 | "type": "WaveGlowLoss", 47 | "args": { 48 | "sigma": 0.7, 49 | "elementwise_mean": true 50 | } 51 | }, 52 | "conditioner": { 53 | "type": "MelSpec", 54 | "args": { 55 | "sr": 22050, 56 | "n_fft": 1024, 57 | "hop_length": 256, 58 | "f_min": 60, 59 | "f_max": 7600, 60 | "n_mels": 80 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /configs/mr_waveglow_LJ_speech.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch MR WaveGlow", 3 | "arch": { 4 | "type": "MRWaveGlow", 5 | "args": { 6 | "prior_flows": 4, 7 | "n_group": 8, 8 | "hop_size": 256, 9 | "n_mels": 80, 10 | "memory_efficient": true, 11 | "reverse_mode": false, 12 | "dilation_channels": 256, 13 | "residual_channels": 256, 14 | "skip_channels": 256, 15 | "depth": 8, 16 | "radix": 3, 17 | "bias": false 18 | } 19 | }, 20 | "dataset": { 21 | "type": "RandomWAVDataset", 22 | "args": { 23 | "data_dir": "~/data-disk/Datasets/LJSpeech-1.1/wavs/", 24 | "size": 24000, 25 | "segment": 16000 26 | } 27 | }, 28 | "data_loader": { 29 | "batch_size": 24, 30 | "shuffle": true, 31 | "num_workers": 4, 32 | "prefetch_factor": 6, 33 | "pin_memory": true 34 | }, 35 | "optimizer": { 36 | "type": "Adam", 37 | "args": { 38 | "lr": 0.0001, 39 | "weight_decay": 0 40 | } 41 | }, 42 | "loss": { 43 | "type": "WaveGlowLoss", 44 | "args": { 45 | "sigma": 0.7, 46 | "elementwise_mean": true 47 | } 48 | }, 49 | "conditioner": { 50 | "type": "MelSpec", 51 | "args": { 52 | "sr": 22050, 53 | "n_fft": 1024, 54 | "hop_length": 256, 55 | "f_max": 8000, 56 | "n_mels": 80 57 | } 58 | } 59 | } -------------------------------------------------------------------------------- /configs/musicnet_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MusicGlow", 3 | "n_gpu": 2, 4 | "arch": { 5 | "type": "WaveGlow", 6 | "args": { 7 | "flows": 18, 8 | "n_group": 8, 9 | "n_early_every": 6, 10 | "n_early_size": 2, 11 | "sr": 22050, 12 | "window_size": 2048, 13 | "hop_size": 512, 14 | "n_mels": 80, 15 | "dilation_channels": 256, 16 | "residual_channels": 256, 17 | "skip_channels": 256, 18 | "depth": 4, 19 | "radix": 3, 20 | "bias": false 21 | } 22 | }, 23 | "data_loader": { 24 | "type": "MusicNetDataLoader", 25 | "args": { 26 | "data_dir": "/host/data_dsk1/dataset/musicnet", 27 | "batch_size": 4, 28 | "num_workers": 2, 29 | "sr": 22050, 30 | "segment": 16384, 31 | "training": true, 32 | "category": "Solo Cello" 33 | } 34 | }, 35 | "optimizer": { 36 | "type": "Adam", 37 | "args": { 38 | "lr": 1e-4, 39 | "weight_decay": 0 40 | } 41 | }, 42 | "loss": { 43 | "type": "WaveGlowLoss", 44 | "args": { 45 | "sigma": 1.0, 46 | "elementwise_mean": true 47 | } 48 | }, 49 | "metrics": [ 50 | ], 51 | "lr_scheduler": { 52 | "type": "StepLR", 53 | "args": { 54 | "step_size": 10000, 55 | "gamma": 0.1 56 | } 57 | }, 58 | "trainer": { 59 | "steps": 300000, 60 | "save_dir": "saved/", 61 | "save_freq": 5000, 62 | "verbosity": 2 63 | }, 64 | "visualization": { 65 | "tensorboardX": true, 66 | "log_dir": "saved/runs" 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /configs/waveflow_LJ_speech.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch WaveFlow", 3 | "arch": { 4 | "type": "WaveFlow", 5 | "args": { 6 | "flows": 8, 7 | "n_group": 64, 8 | "n_mels": 80, 9 | "use_conv1x1": false, 10 | "memory_efficient": false, 11 | "reverse_mode": false, 12 | "dilation_channels": 64, 13 | "residual_channels": 64, 14 | "skip_channels": 64, 15 | "bias": false 16 | } 17 | }, 18 | "dataset": { 19 | "type": "RandomWAVDataset", 20 | "args": { 21 | "data_dir": "~/data-disk/Datasets/LJSpeech-1.1/wavs/", 22 | "size": 12000, 23 | "segment": 16000 24 | } 25 | }, 26 | "data_loader": { 27 | "batch_size": 12, 28 | "shuffle": true, 29 | "num_workers": 4, 30 | "prefetch_factor": 4, 31 | "pin_memory": true 32 | }, 33 | "optimizer": { 34 | "type": "Adam", 35 | "args": { 36 | "lr": 0.0002, 37 | "weight_decay": 0 38 | } 39 | }, 40 | "loss": { 41 | "type": "WaveGlowLoss", 42 | "args": { 43 | "sigma": 0.7, 44 | "elementwise_mean": true 45 | } 46 | }, 47 | "conditioner": { 48 | "type": "MelSpec", 49 | "args": { 50 | "sr": 22050, 51 | "n_fft": 1024, 52 | "hop_length": 256, 53 | "f_max": 8000, 54 | "n_mels": 80 55 | } 56 | } 57 | } -------------------------------------------------------------------------------- /configs/waveglow_LJ_speech.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch WaveGlow", 3 | "arch": { 4 | "type": "WaveGlow", 5 | "args": { 6 | "flows": 12, 7 | "n_group": 8, 8 | "n_early_every": 4, 9 | "n_early_size": 2, 10 | "hop_size": 256, 11 | "n_mels": 80, 12 | "memory_efficient": true, 13 | "reverse_mode": false, 14 | "dilation_channels": 256, 15 | "residual_channels": 256, 16 | "skip_channels": 256, 17 | "depth": 8, 18 | "radix": 3, 19 | "bias": false 20 | } 21 | }, 22 | "dataset": { 23 | "type": "RandomWAVDataset", 24 | "args": { 25 | "data_dir": "~/data-disk/Datasets/LJ/LJSpeech-1.1/wavs/", 26 | "size": 24000, 27 | "segment": 16000 28 | } 29 | }, 30 | "data_loader": { 31 | "batch_size": 24, 32 | "shuffle": true, 33 | "num_workers": 4, 34 | "prefetch_factor": 6, 35 | "pin_memory": true 36 | }, 37 | "optimizer": { 38 | "type": "Adam", 39 | "args": { 40 | "lr": 0.0001, 41 | "weight_decay": 0 42 | } 43 | }, 44 | "loss": { 45 | "type": "WaveGlowLoss", 46 | "args": { 47 | "sigma": 0.7, 48 | "elementwise_mean": true 49 | } 50 | }, 51 | "conditioner": { 52 | "type": "MelSpec", 53 | "args": { 54 | "sr": 22050, 55 | "n_fft": 1024, 56 | "hop_length": 256, 57 | "f_max": 8000, 58 | "n_mels": 80 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configs/waveglow_LJ_speech_fast.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch WaveGlow", 3 | "arch": { 4 | "type": "WaveGlow", 5 | "args": { 6 | "flows": 12, 7 | "n_group": 8, 8 | "n_early_every": 4, 9 | "n_early_size": 2, 10 | "hop_size": 256, 11 | "n_mels": 80, 12 | "memory_efficient": false, 13 | "reverse_mode": false, 14 | "dilation_channels": 256, 15 | "residual_channels": 256, 16 | "skip_channels": 256, 17 | "depth": 8, 18 | "radix": 3, 19 | "bias": false 20 | } 21 | }, 22 | "dataset": { 23 | "type": "RandomWAVDataset", 24 | "args": { 25 | "data_dir": "~/data-disk/Datasets/LJSpeech-1.1/wavs/", 26 | "size": 4000, 27 | "segment": 16000, 28 | "deterministic": false 29 | } 30 | }, 31 | "data_loader": { 32 | "batch_size": 4, 33 | "shuffle": true, 34 | "num_workers": 4, 35 | "prefetch_factor": 2, 36 | "pin_memory": true 37 | }, 38 | "optimizer": { 39 | "type": "Adam", 40 | "args": { 41 | "lr": 0.0001, 42 | "weight_decay": 0 43 | } 44 | }, 45 | "loss": { 46 | "type": "WaveGlowLoss", 47 | "args": { 48 | "sigma": 0.7, 49 | "elementwise_mean": true 50 | } 51 | }, 52 | "conditioner": { 53 | "type": "MelSpec", 54 | "args": { 55 | "sr": 22050, 56 | "n_fft": 1024, 57 | "hop_length": 256, 58 | "f_max": 8000, 59 | "n_mels": 80 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /configs/wsrglow_vctk_2x.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch WSRGlow", 3 | "arch": { 4 | "type": "WSRGlow", 5 | "args": { 6 | "upsample_rate": 2, 7 | "memory_efficient": true 8 | } 9 | }, 10 | "dataset": { 11 | "type": "RandomWAVDataset", 12 | "args": { 13 | "data_dir": "~/data-disk/Datasets/VCTK-Corpus-0.92/wav48_silence_trimmed/train/", 14 | "size": 12000, 15 | "segment": 8192, 16 | "deterministic": false 17 | } 18 | }, 19 | "data_loader": { 20 | "batch_size": 12, 21 | "shuffle": false, 22 | "num_workers": 4, 23 | "prefetch_factor": 6, 24 | "pin_memory": true 25 | }, 26 | "optimizer": { 27 | "type": "Adam", 28 | "args": { 29 | "lr": 0.0001, 30 | "betas": [ 31 | 0.9, 32 | 0.98 33 | ], 34 | "weight_decay": 0 35 | } 36 | }, 37 | "loss": { 38 | "type": "WaveGlowLoss", 39 | "args": { 40 | "sigma": 1.0, 41 | "elementwise_mean": true 42 | } 43 | }, 44 | "conditioner": { 45 | "type": "STFTDecimate", 46 | "args": { 47 | "r": 2 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /configs/wsrglow_vctk_3x.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch WSRGlow", 3 | "arch": { 4 | "type": "WSRGlow", 5 | "args": { 6 | "upsample_rate": 3, 7 | "memory_efficient": true 8 | } 9 | }, 10 | "dataset": { 11 | "type": "RandomWAVDataset", 12 | "args": { 13 | "data_dir": "~/data-disk/Datasets/VCTK-Corpus-0.92/wav48_silence_trimmed/train/", 14 | "size": 12000, 15 | "segment": 8208, 16 | "deterministic": false 17 | } 18 | }, 19 | "data_loader": { 20 | "batch_size": 12, 21 | "shuffle": false, 22 | "num_workers": 4, 23 | "prefetch_factor": 6, 24 | "pin_memory": true 25 | }, 26 | "optimizer": { 27 | "type": "Adam", 28 | "args": { 29 | "lr": 0.0001, 30 | "betas": [ 31 | 0.9, 32 | 0.98 33 | ], 34 | "weight_decay": 0 35 | } 36 | }, 37 | "loss": { 38 | "type": "WaveGlowLoss", 39 | "args": { 40 | "sigma": 1.0, 41 | "elementwise_mean": true 42 | } 43 | }, 44 | "conditioner": { 45 | "type": "STFTDecimate", 46 | "args": { 47 | "r": 3 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.cuda import amp 5 | import torchaudio 6 | from utils import remove_weight_norms 7 | from time import time 8 | import math 9 | 10 | from model import LightModel, condition 11 | 12 | 13 | def main(ckpt, infile, outfile, sigma, half, n_group=None): 14 | lit_model = LightModel.load_from_checkpoint(ckpt, map_location='cpu') 15 | model = lit_model.model 16 | conditioner = lit_model.conditioner 17 | model.apply(remove_weight_norms) 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | model = model.to(device) 21 | conditioner = conditioner.to(device) 22 | model.eval() 23 | 24 | y, sr = torchaudio.load(infile) 25 | y = y.mean(0, keepdim=True).to(device) 26 | 27 | if n_group: 28 | offset = y.shape[1] % n_group 29 | if offset: 30 | y = y[:, :-offset] 31 | cond = conditioner(y) 32 | 33 | if half: 34 | model = model.half() 35 | cond = cond.half() 36 | y = y.half() 37 | 38 | with torch.no_grad(): 39 | start = time() 40 | z, logdet = model(y, cond) 41 | cost = time() - start 42 | z = z.squeeze() 43 | 44 | print(z.mean().item(), z.std().item()) 45 | print("Forward LL:", logdet.mean().item() / z.size(0) - 0.5 * 46 | (z.pow(2).mean().item() / sigma ** 2 + math.log(2 * math.pi) + 2 * math.log(sigma))) 47 | print("Time cost: {:.4f}, Speed: {:.4f} kHz".format( 48 | cost, z.numel() / cost / 1000)) 49 | 50 | with torch.no_grad(): 51 | start = time() 52 | x = model.infer(cond, sigma) 53 | cost = time() - start 54 | 55 | print("Time cost: {:.4f}, Speed: {:.4f} kHz".format( 56 | cost, x.numel() / cost / 1000)) 57 | print(x.max().item(), x.min().item()) 58 | 59 | torchaudio.save(outfile, x.unsqueeze(0).cpu(), sr) 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser(description='Inferencer') 64 | parser.add_argument('ckpt', type=str) 65 | parser.add_argument('infile', type=str) 66 | parser.add_argument('outfile', type=str) 67 | parser.add_argument('-s', '--sigma', type=float, default=0.6) 68 | parser.add_argument('-n', '--n-group', type=int, default=None) 69 | parser.add_argument('--half', action='store_true') 70 | args = parser.parse_args() 71 | 72 | main(args.ckpt, args.infile, args.outfile, 73 | args.sigma, args.half, args.n_group) 74 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .waveglow import WaveGlow 2 | from .waveflow import WaveFlow 3 | from .melglow import MelGlow 4 | from .mr_waveglow import MRWaveGlow 5 | from .base import FlowBase, Reversible 6 | from .lightning import LightModel 7 | from .wsrglow import WSRGlow 8 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch import Tensor 4 | import torch.nn as nn 5 | 6 | 7 | class Reversible(nn.Module): 8 | _reverse_mode: bool 9 | 10 | def __init__(self, reverse_mode, **kwargs) -> None: 11 | super().__init__(**kwargs) 12 | self._reverse_mode = reverse_mode 13 | 14 | def forward_computation(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: 15 | raise NotImplementedError 16 | 17 | def reverse_computation(self, z: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: 18 | raise NotImplementedError 19 | 20 | def forward(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: 21 | if self._reverse_mode: 22 | return self.reverse_computation(x, *args, **kwargs) 23 | return self.forward_computation(x, *args, **kwargs) 24 | 25 | def reverse(self, z: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: 26 | if self._reverse_mode: 27 | return self.forward_computation(z, *args, **kwargs) 28 | return self.reverse_computation(z, *args, **kwargs) 29 | 30 | 31 | class FlowBase(Reversible): 32 | def __init__(self, condition_hop_length: int, reverse_mode=False) -> None: 33 | super().__init__(reverse_mode=reverse_mode) 34 | self._hop_length = condition_hop_length 35 | 36 | def forward_computation(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 37 | raise NotImplementedError 38 | 39 | def reverse_computation(self, z: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 40 | raise NotImplementedError 41 | 42 | @torch.no_grad() 43 | def infer(self, h: Tensor, sigma: float = 1.) -> Tensor: 44 | if h.dim() == 2: 45 | h = h.unsqueeze(0) 46 | 47 | batch_dim, _, steps = h.shape 48 | samples = steps * self._hop_length 49 | 50 | z = h.new_empty((batch_dim, samples)).normal_(std=sigma) 51 | if self._reverse_mode: 52 | x, _ = self.forward_computation(z, h) 53 | else: 54 | x, _ = self.reverse_computation(z, h) 55 | return x.squeeze() 56 | -------------------------------------------------------------------------------- /model/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from torchaudio.transforms import MelSpectrogram 5 | 6 | 7 | class MelSpec(nn.Module): 8 | def __init__(self, sr, n_fft, hop_length, **kwargs) -> None: 9 | super().__init__() 10 | 11 | self.mel = nn.Sequential( 12 | nn.ReflectionPad1d((n_fft // 2 - hop_length // 2, 13 | n_fft // 2 + hop_length // 2)), 14 | MelSpectrogram(sample_rate=sr, n_fft=n_fft, 15 | hop_length=hop_length, center=False, **kwargs) 16 | ) 17 | 18 | def forward(self, x: Tensor) -> Tensor: 19 | return self.mel(x).add_(1e-7).log_() 20 | 21 | 22 | class LowPass(nn.Module): 23 | def __init__(self, 24 | nfft=1024, 25 | hop=256, 26 | ratio=(1 / 6, 1 / 3, 1 / 2, 2 / 3, 3 / 4, 4 / 5, 5 / 6, 27 | 1 / 1)): 28 | super().__init__() 29 | self.nfft = nfft 30 | self.hop = hop 31 | self.register_buffer('window', torch.hann_window(nfft), False) 32 | f = torch.ones((len(ratio), nfft//2 + 1), dtype=torch.float) 33 | for i, r in enumerate(ratio): 34 | f[i, int((nfft//2+1) * r):] = 0. 35 | self.register_buffer('filters', f, False) 36 | 37 | # x: [B,T], r: [B], int 38 | def forward(self, x, r): 39 | origin_shape = x.shape 40 | T = origin_shape[-1] 41 | x = x.view(-1, T) 42 | 43 | x = F.pad(x, (0, self.nfft), 'constant', 0) 44 | stft = torch.stft(x, 45 | self.nfft, 46 | self.hop, 47 | window=self.window, 48 | ) # return_complex=False) #[B, F, TT,2] 49 | 50 | stft = stft * self.filters[r][:, None, None] 51 | x = torch.istft(stft, 52 | self.nfft, 53 | self.hop, 54 | window=self.window, 55 | ) # return_complex=False) 56 | x = x[:, :T] 57 | return x.view(*origin_shape) 58 | 59 | 60 | class STFTDecimate(LowPass): 61 | def __init__(self, r, *args, **kwargs): 62 | super().__init__(*args, ratio=[1 / r], **kwargs) 63 | self.r = r 64 | 65 | def forward(self, x): 66 | return super().forward(x, 0)[..., ::self.r] 67 | -------------------------------------------------------------------------------- /model/efficient_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from typing import Tuple 5 | import torch.nn.functional as F 6 | from torch.autograd import Function, set_grad_enabled, grad, gradcheck 7 | from torch.cuda.amp import custom_fwd, custom_bwd 8 | 9 | from .base import Reversible 10 | 11 | 12 | __all__ = [ 13 | 'InvertibleConv1x1', 'AffineCouplingBlock' 14 | ] 15 | 16 | 17 | class InvertibleConv1x1(Reversible, nn.Conv1d): 18 | def __init__(self, c, memory_efficient=False, reverse_mode=False): 19 | super().__init__(in_channels=c, out_channels=c, 20 | kernel_size=1, bias=False, reverse_mode=reverse_mode) 21 | 22 | W = torch.linalg.qr(torch.randn(c, c))[0] 23 | if torch.det(W) < 0: 24 | W[:, 0] = -W[:, 0] 25 | # W = torch.eye(c).flip(0) 26 | self.weight.data[:] = W.contiguous().unsqueeze(-1) 27 | if memory_efficient: 28 | self._efficient_forward = Conv1x1Func.apply 29 | self._efficient_reverse = InvConv1x1Func.apply 30 | 31 | def forward_computation(self, x: Tensor) -> Tuple[Tensor, Tensor]: 32 | if hasattr(self, '_efficient_forward'): 33 | z, log_det_W = self._efficient_forward(x, self.weight) 34 | x.storage().resize_(0) 35 | return z, log_det_W 36 | else: 37 | *_, n_of_groups = x.shape 38 | # should fix nan logdet 39 | log_det_W = n_of_groups * self.weight.squeeze().logdet() 40 | z = F.conv1d(x, self.weight) 41 | return z, log_det_W 42 | 43 | def reverse_computation(self, z: Tensor) -> Tuple[Tensor, Tensor]: 44 | if hasattr(self, '_efficient_reverse'): 45 | x, log_det_W = self._efficient_reverse(z, self.weight) 46 | z.storage().resize_(0) 47 | return x, log_det_W 48 | else: 49 | weight = self.weight.squeeze() 50 | *_, n_of_groups = z.shape 51 | log_det_W = -n_of_groups * \ 52 | weight.logdet() # should fix nan logdet 53 | x = F.conv1d(z, weight.inverse().unsqueeze(-1)) 54 | return x, log_det_W 55 | 56 | 57 | class AffineCouplingBlock(Reversible): 58 | def __init__(self, 59 | transform_type, 60 | memory_efficient=True, 61 | reverse_mode=False, 62 | **kwargs): 63 | super().__init__(reverse_mode) 64 | 65 | self.F = transform_type(**kwargs) 66 | if memory_efficient: 67 | self._efficient_forward = AffineCouplingFunc.apply 68 | self._efficient_reverse = InvAffineCouplingFunc.apply 69 | 70 | def forward_computation(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: 71 | if hasattr(self, '_efficient_forward'): 72 | z, log_s = self._efficient_forward( 73 | x, y, self.F, *self.F.parameters()) 74 | x.storage().resize_(0) 75 | return z, log_s 76 | else: 77 | xa, xb = x.chunk(2, 1) 78 | za = xa 79 | log_s, t = self.F(xa, y) 80 | zb = xb * log_s.exp() + t 81 | z = torch.cat((za, zb), 1) 82 | return z, log_s 83 | 84 | def reverse_computation(self, z: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: 85 | if hasattr(self, '_efficient_reverse'): 86 | x, log_s = self._efficient_reverse( 87 | z, y, self.F, *self.F.parameters()) 88 | z.storage().resize_(0) 89 | return x, log_s 90 | else: 91 | za, zb = z.chunk(2, 1) 92 | xa = za 93 | log_s, t = self.F(za, y) 94 | xb = (zb - t) / log_s.exp() 95 | x = torch.cat((xa, xb), 1) 96 | return x, -log_s 97 | 98 | 99 | class AffineCouplingFunc(Function): 100 | @staticmethod 101 | @custom_fwd 102 | def forward(ctx, x, y, F, *F_weights): 103 | ctx.F = F 104 | with torch.no_grad(): 105 | xa, xb = x.chunk(2, 1) 106 | xa, xb = xa.contiguous(), xb.contiguous() 107 | 108 | log_s, t = F(xa, y) 109 | zb = xb * log_s.exp() + t 110 | za = xa 111 | z = torch.cat((za, zb), 1) 112 | 113 | ctx.save_for_backward(x.data, y, z) 114 | return z, log_s 115 | 116 | @staticmethod 117 | @custom_bwd 118 | def backward(ctx, z_grad, log_s_grad): 119 | F = ctx.F 120 | x, y, z = ctx.saved_tensors 121 | 122 | za, zb = z.chunk(2, 1) 123 | za, zb = za.contiguous(), zb.contiguous() 124 | dza, dzb = z_grad.chunk(2, 1) 125 | dza, dzb = dza.contiguous(), dzb.contiguous() 126 | 127 | with set_grad_enabled(True): 128 | xa = za 129 | xa.requires_grad = True 130 | log_s, t = F(xa, y) 131 | 132 | with torch.no_grad(): 133 | s = log_s.exp() 134 | xb = (zb - t) / s 135 | x.storage().resize_(xb.numel() * 2) 136 | torch.cat((xa, xb), 1, out=x) # .contiguous() 137 | # x.copy_(xout) # .detach() 138 | 139 | with set_grad_enabled(True): 140 | param_list = [xa] + list(F.parameters()) 141 | if ctx.needs_input_grad[1]: 142 | param_list += [y] 143 | dtsdxa, *dw = grad(torch.cat((log_s, t), 1), param_list, 144 | grad_outputs=torch.cat((dzb * xb * s + log_s_grad, dzb), 1)) 145 | 146 | dxa = dza + dtsdxa 147 | dxb = dzb * s 148 | dx = torch.cat((dxa, dxb), 1) 149 | if ctx.needs_input_grad[1]: 150 | *dw, dy = dw 151 | else: 152 | dy = None 153 | 154 | return (dx, dy, None) + tuple(dw) 155 | 156 | 157 | class InvAffineCouplingFunc(Function): 158 | @staticmethod 159 | @custom_fwd 160 | def forward(ctx, z, y, F, *F_weights): 161 | ctx.F = F 162 | with torch.no_grad(): 163 | za, zb = z.chunk(2, 1) 164 | za, zb = za.contiguous(), zb.contiguous() 165 | 166 | log_s, t = F(za, y) 167 | xb = (zb - t) / log_s.exp() 168 | xa = za 169 | x = torch.cat((xa, xb), 1) 170 | 171 | ctx.save_for_backward(z.data, y, x) 172 | return x, -log_s 173 | 174 | @staticmethod 175 | @custom_bwd 176 | def backward(ctx, x_grad, log_s_grad): 177 | F = ctx.F 178 | z, y, x = ctx.saved_tensors 179 | 180 | xa, xb = x.chunk(2, 1) 181 | xa, xb = xa.contiguous(), xb.contiguous() 182 | dxa, dxb = x_grad.chunk(2, 1) 183 | dxa, dxb = dxa.contiguous(), dxb.contiguous() 184 | 185 | with set_grad_enabled(True): 186 | za = xa 187 | za.requires_grad = True 188 | log_s, t = F(za, y) 189 | s = log_s.exp() 190 | 191 | with torch.no_grad(): 192 | zb = xb * s + t 193 | 194 | z.storage().resize_(zb.numel() * 2) 195 | torch.cat((za, zb), 1, out=z) 196 | # z.copy_(zout) 197 | 198 | with set_grad_enabled(True): 199 | param_list = [za] + list(F.parameters()) 200 | if ctx.needs_input_grad[1]: 201 | param_list += [y] 202 | dtsdza, *dw = grad(torch.cat((-log_s, -t / s), 1), param_list, 203 | grad_outputs=torch.cat((dxb * zb / s.detach() + log_s_grad, dxb), 1)) 204 | 205 | dza = dxa + dtsdza 206 | dzb = dxb / s.detach() 207 | dz = torch.cat((dza, dzb), 1) 208 | if ctx.needs_input_grad[1]: 209 | *dw, dy = dw 210 | else: 211 | dy = None 212 | return (dz, dy, None) + tuple(dw) 213 | 214 | 215 | class Conv1x1Func(Function): 216 | @staticmethod 217 | @custom_fwd 218 | def forward(ctx, x, weight): 219 | with torch.no_grad(): 220 | *_, n_of_groups = x.shape 221 | log_det_W = weight.squeeze().logdet() 222 | log_det_W *= n_of_groups 223 | z = F.conv1d(x, weight) 224 | 225 | ctx.save_for_backward(x.data, weight, z) 226 | return z, log_det_W 227 | 228 | @staticmethod 229 | @custom_bwd 230 | def backward(ctx, z_grad, log_det_W_grad): 231 | x, weight, z = ctx.saved_tensors 232 | *_, n_of_groups = z.shape 233 | 234 | with torch.no_grad(): 235 | inv_weight = weight.squeeze().inverse() 236 | x.storage().resize_(z.numel()) 237 | x[:] = F.conv1d(z, inv_weight.unsqueeze(-1)) 238 | 239 | dx = F.conv1d(z_grad, weight.transpose(0, 1)) 240 | dw = z_grad.transpose(0, 1).contiguous().view(weight.shape[0], -1) @ x.transpose(1, 2).contiguous().view( 241 | -1, weight.shape[1]) 242 | dw += inv_weight.t() * log_det_W_grad * n_of_groups 243 | 244 | return dx, dw.unsqueeze(-1) 245 | 246 | 247 | class InvConv1x1Func(Function): 248 | @staticmethod 249 | @custom_fwd 250 | def forward(ctx, x, inv_weight): 251 | with torch.no_grad(): 252 | sqr_inv_weight = inv_weight.squeeze() 253 | *_, n_of_groups = x.shape 254 | log_det_W = -sqr_inv_weight.logdet() 255 | log_det_W *= n_of_groups 256 | z = F.conv1d(x, sqr_inv_weight.inverse().unsqueeze(-1)) 257 | 258 | ctx.save_for_backward(x.data, inv_weight, z) 259 | return z, log_det_W 260 | 261 | @staticmethod 262 | @custom_bwd 263 | def backward(ctx, z_grad, log_det_W_grad): 264 | x, inv_weight, z = ctx.saved_tensors 265 | *_, n_of_groups = z.shape 266 | 267 | with torch.no_grad(): 268 | x.storage().resize_(z.numel()) 269 | x[:] = F.conv1d(z, inv_weight) 270 | 271 | inv_weight = inv_weight.squeeze() 272 | weight_T = inv_weight.inverse().t() 273 | dx = F.conv1d(z_grad, weight_T.unsqueeze(-1)) 274 | dw = z_grad.transpose(0, 1).contiguous().view(weight_T.shape[0], -1) @ \ 275 | x.transpose(1, 2).contiguous().view(-1, weight_T.shape[1]) 276 | dinvw = - weight_T @ dw @ weight_T 277 | dinvw -= weight_T * log_det_W_grad * n_of_groups 278 | 279 | return dx, dinvw.unsqueeze(-1) 280 | -------------------------------------------------------------------------------- /model/lightning.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | import pytorch_lightning as pl 6 | 7 | 8 | import model as module_arch 9 | from .base import FlowBase 10 | import model.condition as module_condition 11 | import model.loss as module_loss 12 | import datasets as module_data 13 | from utils import get_instance 14 | 15 | 16 | class LightModel(pl.LightningModule): 17 | model: FlowBase 18 | conditioner: nn.Module 19 | criterion: nn.Module 20 | 21 | @staticmethod 22 | def add_model_specific_args(parent_parser: ArgumentParser): 23 | parser = parent_parser.add_argument_group('Lightning') 24 | # parser.add_argument('config', type=str, help='Path to config file') 25 | return parent_parser 26 | 27 | def __init__(self, config: dict = None, **kwargs) -> None: 28 | super().__init__() 29 | 30 | self.save_hyperparameters(config) 31 | self.save_hyperparameters(kwargs) 32 | 33 | model = get_instance(module_arch, self.hparams.arch) 34 | conditioner = get_instance(module_condition, self.hparams.conditioner) 35 | criterion = get_instance(module_loss, self.hparams.loss) 36 | 37 | self.model = model 38 | self.conditioner = conditioner 39 | self.criterion = criterion 40 | 41 | def configure_optimizers(self): 42 | optimizer = get_instance( 43 | torch.optim, self.hparams.optimizer, self.parameters()) 44 | return optimizer 45 | 46 | def train_dataloader(self): 47 | train_data = get_instance(module_data, self.hparams.dataset) 48 | train_loader = DataLoader( 49 | train_data, **self.hparams.data_loader) 50 | return train_loader 51 | 52 | def training_step(self, batch, batch_idx): 53 | x = batch 54 | cond = self.conditioner(x) 55 | z, logdet = self.model(x, cond) 56 | loss = self.criterion(z, logdet) 57 | 58 | values = { 59 | 'logdet': logdet.sum() / z.numel(), 60 | 'z_mean': z.mean(), 61 | 'z_std': z.std() 62 | } 63 | self.log_dict(values, prog_bar=True, sync_dist=True) 64 | self.log('loss', loss, prog_bar=False, sync_dist=True) 65 | return loss 66 | 67 | def forward(self, *args, **kwargs): 68 | return self.model.infer(*args, **kwargs) 69 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class WaveGlowLoss(torch.nn.Module): 5 | def __init__(self, sigma=1., elementwise_mean=True): 6 | super().__init__() 7 | self.sigma2 = sigma ** 2 8 | self.mean = elementwise_mean 9 | 10 | def forward(self, z, logdet): 11 | loss = 0.5 * z.pow(2).sum(1) / self.sigma2 - logdet 12 | loss = loss.mean() 13 | if self.mean: 14 | loss = loss / z.size(1) 15 | return loss 16 | -------------------------------------------------------------------------------- /model/melglow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from typing import Tuple 5 | import torch.nn.functional as F 6 | 7 | from utils import add_weight_norms 8 | from .base import FlowBase 9 | from .waveglow import fused_gate 10 | from .efficient_modules import AffineCouplingBlock, InvertibleConv1x1 11 | 12 | 13 | class Predictor(nn.Module): 14 | def __init__(self, 15 | in_channels, 16 | out_channels, 17 | hidden_channels, 18 | layers, 19 | bias, 20 | groups): 21 | super().__init__() 22 | 23 | self.groups = groups 24 | 25 | self.start = nn.Sequential( 26 | nn.Conv1d(in_channels, hidden_channels * groups, 1, bias=bias), 27 | nn.BatchNorm1d(hidden_channels * groups), 28 | nn.Tanh()) 29 | 30 | self.end = nn.Conv1d(hidden_channels * groups, 31 | out_channels * groups, 1, bias=bias, groups=groups) 32 | self.res_blocks = nn.ModuleList([ 33 | nn.Sequential( 34 | nn.Conv1d(hidden_channels * groups, hidden_channels * 35 | groups, 1, bias=bias, groups=groups), 36 | nn.BatchNorm1d(hidden_channels * groups), 37 | nn.Tanh(), 38 | nn.Conv1d(hidden_channels * groups, hidden_channels * 39 | groups, 1, bias=bias, groups=groups), 40 | nn.BatchNorm1d(hidden_channels * groups), 41 | nn.Tanh() 42 | ) for _ in range(layers) 43 | ]) 44 | 45 | def forward(self, x): 46 | x = self.start(x) 47 | for block in self.res_blocks: 48 | x = block(x) + x 49 | return self.end(x) 50 | 51 | 52 | class NonCausalLayerLVC(nn.Module): 53 | def __init__(self, 54 | dilation, 55 | dilation_channels, 56 | residual_channels, 57 | skip_channels, 58 | radix, 59 | bias, 60 | last_layer=False): 61 | super().__init__() 62 | 63 | self.padding = dilation * (radix - 1) // 2 64 | self.dilation = dilation 65 | 66 | self.chs_split = [skip_channels] 67 | if last_layer: 68 | self.W_o = nn.Conv1d( 69 | dilation_channels, skip_channels, 1, bias=bias) 70 | else: 71 | self.W_o = nn.Conv1d( 72 | dilation_channels, residual_channels + skip_channels, 1, bias=bias) 73 | self.chs_split.insert(0, residual_channels) 74 | 75 | def forward(self, x, weights): 76 | batch, steps, *kernel_size = weights.shape 77 | weights = weights.view(-1, *kernel_size[1:]) 78 | 79 | offset = x.shape[2] // steps 80 | padded_x = F.pad(x, (self.padding,) * 2) 81 | unfolded_x = padded_x.unfold(2, self.padding * 2 + offset, offset).transpose( 82 | 1, 2).contiguous().view(1, -1, self.padding * 2 + offset) 83 | 84 | z = F.conv1d(unfolded_x, weights, dilation=self.dilation, 85 | groups=batch * steps) 86 | zw, zv = z.view(batch, steps, kernel_size[0], -1).transpose( 87 | 1, 2).contiguous().view(batch, kernel_size[0], -1).chunk(2, 1) 88 | z = fused_gate(zw, zv) 89 | *z, skip = self.W_o(z).split(self.chs_split, 1) 90 | return z[0] + x if len(z) else None, skip 91 | 92 | 93 | class WN_LVC(nn.Module): 94 | def __init__(self, 95 | in_channels, 96 | aux_channels, 97 | depth, 98 | dilation_channels, 99 | residual_channels, 100 | skip_channels, 101 | predict_channels, 102 | predict_layers, 103 | radix, 104 | bias, 105 | zero_init=True): 106 | super().__init__() 107 | dilations = 2 ** torch.arange(depth) 108 | self.dilations = dilations.tolist() 109 | self.in_chs = in_channels 110 | self.res_chs = residual_channels 111 | self.dil_chs = dilation_channels 112 | self.skp_chs = skip_channels 113 | self.rdx = radix 114 | self.r_field = sum(self.dilations) + 1 115 | 116 | self.start = nn.Conv1d(in_channels, residual_channels, 1, bias=bias) 117 | self.start.apply(add_weight_norms) 118 | 119 | self.layers = nn.ModuleList(NonCausalLayerLVC(d, 120 | dilation_channels, 121 | residual_channels, 122 | skip_channels, 123 | radix, 124 | bias) for d in self.dilations[:-1]) 125 | self.layers.append(NonCausalLayerLVC(self.dilations[-1], 126 | dilation_channels, 127 | residual_channels, 128 | skip_channels, 129 | radix, 130 | bias, 131 | last_layer=True)) 132 | self.layers.apply(add_weight_norms) 133 | 134 | self.end = nn.Conv1d(skip_channels, in_channels * 2, 1, bias=bias) 135 | if zero_init: 136 | self.end.weight.data.zero_() 137 | if bias: 138 | self.end.bias.data.zero_() 139 | 140 | self.pred = Predictor( 141 | aux_channels, 142 | 2 * dilation_channels * residual_channels * radix, 143 | predict_channels, 144 | predict_layers, 145 | bias, 146 | depth 147 | ) 148 | 149 | def forward(self, x, y): 150 | x = self.start(x) 151 | weights = self.pred(y) 152 | weights = weights.view(weights.shape[0], len(self.dilations), -1, 153 | weights.shape[2]).permute(1, 0, 3, 2).contiguous() 154 | cum_skip = 0 155 | for layer, w in zip(self.layers, weights.chunk(len(self.dilations), 0)): 156 | x, skip = layer( 157 | x, w.view(w.shape[1], w.shape[2], 2 * self.dil_chs, self.res_chs, self.rdx)) 158 | cum_skip = cum_skip + skip 159 | return self.end(cum_skip).chunk(2, 1) 160 | 161 | 162 | class MelGlow(FlowBase): 163 | def __init__(self, 164 | flows, 165 | n_group, 166 | n_early_every, 167 | n_early_size, 168 | hop_size, 169 | n_mels, 170 | memory_efficient, 171 | reverse_mode=False, 172 | **kwargs): 173 | super().__init__(hop_size, reverse_mode=reverse_mode) 174 | self.flows = flows 175 | self.n_group = n_group 176 | self.n_early_every = n_early_every 177 | self.n_early_size = n_early_size 178 | self.n_mels = n_mels 179 | self.mem_efficient = memory_efficient 180 | 181 | self.upsample_factor = self._hop_length // n_group 182 | 183 | self.invconv1x1 = nn.ModuleList() 184 | self.WNs = nn.ModuleList() 185 | 186 | # Set up layers with the right sizes based on how many dimensions 187 | # have been output already 188 | n_remaining_channels = n_group 189 | self.z_split_sizes = [] 190 | for k in range(flows): 191 | if k % self.n_early_every == 0 and k: 192 | n_remaining_channels -= n_early_size 193 | self.z_split_sizes.append(n_early_size) 194 | self.invconv1x1.append(InvertibleConv1x1( 195 | n_remaining_channels, memory_efficient=memory_efficient, reverse_mode=reverse_mode)) 196 | self.WNs.append( 197 | AffineCouplingBlock(WN_LVC, memory_efficient=memory_efficient, reverse_mode=reverse_mode, 198 | in_channels=n_remaining_channels // 2, 199 | aux_channels=n_mels, 200 | **kwargs)) 201 | self.z_split_sizes.append(n_remaining_channels) 202 | 203 | def forward_computation(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 204 | batch_dim = x.size(0) 205 | x = x[:, :x.shape[1] // self._hop_length * self._hop_length] 206 | x = x.view(batch_dim, -1, self.n_group).transpose(1, 2) 207 | y = h[..., :x.shape[2] // self.upsample_factor] 208 | 209 | output_audio = [] 210 | split_sections = [self.n_early_size, self.n_group] 211 | 212 | logdet: Tensor = 0 213 | for k, (invconv, affine_coup) in enumerate(zip(self.invconv1x1, 214 | self.WNs)): 215 | if k % self.n_early_every == 0 and k: 216 | split_sections[1] -= self.n_early_size 217 | early_output, x = x.split(split_sections, 1) 218 | # these 2 lines actually copy tensors, may need optimization in the future 219 | output_audio.append(early_output) 220 | if self.mem_efficient: 221 | x = x.clone() 222 | 223 | x, log_det_W = invconv(x) 224 | x, log_s = affine_coup(x, y) 225 | 226 | logdet += log_det_W + log_s.sum((1, 2)) 227 | 228 | assert split_sections[1] == self.z_split_sizes[-1] 229 | output_audio.append(x) 230 | return torch.cat([o.transpose(1, 2) for o in output_audio], 2).view(batch_dim, -1), logdet 231 | 232 | def reverse_computation(self, z: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 233 | batch_dim = z.size(0) 234 | z = z[:, :z.shape[1] // self._hop_length * self._hop_length] 235 | z = z.view(batch_dim, -1, self.n_group).transpose(1, 2) 236 | y = h[..., :z.shape[2] // self.upsample_factor] 237 | 238 | if self.mem_efficient: 239 | remained_z = [r.clone() for r in z.split(self.z_split_sizes, 1)] 240 | else: 241 | remained_z = z.split(self.z_split_sizes, 1) 242 | *remained_z, z = remained_z 243 | 244 | logdet: Tensor = 0 245 | for k, invconv, affine_coup in zip(range(self.flows - 1, -1, -1), 246 | self.invconv1x1[::-1], 247 | self.WNs[::-1]): 248 | 249 | z, log_s = affine_coup.reverse(z, y) 250 | z, log_det_W = invconv.reverse(z) 251 | 252 | logdet += log_det_W + log_s.sum((1, 2)) 253 | 254 | if k % self.n_early_every == 0 and k: 255 | z = torch.cat((remained_z.pop(), z), 1) 256 | 257 | z = z.transpose(1, 2).contiguous().view(batch_dim, -1) 258 | return z, logdet 259 | -------------------------------------------------------------------------------- /model/mr_waveglow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from typing import Tuple 6 | 7 | from utils import add_weight_norms 8 | 9 | from .base import FlowBase 10 | from .efficient_modules import AffineCouplingBlock, InvertibleConv1x1 11 | from .waveglow import WN, fused_gate, NonCausalLayer 12 | 13 | 14 | class MRWaveGlow(FlowBase): 15 | def __init__(self, 16 | prior_flows, 17 | n_group, 18 | hop_size, 19 | n_mels, 20 | memory_efficient, 21 | levels=3, 22 | flows=4, 23 | super_resolution=False, 24 | reverse_mode=False, 25 | **kwargs): 26 | super().__init__(hop_size, reverse_mode) 27 | self.flows = flows 28 | self.prior_flows = prior_flows 29 | self.n_group = n_group 30 | self.n_mels = n_mels 31 | self.super_resolution = super_resolution 32 | self.levels = levels 33 | 34 | self.upsample_factor = hop_size // n_group 35 | 36 | self.prior_invconv1x1 = nn.ModuleList() 37 | self.prior_WNs = nn.ModuleList() 38 | 39 | self.invconv1x1_list = nn.ModuleList() 40 | self.WNs_list = nn.ModuleList() 41 | 42 | in_channels = n_group 43 | for i in range(levels - 1): 44 | in_channels = in_channels // 2 45 | self.invconv1x1_list.append( 46 | nn.ModuleList([InvertibleConv1x1(in_channels, in_channels) for _ in range(flows)])) 47 | 48 | self.WNs_list.append( 49 | nn.ModuleList([ 50 | AffineCouplingBlock(WN, memory_efficient=memory_efficient, reverse_mode=reverse_mode, 51 | in_channels=in_channels // 2, aux_channels=in_channels + (0 if super_resolution else n_mels), **kwargs) for _ in range(flows)])) 52 | 53 | for k in range(prior_flows): 54 | self.prior_invconv1x1.append(InvertibleConv1x1( 55 | in_channels, memory_efficient=memory_efficient, reverse_mode=reverse_mode)) 56 | self.prior_WNs.append( 57 | AffineCouplingBlock(WN, memory_efficient=memory_efficient, in_channels=in_channels // 2, 58 | aux_channels=n_mels, reverse_mode=reverse_mode, **kwargs)) 59 | 60 | def forward_computation(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 61 | y = self._upsample_h(h) 62 | 63 | batch_dim = x.size(0) 64 | x = x.view(batch_dim, -1, self.n_group).transpose(1, 2) 65 | assert x.size(2) <= y.size(2) 66 | y = y[..., :x.size(2)] 67 | 68 | output_audio = [] 69 | 70 | logdet: torch.Tensor = 0 71 | 72 | for level in range(self.levels - 1): 73 | x0, x1 = x[:, ::2], x[:, 1::2] 74 | x_diff, x = x1 - x0, (x0 + x1) * 0.5 75 | if self.super_resolution: 76 | cond = x 77 | else: 78 | cond = torch.cat([x, y], 1) 79 | 80 | for invconv, affine_coup in zip(self.invconv1x1_list[level], self.WNs_list[level]): 81 | x_diff, log_det_W = invconv(x_diff) 82 | x_diff, log_s = affine_coup(x_diff, cond) 83 | logdet += log_det_W + log_s.sum((1, 2)) 84 | 85 | output_audio.append(x_diff) 86 | 87 | for invconv, affine_coup in zip(self.prior_invconv1x1, self.prior_WNs): 88 | x, log_det_W = invconv(x) 89 | x, log_s = affine_coup(x, y) 90 | logdet += log_det_W + log_s.sum((1, 2)) 91 | 92 | output_audio.append(x) 93 | return torch.cat(output_audio, 1).transpose(1, 2).contiguous().view(batch_dim, -1), logdet 94 | 95 | def reverse_computation(self, z: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 96 | y = self._upsample_h(h) 97 | batch_dim = z.size(0) 98 | z = z.view(batch_dim, -1, self.n_group).transpose(1, 2) 99 | assert z.size(2) <= y.size(2) 100 | y = y[..., :z.size(2)] 101 | 102 | remained_z = [] 103 | for _ in range(self.levels - 1): 104 | r, z = z.chunk(2, 1) 105 | remained_z.append(r.clone()) 106 | z = z.contiguous() 107 | 108 | logdet: torch.Tensor = 0 109 | for invconv, affine_coup in zip(self.prior_invconv1x1[::-1], self.prior_WNs[::-1]): 110 | z, log_s = affine_coup.reverse(z, y) 111 | z, log_det_W = invconv.reverse(z) 112 | logdet += log_det_W + log_s.sum((1, 2)) 113 | 114 | for level in range(self.levels - 2, -1, -1): 115 | z_diff = remained_z.pop() 116 | if self.super_resolution: 117 | cond = z 118 | else: 119 | cond = torch.cat([z, y], 1) 120 | 121 | for invconv, affine_coup in zip(self.invconv1x1_list[level][::-1], self.WNs_list[level][::-1]): 122 | z_diff, log_s = affine_coup.reverse(z_diff, cond) 123 | z_diff, log_det_W = invconv.reverse(z_diff) 124 | logdet += log_det_W + log_s.sum((1, 2)) 125 | 126 | z_0, z_1 = z - z_diff * 0.5 , z + z_diff * 0.5 127 | z = torch.stack([z_0, z_1], 2).view(batch_dim, -1, z_0.size(2)) 128 | 129 | 130 | z = z.transpose(1, 2).contiguous().view(batch_dim, -1) 131 | return z, logdet 132 | 133 | def _upsample_h(self, h): 134 | return F.interpolate(h, scale_factor=self.upsample_factor, mode='linear') 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /model/waveflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from torch.cuda.amp import autocast 6 | from typing import Tuple 7 | 8 | from utils import add_weight_norms 9 | from .base import FlowBase 10 | from .efficient_modules import InvertibleConv1x1 11 | from .waveglow import fused_gate 12 | 13 | 14 | class NonCausalLayer2D(nn.Module): 15 | def __init__(self, 16 | h_dilation, 17 | dilation, 18 | dilation_channels, 19 | residual_channels, 20 | skip_channels, 21 | radix, 22 | bias, 23 | last_layer=False): 24 | super().__init__() 25 | self.h_pad_size = h_dilation * (radix - 1) 26 | self.pad_size = dilation * (radix - 1) // 2 27 | 28 | self.W = nn.Conv2d(residual_channels, dilation_channels * 2, 29 | kernel_size=radix, 30 | dilation=(h_dilation, dilation), bias=bias) 31 | 32 | self.chs_split = [skip_channels] 33 | if last_layer: 34 | self.W_o = nn.Conv2d( 35 | dilation_channels, skip_channels, 1, bias=bias) 36 | else: 37 | self.W_o = nn.Conv2d( 38 | dilation_channels, residual_channels + skip_channels, 1, bias=bias) 39 | self.chs_split.insert(0, residual_channels) 40 | 41 | def forward(self, x, y): 42 | tmp = F.pad(x, [self.pad_size] * 2 + [self.h_pad_size, 0]) 43 | xy = self.W(tmp) + y 44 | zw, zf = xy.chunk(2, 1) 45 | z = fused_gate(zw, zf) 46 | *z, skip = self.W_o(z).split(self.chs_split, 1) 47 | if len(z): 48 | output = z[0] 49 | return output + x[:, :, -output.size(2):], skip 50 | else: 51 | return None, skip 52 | 53 | def reverse_mode_forward(self, x, y, buffer=None): 54 | if buffer is None: 55 | buffer = F.pad(x, [0, 0, self.h_pad_size, 0]) 56 | else: 57 | buffer = torch.cat((buffer[:, :, 1:], x), 2) 58 | tmp = F.pad(buffer, [self.pad_size] * 2) 59 | xy = self.W(tmp) + y 60 | zw, zf = xy.chunk(2, 1) 61 | z = fused_gate(zw, zf) 62 | *z, skip = self.W_o(z).split(self.chs_split, 1) 63 | if len(z): 64 | output = z[0] 65 | return output + x, skip, buffer 66 | else: 67 | return None, skip, buffer 68 | 69 | 70 | class WN2D(nn.Module): 71 | def __init__(self, 72 | n_group, 73 | aux_channels, 74 | dilation_channels=256, 75 | residual_channels=256, 76 | skip_channels=256, 77 | bias=False, 78 | zero_init=True): 79 | super().__init__() 80 | 81 | dilation_dict = { 82 | 8: [1] * 8, 83 | 16: [1] * 8, 84 | 32: [1, 2, 4] * 2 + [1, 2], 85 | 64: [1, 2, 4, 8, 16, 1, 2, 4], 86 | 128: [1, 2, 4, 8, 16, 32, 64, 1], 87 | } 88 | 89 | self.h_dilations = dilation_dict[n_group] 90 | dilations = 2 ** torch.arange(8) 91 | self.dilations = dilations.tolist() 92 | self.n_group = n_group 93 | self.res_chs = residual_channels 94 | self.dil_chs = dilation_channels 95 | self.skp_chs = skip_channels 96 | self.aux_chs = aux_channels 97 | self.r_field = sum(self.dilations) * 2 + 1 98 | self.h_r_field = sum(self.h_dilations) * 2 + 1 99 | 100 | self.V = nn.Conv1d(aux_channels, dilation_channels * 101 | 2 * 8, 1, bias=bias) 102 | self.V.apply(add_weight_norms) 103 | 104 | self.start = nn.Conv2d(1, residual_channels, 1, bias=bias) 105 | self.start.apply(add_weight_norms) 106 | 107 | self.layers = nn.ModuleList(NonCausalLayer2D(hd, d, 108 | dilation_channels, 109 | residual_channels, 110 | skip_channels, 111 | 3, 112 | bias) for hd, d in zip(self.h_dilations[:-1], self.dilations[:-1])) 113 | self.layers.append(NonCausalLayer2D(self.h_dilations[-1], self.dilations[-1], 114 | dilation_channels, 115 | residual_channels, 116 | skip_channels, 117 | 3, 118 | bias, 119 | last_layer=True)) 120 | self.layers.apply(add_weight_norms) 121 | 122 | self.end = nn.Conv2d(skip_channels, 2, 1, bias=bias) 123 | if zero_init: 124 | self.end.weight.data.zero_() 125 | if bias: 126 | self.end.bias.data.zero_() 127 | 128 | def forward(self, x, y): 129 | x = self.start(x) 130 | y = self.V(y).unsqueeze(2) 131 | cum_skip = 0 132 | for layer, v in zip(self.layers, y.chunk(len(self.layers), 1)): 133 | x, skip = layer(x, v) 134 | cum_skip = cum_skip + skip 135 | return self.end(cum_skip).chunk(2, 1) 136 | 137 | def reverse_mode_forward(self, x, y=None, cond=None, buffer_list=None): 138 | x = self.start(x) 139 | new_buffer_list = [] 140 | if buffer_list is None: 141 | buffer_list = [None] * len(self.layers) 142 | if cond is None: 143 | cond = self.V(y).unsqueeze(2).chunk(len(self.layers), 1) 144 | 145 | cum_skip = 0 146 | for layer, buf, v in zip(self.layers, buffer_list, cond): 147 | x, skip, buf = layer.reverse_mode_forward(x, v, buf) 148 | new_buffer_list.append(buf) 149 | cum_skip = cum_skip + skip 150 | 151 | return self.end(cum_skip).chunk(2, 1) + (cond, new_buffer_list,) 152 | 153 | 154 | class WaveFlow(FlowBase): 155 | def __init__(self, 156 | flows, 157 | n_group, 158 | n_mels, 159 | use_conv1x1, 160 | memory_efficient, 161 | reverse_mode=False, 162 | **kwargs): 163 | super().__init__(256, reverse_mode) 164 | self.flows = flows 165 | self.n_group = n_group 166 | self.n_mels = n_mels 167 | self.sub_sr = self._hop_length // n_group 168 | 169 | self.upsampler = nn.Sequential( 170 | nn.ReplicationPad1d((0, 1)), 171 | nn.ConvTranspose1d(n_mels, n_mels, self.sub_sr * 172 | 2 + 1, self.sub_sr, padding=self.sub_sr // 2), 173 | nn.LeakyReLU(0.4, True) 174 | ) 175 | self.upsampler.apply(add_weight_norms) 176 | 177 | self.WNs = nn.ModuleList() 178 | 179 | if use_conv1x1: 180 | self.invconv1x1 = nn.ModuleList() 181 | 182 | # Set up layers with the right sizes based on how many dimensions 183 | # have been output already 184 | for k in range(flows): 185 | self.WNs.append(WN2D(n_group, n_mels, **kwargs)) 186 | if use_conv1x1: 187 | self.invconv1x1.append(InvertibleConv1x1( 188 | n_group, memory_efficient=memory_efficient, reverse_mode=reverse_mode)) 189 | 190 | def forward_computation(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 191 | y = self._upsample_h(h) 192 | 193 | batch_dim = x.size(0) 194 | x = x.view(batch_dim, 1, -1, self.n_group).transpose(2, 3).contiguous() 195 | y = y[..., :x.size(-1)] 196 | 197 | if hasattr(self, 'invconv1x1'): 198 | invconv1x1 = self.invconv1x1 199 | else: 200 | invconv1x1 = [None] * self.flows 201 | 202 | logdet: Tensor = 0 203 | for WN, invconv in zip(self.WNs, invconv1x1): 204 | x0 = x[:, :, :1] 205 | log_s, t = WN(x[:, :, :-1], y) 206 | xout = x[:, :, 1:] * log_s.exp() + t 207 | 208 | logdet += log_s.sum((1, 2, 3)) 209 | 210 | if invconv is None: 211 | x = torch.cat((xout.flip(2), x0), 2) 212 | else: 213 | x, log_det_W = invconv(torch.cat((x0, xout), 2).squeeze(1)) 214 | x = x.unsqueeze(1) 215 | logdet += log_det_W 216 | 217 | return x.squeeze(1).transpose(1, 2).contiguous().view(batch_dim, -1), logdet 218 | 219 | def reverse_computation(self, z: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 220 | y = self._upsample_h(h) 221 | 222 | batch_dim = z.size(0) 223 | z = z.view(batch_dim, 1, -1, self.n_group).transpose(2, 3).contiguous() 224 | y = y[..., :z.size(-1)] 225 | 226 | if hasattr(self, 'invconv1x1'): 227 | invconv1x1 = self.invconv1x1 228 | else: 229 | invconv1x1 = [None] * self.flows 230 | 231 | logdet: Tensor = None 232 | for WN, invconv in zip(self.WNs[::-1], invconv1x1[::-1]): 233 | if invconv is None: 234 | z = z.flip(2) 235 | else: 236 | z, log_det_W = invconv.reverse(z.squeeze(1)) 237 | z = z.unsqueeze(1) 238 | if logdet is None: 239 | logdet = log_det_W.repeat(z.shape[0]) 240 | else: 241 | logdet += log_det_W 242 | 243 | xnew = z[:, :, :1] 244 | x = [xnew] 245 | 246 | buffer_list = None 247 | cond = None 248 | for i in range(1, self.n_group): 249 | log_s, t, cond, buffer_list = WN.reverse_mode_forward( 250 | xnew, y if cond is None else None, cond, buffer_list) 251 | xnew = (z[:, :, i:i+1] - t) / log_s.exp() 252 | x.append(xnew) 253 | 254 | if logdet is None: 255 | logdet = -log_s.sum((1, 2, 3)) 256 | else: 257 | logdet -= log_s.sum((1, 2, 3)) 258 | z = torch.cat(x, 2) 259 | 260 | z = z.squeeze(1).transpose(1, 2).contiguous().view(batch_dim, -1) 261 | return z, logdet 262 | 263 | @autocast(enabled=False) 264 | def _upsample_h(self, h): 265 | return self.upsampler(h.float()) 266 | -------------------------------------------------------------------------------- /model/waveglow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from typing import Tuple 6 | 7 | from utils import add_weight_norms 8 | 9 | from .base import FlowBase 10 | from .efficient_modules import AffineCouplingBlock, InvertibleConv1x1 11 | 12 | 13 | @torch.jit.script 14 | def fused_gate(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 15 | return x1.tanh() * x2.sigmoid() 16 | 17 | 18 | class NonCausalLayer(nn.Module): 19 | def __init__(self, 20 | dilation, 21 | dilation_channels, 22 | residual_channels, 23 | skip_channels, 24 | radix, 25 | bias, 26 | last_layer=False): 27 | super().__init__() 28 | pad_size = dilation * (radix - 1) // 2 29 | self.W = nn.Conv1d(residual_channels, dilation_channels * 2, kernel_size=radix, 30 | padding=pad_size, dilation=dilation, bias=bias) 31 | 32 | self.chs_split = [skip_channels] 33 | if last_layer: 34 | self.W_o = nn.Conv1d( 35 | dilation_channels, skip_channels, 1, bias=bias) 36 | else: 37 | self.W_o = nn.Conv1d( 38 | dilation_channels, residual_channels + skip_channels, 1, bias=bias) 39 | self.chs_split.insert(0, residual_channels) 40 | 41 | def forward(self, x, y): 42 | xy = self.W(x) + y 43 | zw, zf = xy.chunk(2, 1) 44 | z = fused_gate(zw, zf) 45 | *z, skip = self.W_o(z).split(self.chs_split, 1) 46 | return z[0] + x if len(z) else None, skip 47 | 48 | 49 | class WN(nn.Module): 50 | def __init__(self, 51 | in_channels, 52 | aux_channels, 53 | dilation_channels=256, 54 | residual_channels=256, 55 | skip_channels=256, 56 | depth=8, 57 | radix=3, 58 | bias=False, 59 | zero_init=True): 60 | super().__init__() 61 | dilations = 2 ** torch.arange(depth) 62 | self.dilations = dilations.tolist() 63 | self.in_chs = in_channels 64 | self.res_chs = residual_channels 65 | self.dil_chs = dilation_channels 66 | self.skp_chs = skip_channels 67 | self.rdx = radix 68 | self.r_field = sum(self.dilations) + 1 69 | 70 | self.V = nn.Conv1d(aux_channels, dilation_channels * 71 | 2 * depth, 1, bias=bias) 72 | self.V.apply(add_weight_norms) 73 | 74 | self.start = nn.Conv1d(in_channels, residual_channels, 1, bias=bias) 75 | self.start.apply(add_weight_norms) 76 | 77 | self.layers = nn.ModuleList(NonCausalLayer(d, 78 | dilation_channels, 79 | residual_channels, 80 | skip_channels, 81 | radix, 82 | bias) for d in self.dilations[:-1]) 83 | self.layers.append(NonCausalLayer(self.dilations[-1], 84 | dilation_channels, 85 | residual_channels, 86 | skip_channels, 87 | radix, 88 | bias, 89 | last_layer=True)) 90 | self.layers.apply(add_weight_norms) 91 | 92 | self.end = nn.Conv1d(skip_channels, in_channels * 2, 1, bias=bias) 93 | if zero_init: 94 | self.end.weight.data.zero_() 95 | if bias: 96 | self.end.bias.data.zero_() 97 | 98 | def forward(self, x, y): 99 | x = self.start(x) 100 | y = self.V(y) 101 | cum_skip = 0 102 | for layer, v in zip(self.layers, y.chunk(len(self.layers), 1)): 103 | x, skip = layer(x, v) 104 | cum_skip = cum_skip + skip 105 | return self.end(cum_skip).chunk(2, 1) 106 | 107 | 108 | class WaveGlow(FlowBase): 109 | def __init__(self, 110 | flows, 111 | n_group, 112 | n_early_every, 113 | n_early_size, 114 | hop_size, 115 | n_mels, 116 | memory_efficient, 117 | reverse_mode=False, 118 | **kwargs): 119 | super().__init__(hop_size, reverse_mode) 120 | self.n_group = n_group 121 | self.n_early_every = n_early_every 122 | self.n_early_size = n_early_size 123 | self.n_mels = n_mels 124 | self.mem_efficient = memory_efficient 125 | 126 | self.upsample_factor = self._hop_length // n_group 127 | sub_win_size = self.upsample_factor * 2 + 1 128 | self.upsampler = nn.ConvTranspose1d(n_mels, n_mels, sub_win_size, self.upsample_factor, 129 | padding=sub_win_size // 2 - self.upsample_factor // 2, groups=n_mels) 130 | self.upsampler.apply(add_weight_norms) 131 | 132 | self.invconv1x1 = nn.ModuleList() 133 | self.WNs = nn.ModuleList() 134 | 135 | # Set up layers with the right sizes based on how many dimensions 136 | # have been output already 137 | n_remaining_channels = n_group 138 | self.z_split_sizes = [] 139 | for k in range(flows): 140 | if k % self.n_early_every == 0 and k: 141 | n_remaining_channels -= n_early_size 142 | self.z_split_sizes.append(n_early_size) 143 | self.invconv1x1.append(InvertibleConv1x1( 144 | n_remaining_channels, memory_efficient=memory_efficient, reverse_mode=reverse_mode)) 145 | self.WNs.append( 146 | AffineCouplingBlock(WN, memory_efficient=memory_efficient, in_channels=n_remaining_channels // 2, 147 | aux_channels=n_mels, reverse_mode=reverse_mode, **kwargs)) 148 | self.z_split_sizes.append(n_remaining_channels) 149 | 150 | def forward_computation(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 151 | y = self._upsample_h(h) 152 | batch_dim = x.size(0) 153 | x = x.view(batch_dim, -1, self.n_group).transpose(1, 2).contiguous() 154 | # y = y.view(batch_dim, y.size(1), -1, self.n_group).transpose(2, 3) 155 | # y = y.reshape(batch_dim, -1, y.size(-1)) 156 | assert x.size(2) <= y.size(2) 157 | y = y[..., :x.size(2)] 158 | 159 | output_audio = [] 160 | split_sections = [self.n_early_size, self.n_group] 161 | 162 | logdet: torch.Tensor = 0 163 | for k, (invconv, affine_coup) in enumerate(zip(self.invconv1x1, self.WNs)): 164 | if k % self.n_early_every == 0 and k: 165 | split_sections[1] -= self.n_early_size 166 | early_output, x = x.split(split_sections, 1) 167 | # these 2 lines actually copy tensors, may need optimization in the future 168 | output_audio.append(early_output) 169 | if self.mem_efficient: 170 | x = x.clone() 171 | 172 | x, log_det_W = invconv(x) 173 | x, log_s = affine_coup(x, y) 174 | 175 | logdet += log_det_W + log_s.sum((1, 2)) 176 | 177 | # assert split_sections[1] == self.z_split_sizes[-1] 178 | output_audio.append(x) 179 | return torch.cat(output_audio, 1).transpose(1, 2).contiguous().view(batch_dim, -1), logdet 180 | 181 | def reverse_computation(self, z: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: 182 | y = self._upsample_h(h) 183 | batch_dim = z.size(0) 184 | z = z.view(batch_dim, -1, self.n_group).transpose(1, 2).contiguous() 185 | # y = y.view(batch_dim, y.size(1), -1, self.n_group).transpose(2, 3) 186 | # y = y.reshape(batch_dim, -1, y.size(-1)) 187 | assert z.size(2) <= y.size(2) 188 | y = y[..., :z.size(2)] 189 | 190 | if self.mem_efficient: 191 | remained_z = [r.clone() for r in z.split(self.z_split_sizes, 1)] 192 | else: 193 | remained_z = z.split(self.z_split_sizes, 1) 194 | *remained_z, z = remained_z 195 | 196 | logdet: torch.Tensor = 0 197 | for k, invconv, affine_coup in zip(range(len(self.WNs) - 1, -1, -1), self.invconv1x1[::-1], self.WNs[::-1]): 198 | 199 | z, log_s = affine_coup.reverse(z, y) 200 | z, log_det_W = invconv.reverse(z) 201 | 202 | logdet += log_det_W + log_s.sum((1, 2)) 203 | 204 | if k % self.n_early_every == 0 and k: 205 | z = torch.cat((remained_z.pop(), z), 1) 206 | 207 | z = z.transpose(1, 2).contiguous().view(batch_dim, -1) 208 | return z, logdet 209 | 210 | def _upsample_h(self, h): 211 | # return F.interpolate(h, scale_factor=self.upsample_factor, mode='linear') 212 | return self.upsampler(h) 213 | -------------------------------------------------------------------------------- /model/wsrglow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import MuLawEncoding 5 | from .waveglow import WaveGlow 6 | 7 | 8 | class AngleEmbedding(nn.Module): 9 | def __init__(self, embed_num, hidden_dim): 10 | super(AngleEmbedding, self).__init__() 11 | self.embed_num = embed_num 12 | self.embed = nn.Embedding(num_embeddings=embed_num, 13 | embedding_dim=hidden_dim) 14 | 15 | def forward(self, index): 16 | embed_num = self.embed_num 17 | index = ((index / torch.pi + 1) * 0.5 * (embed_num - 1)).long() 18 | return self.embed(index) 19 | 20 | 21 | class WSRGlow(WaveGlow): 22 | def __init__(self, upsample_rate: int = 2, memory_efficient: bool = False, **kwargs) -> None: 23 | super().__init__( 24 | 12, 8 * upsample_rate, 4, 2, 8 * upsample_rate, 8 * 400 + 51 * 9, 25 | memory_efficient=memory_efficient, **kwargs 26 | ) 27 | self.mu_enc = nn.Sequential( 28 | MuLawEncoding(256), 29 | nn.Embedding(256, 400) 30 | ) 31 | self.angle_embed = AngleEmbedding(embed_num=120, hidden_dim=50) 32 | 33 | self.n_fft = 16 34 | self.hop_length = 8 35 | self.register_buffer('window', torch.hann_window(self.n_fft)) 36 | 37 | def _get_cond(self, c): 38 | c = c.clip_(-1, 1) 39 | c_emb = self.mu_enc(c).view(c.shape[0], -1, 8 * 400).transpose(1, 2) 40 | spec = torch.stft( 41 | F.pad(c.unsqueeze(1), (4, 4), mode='reflect').squeeze(1), 42 | n_fft=self.n_fft, 43 | hop_length=self.hop_length, 44 | window=self.window, 45 | center=False, return_complex=True 46 | ) 47 | mag = spec.abs() 48 | phase_emb = self.angle_embed(spec.angle()).permute( 49 | 0, 1, 3, 2).reshape(spec.shape[0], 50 * 9, -1) 50 | return torch.cat([c_emb, mag, phase_emb], dim=1) 51 | 52 | def forward_computation(self, x, h): 53 | return super().forward_computation(x, self._get_cond(h)) 54 | 55 | def reverse_computation(self, z, h): 56 | return super().reverse_computation(z, self._get_cond(h)) 57 | -------------------------------------------------------------------------------- /samples/2293_generated.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/samples/2293_generated.wav -------------------------------------------------------------------------------- /samples/2298_generated.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/samples/2298_generated.wav -------------------------------------------------------------------------------- /samples/waveflow_64chs/LJ001-0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/samples/waveflow_64chs/LJ001-0001.wav -------------------------------------------------------------------------------- /samples/waveflow_64chs/LJ010-0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/samples/waveflow_64chs/LJ010-0001.wav -------------------------------------------------------------------------------- /samples/waveflow_64chs/LJ020-0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/samples/waveflow_64chs/LJ020-0001.wav -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from tqdm import tqdm 5 | import data_loader.data_loaders as module_data 6 | import model.loss as module_loss 7 | import model.metric as module_metric 8 | import model.model as module_arch 9 | from train import get_instance 10 | 11 | 12 | def main(config, resume): 13 | # setup data_loader instances 14 | data_loader = getattr(module_data, config['data_loader']['type'])( 15 | config['data_loader']['args']['data_dir'], 16 | batch_size=512, 17 | shuffle=False, 18 | validation_split=0.0, 19 | training=False, 20 | num_workers=2 21 | ) 22 | 23 | # build model architecture 24 | model = get_instance(module_arch, 'arch', config) 25 | model.summary() 26 | 27 | # get function handles of loss and metrics 28 | loss_fn = getattr(module_loss, config['loss']) 29 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 30 | 31 | # load state dict 32 | checkpoint = torch.load(resume) 33 | state_dict = checkpoint['state_dict'] 34 | if config['n_gpu'] > 1: 35 | model = torch.nn.DataParallel(model) 36 | model.load_state_dict(state_dict) 37 | 38 | # prepare model for testing 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | model = model.to(device) 41 | model.eval() 42 | 43 | total_loss = 0.0 44 | total_metrics = torch.zeros(len(metric_fns)) 45 | 46 | with torch.no_grad(): 47 | for i, (data, target) in enumerate(tqdm(data_loader)): 48 | data, target = data.to(device), target.to(device) 49 | output = model(data) 50 | # 51 | # save sample images, or do something with output here 52 | # 53 | 54 | # computing loss, metrics on test set 55 | loss = loss_fn(output, target) 56 | batch_size = data.shape[0] 57 | total_loss += loss.item() * batch_size 58 | for i, metric in enumerate(metric_fns): 59 | total_metrics[i] += metric(output, target) * batch_size 60 | 61 | n_samples = len(data_loader.sampler) 62 | log = {'loss': total_loss / n_samples} 63 | log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)}) 64 | print(log) 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser(description='PyTorch Template') 69 | 70 | parser.add_argument('-r', '--resume', default=None, type=str, 71 | help='path to latest checkpoint (default: None)') 72 | parser.add_argument('-d', '--device', default=None, type=str, 73 | help='indices of GPUs to enable (default: all)') 74 | 75 | args = parser.parse_args() 76 | 77 | if args.resume: 78 | config = torch.load(args.resume)['config'] 79 | if args.device: 80 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device 81 | 82 | main(config, args.resume) 83 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/constant-memory-waveglow/41e2c27201a3df69359e30880b15196cb5d5f0a3/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_fwd_bwd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | 6 | from model.efficient_modules import AffineCouplingBlock, InvertibleConv1x1 7 | from model.waveglow import WN 8 | from model.loss import WaveGlowLoss 9 | 10 | torch.backends.cuda.matmul.allow_tf32 = False 11 | torch.backends.cudnn.allow_tf32 = False 12 | 13 | 14 | def set_seed(seed): 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | 18 | 19 | @pytest.mark.parametrize('batch', list(2 ** i for i in range(6))) 20 | @pytest.mark.parametrize('channels', list(2 ** i for i in range(1, 4))) 21 | @pytest.mark.parametrize('length', [2000]) 22 | def test_conv1x1_fwd_bwd(batch, channels, length): 23 | weights = InvertibleConv1x1(channels).state_dict() 24 | loss_func = WaveGlowLoss().cuda() 25 | 26 | for seed in range(10): 27 | set_seed(seed) 28 | data = torch.rand(batch, channels, length) * 2 - 1 29 | for bwd in [False, True]: 30 | impl_out, impl_grad = [], [] 31 | for keep_input in [True, False]: 32 | model = InvertibleConv1x1(channels, not keep_input) 33 | model.load_state_dict(weights) 34 | model = model.cuda() 35 | model.train() 36 | model.zero_grad() 37 | 38 | x = data.cuda() 39 | 40 | if bwd: 41 | xin = x.clone() 42 | y, log1 = model.reverse(xin) 43 | yrev = y.clone() 44 | xinv, log2 = model(yrev) 45 | else: 46 | xin = x.clone() 47 | y, log1 = model(xin) 48 | yrev = y.clone() 49 | xinv, log2 = model.reverse(yrev) 50 | 51 | assert torch.equal(log1, log2.neg()) 52 | loss = loss_func(y.view(batch, -1), log1) 53 | 54 | if keep_input: 55 | assert xin.data.shape == x.shape 56 | assert y.data.shape == yrev.shape 57 | else: 58 | assert len(xin.data.shape) == 0 \ 59 | or (len(xin.data.shape) == 0 and xin.data.shape[0] == 0) \ 60 | or xin.storage().size() == 0 61 | 62 | assert len(yrev.data.shape) == 0 \ 63 | or (len(yrev.data.shape) == 0 and yrev.data.shape[0] == 0) \ 64 | or yrev.storage().size() == 0 65 | 66 | loss.backward() 67 | 68 | assert y.shape == x.shape 69 | assert x.shape == data.shape 70 | assert torch.allclose(x.cpu(), data) 71 | print(torch.abs(x - xinv).max().item()) 72 | assert torch.allclose(x, xinv, atol=1e-6, rtol=0) 73 | 74 | impl_out.append(y.detach().cpu()) 75 | impl_grad.append([p.grad.cpu() for p in model.parameters()]) 76 | 77 | for p_grad1, p_grad2 in zip(impl_grad[0], impl_grad[1]): 78 | assert torch.allclose(p_grad1, p_grad2, atol=5e-7, rtol=0) 79 | assert torch.allclose(impl_out[0], impl_out[1]) 80 | 81 | 82 | @pytest.mark.parametrize('batch', [2]) 83 | @pytest.mark.parametrize('channels', list(2 ** i for i in range(4, 6))) 84 | @pytest.mark.parametrize('WN_channels', [128]) 85 | @pytest.mark.parametrize('depth', list(range(1, 5))) 86 | @pytest.mark.parametrize('aux_channels', [20, 40]) 87 | @pytest.mark.parametrize('length', [4000]) 88 | def test_affine_fwd_bwd(batch, channels, WN_channels, depth, aux_channels, length): 89 | 90 | weights = AffineCouplingBlock(WN, False, in_channels=channels // 2, aux_channels=aux_channels, 91 | zero_init=False, 92 | dilation_channels=WN_channels, 93 | residual_channels=WN_channels, 94 | skip_channels=WN_channels, 95 | depth=depth).state_dict() 96 | 97 | loss_func = WaveGlowLoss().cuda() 98 | 99 | for seed in range(10): 100 | set_seed(seed) 101 | data = torch.rand(batch, channels, length) * 2 - 1 102 | condition = torch.randn(batch, aux_channels, length) 103 | for bwd in [False, True]: 104 | impl_out, impl_grad = [], [] 105 | for keep_input in [True, False]: 106 | model = AffineCouplingBlock(WN, not keep_input, in_channels=channels // 2, aux_channels=aux_channels, 107 | zero_init=False, 108 | dilation_channels=WN_channels, 109 | residual_channels=WN_channels, 110 | skip_channels=WN_channels, 111 | depth=depth) 112 | model.load_state_dict(weights) 113 | model = model.cuda() 114 | model.train() 115 | model.zero_grad() 116 | 117 | x = data.cuda() 118 | h = condition.cuda() 119 | 120 | if bwd: 121 | xin = x.clone() 122 | y, log1 = model.reverse(xin, h) 123 | yrev = y.clone() 124 | xinv, log2 = model(yrev, h) 125 | else: 126 | xin = x.clone() 127 | y, log1 = model(xin, h) 128 | yrev = y.clone() 129 | xinv, log2 = model.reverse(yrev, h) 130 | 131 | assert torch.equal(log1, log2.neg()) 132 | loss = loss_func(y.view(2, -1), log1.sum((1, 2))) 133 | 134 | if keep_input: 135 | assert xin.data.shape == x.shape 136 | assert y.data.shape == yrev.shape 137 | else: 138 | assert len(xin.data.shape) == 0 \ 139 | or (len(xin.data.shape) == 0 and xin.data.shape[0] == 0) \ 140 | or xin.storage().size() == 0 141 | 142 | assert len(yrev.data.shape) == 0 \ 143 | or (len(yrev.data.shape) == 0 and yrev.data.shape[0] == 0) \ 144 | or yrev.storage().size() == 0 145 | assert h.shape == condition.shape 146 | assert torch.allclose(h.cpu(), condition) 147 | 148 | loss.backward() 149 | 150 | assert y.shape == x.shape 151 | assert x.data.shape == data.shape 152 | assert torch.allclose(x.cpu(), data) 153 | print(torch.abs(x - xinv).max().item()) 154 | assert torch.allclose(x, xinv, atol=1e-7) 155 | 156 | impl_out.append(y.cpu().detach()) 157 | impl_grad.append([p.grad.cpu() for p in model.parameters()]) 158 | 159 | for p_grad1, p_grad2 in zip(impl_grad[0], impl_grad[1]): 160 | assert torch.allclose(p_grad1, p_grad2) 161 | assert torch.allclose(impl_out[0], impl_out[1]) 162 | 163 | 164 | @pytest.mark.parametrize('batch', list(2 ** i for i in range(6))) 165 | @pytest.mark.parametrize('channels', list(2 ** i for i in range(1, 4))) 166 | @pytest.mark.parametrize('length', [2000]) 167 | def test_complx_chained(batch, channels, length): 168 | 169 | model1 = nn.ModuleList([InvertibleConv1x1(channels, True), 170 | InvertibleConv1x1(channels, False), 171 | InvertibleConv1x1(channels, True)]) 172 | model2 = nn.ModuleList([InvertibleConv1x1(channels, False), 173 | InvertibleConv1x1(channels, True), 174 | InvertibleConv1x1(channels, False)]) 175 | model2.load_state_dict(model1.state_dict()) 176 | loss_func = WaveGlowLoss().cuda() 177 | 178 | for seed in range(10): 179 | set_seed(seed) 180 | data = torch.rand(batch, channels, length) * 2 - 1 181 | impl_grad = [] 182 | for model in [model1, model2]: 183 | model = model.cuda() 184 | model.train() 185 | model.zero_grad() 186 | 187 | x = data.cuda() 188 | 189 | xin = x.clone() 190 | logdet = 0 191 | for layer in model: 192 | xin, _ = layer.reverse(xin) 193 | logdet = logdet + _ 194 | 195 | loss = loss_func(xin.view(batch, -1), logdet) 196 | 197 | loss.backward() 198 | impl_grad.append([p.grad.cpu() for p in model.parameters()]) 199 | 200 | for p_grad1, p_grad2 in zip(impl_grad[0], impl_grad[1]): 201 | assert torch.allclose(p_grad1, p_grad2, atol=5e-7, rtol=0) 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | 7 | 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import ModelSummary, DeviceStatsMonitor, LearningRateMonitor 10 | from pytorch_lightning.plugins import DDPPlugin 11 | import torchaudio 12 | 13 | 14 | from model import LightModel 15 | 16 | 17 | class TestFileCallBack(pl.Callback): 18 | def __init__(self, test_file: str) -> None: 19 | super().__init__() 20 | 21 | y, sr = torchaudio.load(test_file) 22 | self.test_y = y.mean(0) 23 | self.sr = sr 24 | 25 | def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: LightModel) -> None: 26 | if not trainer.is_global_zero: 27 | return 28 | y = self.test_y.to(pl_module.device).unsqueeze(0) 29 | with torch.no_grad(): 30 | cond = pl_module.conditioner(y) 31 | pred = pl_module(cond, 0.7).cpu() 32 | 33 | trainer.logger.experiment.add_audio( 34 | 'reconstruct_audio', pred[:, None], sample_rate=self.sr, global_step=trainer.global_step) 35 | 36 | 37 | class ChangeLRCallback(pl.Callback): 38 | def __init__(self, lr: float) -> None: 39 | super().__init__() 40 | self.lr = lr 41 | 42 | def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 43 | for optimizer in trainer.optimizers: 44 | for param_group in optimizer.param_groups: 45 | param_group['lr'] = self.lr 46 | 47 | 48 | def main(args, config): 49 | pl.seed_everything(args.seed) 50 | 51 | gpus = torch.cuda.device_count() 52 | if config is not None: 53 | config['data_loader']['batch_size'] //= gpus 54 | 55 | callbacks = [ 56 | ModelSummary(max_depth=2), 57 | LearningRateMonitor('epoch') 58 | # DeviceStatsMonitor() 59 | ] 60 | if args.test_file: 61 | callbacks.append(TestFileCallBack(args.test_file)) 62 | if args.lr: 63 | callbacks.append(ChangeLRCallback(args.lr)) 64 | 65 | if args.ckpt_path: 66 | kwargs = {} 67 | if config is not None: 68 | kwargs['config'] = config 69 | lit_model = LightModel.load_from_checkpoint(args.ckpt_path, **kwargs) 70 | else: 71 | lit_model = LightModel(config) 72 | 73 | trainer = pl.Trainer.from_argparse_args( 74 | args, callbacks=callbacks, log_every_n_steps=1, 75 | benchmark=True, detect_anomaly=True, gpus=gpus, 76 | max_epochs=100, 77 | strategy=DDPPlugin(find_unused_parameters=False) if gpus > 1 else None) 78 | trainer.fit(lit_model, ckpt_path=args.ckpt_path) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description='PyTorch WaveGlow') 83 | parser = LightModel.add_model_specific_args(parser) 84 | parser = pl.Trainer.add_argparse_args(parser) 85 | parser.add_argument('--config', type=str, 86 | help='config file path (default: None)') 87 | parser.add_argument('--ckpt-path', type=str) 88 | parser.add_argument('--test-file', type=str) 89 | parser.add_argument('--seed', type=int, default=None) 90 | parser.add_argument('--lr', type=float, default=None, 91 | help='force learning rate') 92 | parser.add_argument('--no-tf32', action='store_true') 93 | args = parser.parse_args() 94 | 95 | if args.no_tf32 and torch.cuda.is_available(): 96 | torch.backends.cuda.matmul.allow_tf32 = False 97 | torch.backends.cudnn.allow_tf32 = False 98 | 99 | config = json.load(open(args.config)) if args.config else None 100 | main(args, config) 101 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import nn 3 | 4 | 5 | def get_instance(module, config, *args): 6 | return getattr(module, config['type'])(*args, **config['args']) 7 | 8 | 9 | def remove_weight_norms(m): 10 | if hasattr(m, 'weight_g'): 11 | nn.utils.remove_weight_norm(m) 12 | 13 | 14 | def add_weight_norms(m): 15 | if hasattr(m, 'weight'): 16 | nn.utils.weight_norm(m) 17 | 18 | 19 | def ensure_dir(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | -------------------------------------------------------------------------------- /vctk_wsrglow_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from pathlib import Path 4 | import argparse 5 | import torchaudio 6 | from kazane import Decimate 7 | from functools import partial 8 | from tqdm import tqdm 9 | from model.condition import STFTDecimate 10 | from model import LightModel 11 | 12 | 13 | class LSD(torch.nn.Module): 14 | def __init__(self, n_fft=2048, hop_length=512): 15 | super().__init__() 16 | self.n_fft = n_fft 17 | self.hop_length = hop_length 18 | self.register_buffer('window', torch.hann_window(n_fft)) 19 | 20 | def forward(self, y_hat, y): 21 | Y_hat = torch.stft(y_hat, self.n_fft, hop_length=self.hop_length, 22 | window=self.window, return_complex=True) 23 | Y = torch.stft(y, self.n_fft, hop_length=self.hop_length, 24 | window=self.window, return_complex=True) 25 | sp = Y_hat.abs().square_().clamp_(min=1e-8).log10_() 26 | st = Y.abs().square_().clamp_(min=1e-8).log10_() 27 | return (sp - st).square_().mean(0).sqrt_().mean() 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('vctk', type=str) 33 | parser.add_argument('-q', type=int, default=2) 34 | parser.add_argument('--ckpt', type=str, 35 | default='../WSRGlow/ckpt/x2_best.pt') 36 | parser.add_argument('--downsample-type', type=str, 37 | choices=['sinc', 'stft'], default='stft') 38 | 39 | args = parser.parse_args() 40 | 41 | checkpoint = args.ckpt 42 | model = LightModel.load_from_checkpoint(checkpoint) 43 | model.eval() 44 | model = model.cuda() 45 | 46 | sinc_kwargs = { 47 | 'q': args.q, 48 | 'roll_off': 0.962, 49 | 'num_zeros': 128, 50 | 'window_func': partial(torch.kaiser_window, periodic=False, 51 | beta=14.769656459379492), 52 | } 53 | 54 | if args.downsample_type == 'sinc': 55 | downsampler = Decimate(**sinc_kwargs) 56 | else: 57 | downsampler = STFTDecimate(sinc_kwargs['q']) 58 | downsampler = downsampler.cuda() 59 | evaluater = LSD().cuda() 60 | vctk_path = Path(args.vctk) 61 | test_files = list(vctk_path.glob('*/*.wav')) 62 | 63 | pbar = tqdm(total=len(test_files)) 64 | 65 | lsd_list = [] 66 | chunk_size = 8 * args.q 67 | for filename in test_files: 68 | raw_y, sr = torchaudio.load(filename) 69 | raw_y = raw_y.cuda() 70 | offset = raw_y.shape[1] % chunk_size 71 | if offset > 0: 72 | y = raw_y[:, :-offset] 73 | else: 74 | y = raw_y 75 | 76 | y_lowpass = downsampler(y) 77 | with torch.no_grad(): 78 | y_hat, _ = model.model.reverse(torch.randn_like(y), y_lowpass) 79 | y_hat = y_hat.squeeze() 80 | 81 | if offset > 0: 82 | y_hat = torch.cat([y_hat, y_hat.new_zeros(offset)], dim=0) 83 | raw_y = raw_y.squeeze() 84 | lsd = evaluater(y_hat, raw_y).item() 85 | lsd_list.append(lsd) 86 | pbar.set_postfix(lsd=lsd) 87 | pbar.update(1) 88 | 89 | print(sum(lsd_list) / len(lsd_list)) 90 | --------------------------------------------------------------------------------