├── .gitignore ├── LICENSE ├── README.md ├── audio.py ├── data.py ├── gen_spec.py ├── hparams.py ├── loss.py ├── model.py ├── preprocess.py ├── requirements.txt ├── samples ├── sample1_gen.wav ├── sample1_gt.wav ├── sample2_gen.wav ├── sample2_gt.wav ├── sample3_gen.wav ├── sample3_gt.wav ├── sample4_gen.wav └── sample4_gt.wav ├── stft.py ├── synthesis.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | 114 | # vscode 115 | # vscode 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tuan Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNNVocoder 2 | NOTE: I'm no longer working on this project. See [#9](https://github.com/tuan3w/cnn_vocoder/issues/9#issuecomment-642700118). 3 | ## A CNN-based vocoder. 4 | 5 | This work is inspired from m-cnn model described in [Fast Spectrogram Inversion using Multi-head Convolutional Neural Networks](https://arxiv.org/abs/1808.06719). 6 | The authors show that even a simple upsampling networks is enough to synthesis waveform from spectrogram/mel-spectrogram. 7 | 8 | In this repo, I use spectrogram feature for training model because it contains more information than mel-spectrogram feature. However, because the transformation from spectrogram to mel-spectrogram is just a linear projection, so basically, you can train a simple network predict spectrogram from mel-spectrogram. You also can change parameters to be able to train a vocoder from mel-spectrogram feature too. 9 | 10 | ## [Sample Audios](https://soundcloud.com/nguyen-duc-tuan-80422561/sets/cnn-vocoder-samples) 11 | 12 | ## Architecture notes 13 | 14 | Compare with m-cnn, my proposed network have some differences: 15 | - I use Upsampling + Conv layers instead of TransposedConv layer. This helps to prevent [checkerboard artifacts](https://distill.pub/2016/deconv-checkerboard/). 16 | - The model use a lot of residual blocks pre/after the upsampling module to make network larger/deeper. 17 | - I only used l1 loss between log-scale STFT-magnitude of predicted and target waveform. Evaluation loss on log space is better than on raw STFT-magnitude because it's closer to [human sensation about loudness](http://faculty.tamuc.edu/cbertulani/music/lectures/Lec12/Lec12.pdf). I tried to compute loss on spectrogram feature, but it didn't help much. 18 | 19 | ## Install requirements 20 | 21 | ```bash 22 | $ pip install -r requirements.txt 23 | ``` 24 | ## Training vocoder 25 | ### 1. Prepare dataset 26 | 27 | I use [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset for my experiment. If you don't have it yet, please download dataset and put it somewhere. 28 | 29 | After that, you can run command to generate dataset for our experiment: 30 | 31 | ```bash 32 | $ python preprocessing.py --samples_per_audio 20 \ 33 | --out_dir ljspeech \ 34 | --data_dir path/to/ljspeech/dataset \ 35 | --n_workers 4 36 | ``` 37 | 38 | ### 2. Train vocoder 39 | 40 | ```bash 41 | $ python train.py --out_dir ${output_directory} 42 | ``` 43 | For more training options, please run: 44 | ```bash 45 | $ python train.py --help 46 | ``` 47 | 48 | ## Generate audio from spectrogram 49 | - Generate spectrogram from audio 50 | ```bash 51 | $ python gen_spec.py -i sample.wav -o out.npz 52 | ``` 53 | - Generate audio from spectrogram 54 | 55 | ```bash 56 | $ python synthesis.py --model_path path/to/checkpoint \ 57 | --spec_path out.npz \ 58 | --out_path out.wav 59 | ``` 60 | 61 | ## Pretrained model 62 | You can get my pre-trained model [here](https://drive.google.com/drive/folders/1aUwC8PFXnpWJuAKlhXP3HsOLesNgELn-?usp=sharing). 63 | 64 | ## Acknowledgements 65 | This implementation uses code from [NVIDIA](https://github.com/NVIDIA), [Ryuichi Yamamoto](https://github.com/r9y9), [Keith Ito](https://github.com/keithito) as described in my code. 66 | 67 | ## License 68 | [MIT](LICENSE) 69 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/keithito/tacotron/blob/master/util/audio.py 2 | 3 | import librosa 4 | import librosa.filters 5 | import math 6 | import numpy as np 7 | from scipy import signal 8 | from hparams import hparams 9 | from scipy.io import wavfile 10 | 11 | 12 | n_fft = hparams.fft_size * 2 13 | hop_length = hparams.hop_length 14 | win_length = hparams.win_length 15 | 16 | def load_wav(path): 17 | return librosa.core.load(path, sr=hparams.sample_rate)[0] 18 | 19 | 20 | def save_wav(wav, path): 21 | wav *= hparams.audio_max_value / max(0.01, np.max(np.abs(wav))) 22 | wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 23 | 24 | 25 | def trim(quantized): 26 | start, end = start_and_end_indices(quantized, hparams.silence_threshold) 27 | return quantized[start:end] 28 | 29 | 30 | def adjust_time_resolution(quantized, mel): 31 | """Adjust time resolution by repeating features 32 | 33 | Args: 34 | quantized (ndarray): (T,) 35 | mel (ndarray): (N, D) 36 | 37 | Returns: 38 | tuple: Tuple of (T,) and (T, D) 39 | """ 40 | assert len(quantized.shape) == 1 41 | assert len(mel.shape) == 2 42 | 43 | upsample_factor = quantized.size // mel.shape[0] 44 | mel = np.repeat(mel, upsample_factor, axis=0) 45 | n_pad = quantized.size - mel.shape[0] 46 | if n_pad != 0: 47 | assert n_pad > 0 48 | mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) 49 | 50 | # trim 51 | start, end = start_and_end_indices(quantized, hparams.silence_threshold) 52 | 53 | return quantized[start:end], mel[start:end, :] 54 | adjast_time_resolution = adjust_time_resolution # 'adjust' is correct spelling, this is for compatibility 55 | 56 | 57 | def start_and_end_indices(quantized, silence_threshold=2): 58 | for start in range(quantized.size): 59 | if abs(quantized[start] - 127) > silence_threshold: 60 | break 61 | for end in range(quantized.size - 1, 1, -1): 62 | if abs(quantized[end] - 127) > silence_threshold: 63 | break 64 | 65 | assert abs(quantized[start] - 127) > silence_threshold 66 | assert abs(quantized[end] - 127) > silence_threshold 67 | 68 | return start, end 69 | 70 | def _stft(y): 71 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=True) 72 | 73 | def melspectrogram(y): 74 | D = _stft(y) 75 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db 76 | if not hparams.allow_clipping_in_normalization: 77 | assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 78 | return _normalize(S) 79 | 80 | def spectrogram(y): 81 | D = _stft(y) 82 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db 83 | return _normalize(S) 84 | 85 | 86 | 87 | 88 | def get_hop_size(): 89 | hop_size = hparams.hop_size 90 | if hop_size is None: 91 | assert hparams.frame_shift_ms is not None 92 | hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 93 | return hop_size 94 | 95 | # Conversions: 96 | 97 | _mel_basis = None 98 | 99 | 100 | def _linear_to_mel(spectrogram): 101 | global _mel_basis 102 | if _mel_basis is None: 103 | _mel_basis = _build_mel_basis() 104 | return np.dot(_mel_basis, spectrogram) 105 | 106 | 107 | def _build_mel_basis(): 108 | assert hparams.fmax <= hparams.sample_rate // 2 109 | return librosa.filters.mel(hparams.sample_rate, n_fft, 110 | fmin=hparams.fmin, fmax=hparams.fmax, 111 | n_mels=hparams.num_mels) 112 | 113 | 114 | def _amp_to_db(x): 115 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) 116 | return 20 * np.log10(np.maximum(min_level, x)) 117 | 118 | 119 | def _db_to_amp(x): 120 | return np.power(10.0, x * 0.05) 121 | 122 | 123 | def _normalize(S): 124 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 125 | 126 | 127 | def _denormalize(S): 128 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 129 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import audio 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | from torch.utils.data import DataLoader 5 | 6 | class MelDataset(Dataset): 7 | def __init__(self, file_path): 8 | with open(file_path) as f: 9 | self._files = [] 10 | for l in f: 11 | l = l.strip() 12 | self._files.append(l) 13 | 14 | def __len__(self): 15 | return len(self._files) 16 | 17 | def __getitem__(self, idx): 18 | p = self._files[idx] 19 | f = np.load(p) 20 | return f['wav'].astype(np.float32), f['spec'].astype(np.float32) 21 | -------------------------------------------------------------------------------- /gen_spec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import audio 4 | 5 | def main(args): 6 | wav = audio.load_wav(args.wav_path) 7 | spec = audio.spectrogram(wav) 8 | np.save(args.out_path, spec) 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-i', '--wav_path', type=str, required=True, 13 | help='path to audio path') 14 | parser.add_argument('-o', '--out_path', type=str, required=True, 15 | help='output path') 16 | args = parser.parse_args() 17 | main(args) 18 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | hparams = tf.contrib.training.HParams( 5 | ##################################### 6 | # Audio config 7 | ##################################### 8 | sample_rate=22050, 9 | silence_threshold=2, 10 | num_mels=80, 11 | fmin=125, 12 | fmax=7600, 13 | fft_size=1024, 14 | win_length=1024, 15 | hop_length=256, 16 | min_level_db=-100, 17 | ref_level_db=20, 18 | rescaling=True, 19 | rescaling_max=0.999, 20 | audio_max_value=32767, 21 | allow_clipping_in_normalization=True, 22 | 23 | ##################################### 24 | # Data config 25 | ##################################### 26 | seg_len= 81 * 256, 27 | file_list="training_data/files.txt", 28 | spec_len= 81, 29 | 30 | ##################################### 31 | # Model parameters 32 | ##################################### 33 | n_heads = 2, 34 | pre_residuals = 4, 35 | up_residuals=0, 36 | post_residuals = 12, 37 | pre_conv_channels = [1, 1, 2], 38 | layer_channels = [1025 * 2, 1024, 512, 256, 128, 64, 32, 16, 8], 39 | 40 | 41 | ##################################### 42 | # Training config 43 | ##################################### 44 | n_workers=2, 45 | seed=12345, 46 | batch_size=40, 47 | lr=1.0 * 1e-3, 48 | weight_decay=1e-5, 49 | epochs=50000, 50 | grad_clip_thresh=5.0, 51 | checkpoint_interval=1000, 52 | ) 53 | 54 | 55 | def hparams_debug_string(): 56 | values = hparams.values() 57 | hp = [' %s: %s' % (name, values[name]) for name in sorted(values)] 58 | return 'Hyperparameters:\n' + '\n'.join(hp) 59 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from hparams import hparams 5 | import stft 6 | 7 | def compute_stft(audio, n_fft=1024, win_length=1024, hop_length=256): 8 | """ 9 | Computes STFT transformation of given audio 10 | 11 | Args: 12 | audio (Tensor): B x T, batch of audio 13 | 14 | Returns: 15 | mag (Tensor): STFT magnitudes 16 | real (Tensor): Real part of STFT transformation result 17 | im (Tensor): Imagine part of STFT transformation result 18 | """ 19 | win = torch.hann_window(win_length).cuda() 20 | 21 | # add some padding because torch 4.0 doesn't 22 | signal_dim = audio.dim() 23 | extended_shape = [1] * (3 - signal_dim) + list(audio.size()) 24 | # pad = int(self.n_fft // 2) 25 | pad = win_length 26 | audio = F.pad(audio.view(extended_shape), (pad, pad), 'constant') 27 | audio = audio.view(audio.shape[-signal_dim:]) 28 | 29 | stft = torch.stft(audio, win_length, hop_length, fft_size=n_fft, window=win) 30 | real = stft[:, :, :, 0] 31 | im = stft[:, :, :, 1] 32 | power = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2)) 33 | return power, real, im 34 | 35 | def compute_loss(pred, target): 36 | """ 37 | Computes loss value 38 | 39 | Args: 40 | pred (Tensor): B x T, predicted wavs 41 | target (Tensor): B x T, target wavs 42 | """ 43 | stft_pred, _, _= compute_stft(pred, n_fft=2048, win_length=1024, hop_length=256) 44 | stft_target, _, _ = compute_stft(target, n_fft=2048, win_length=1024, hop_length=256) 45 | l1_loss = nn.L1Loss() 46 | 47 | log_stft_pred = torch.log(stft_pred + 1e-8) 48 | log_stft_target = torch.log(stft_target + 1e-8) 49 | l1 = l1_loss(log_stft_pred, log_stft_target) 50 | l2 = l1_loss(log_stft_pred[:, :, :500], log_stft_target[:, :,:500]) 51 | l3 = l1_loss(stft_pred[:,:,:500], stft_target[:,:,:500]) 52 | loss = l1 + l2 + l3 53 | return loss, l1, l2, l3 54 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class ResnetBlock(nn.Module): 6 | """Residual Block 7 | Args: 8 | in_channels (int): number of channels in input data 9 | out_channels (int): number of channels in output 10 | """ 11 | def __init__(self, in_channels, out_channels, kernel_size=3, one_d=False): 12 | super(ResnetBlock, self).__init__() 13 | self.build_conv_block(in_channels, out_channels, one_d, kernel_size=kernel_size) 14 | 15 | def build_conv_block(self, in_channels, out_channels, one_d, kernel_size=3): 16 | padding = (kernel_size -1)//2 17 | if not one_d: 18 | conv = nn.Conv2d 19 | norm = nn.BatchNorm2d 20 | else: 21 | conv = nn.Conv1d 22 | norm = nn.BatchNorm1d 23 | 24 | self.conv1 = nn.Sequential( 25 | conv(in_channels, out_channels, kernel_size=kernel_size, padding=padding), 26 | norm(out_channels), 27 | nn.ELU() 28 | ) 29 | self.conv2 = nn.Sequential( 30 | conv(out_channels, out_channels, kernel_size=kernel_size, padding=padding), 31 | norm(out_channels), 32 | ) 33 | if in_channels != out_channels: 34 | self.down = nn.Sequential( 35 | conv(in_channels, out_channels, kernel_size=1, bias=False), 36 | norm(out_channels) 37 | ) 38 | else: 39 | self.down = None 40 | 41 | self.act = nn.ELU() 42 | 43 | def forward(self, x): 44 | """ 45 | Args: 46 | x (Tensor): B x C x T 47 | """ 48 | residual = x 49 | out = self.conv1(x) 50 | out = self.conv2(out) 51 | if self.down is not None: 52 | residual = self.down(residual) 53 | return self.act(out + residual) 54 | 55 | class UpsamplingLayer(nn.Module): 56 | """Applies 1D upsampling operator over input tensor. 57 | 58 | Args: 59 | in_channels (int): number of input channels 60 | out_channels (int): number of output channels 61 | residuals (int, optional): number of residual blocks. Default=0 62 | """ 63 | def __init__(self, in_channels, out_channels, residuals=0): 64 | super(UpsamplingLayer, self).__init__() 65 | # TODO: try umsampling with bilinear interpolation 66 | self.upsample = nn.Upsample(scale_factor=2, mode='linear') 67 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) 68 | torch.nn.init.xavier_uniform_(self.conv.weight) 69 | self.bn = nn.BatchNorm1d(out_channels) 70 | self.act = nn.ELU() 71 | 72 | if residuals != 0: 73 | # resnet blocks 74 | layers = [] 75 | for _ in range(residuals): 76 | layers.append( 77 | ResnetBlock(out_channels, out_channels, one_d=True) 78 | ) 79 | self.res_blocks = nn.Sequential(*layers) 80 | else: 81 | self.res_blocks = None 82 | 83 | 84 | def forward(self, x): 85 | """ 86 | Args: 87 | x (Tensor): B x in_channels x T 88 | 89 | Returns: 90 | Tensor of shape (B, out_channels, T x 2) 91 | """ 92 | # upsample network 93 | B, C, T = x.shape 94 | # upsample 95 | # x = x.unsqueeze(dim=3) 96 | # x = F.upsample(x, size=(T*2, 1), mode='bilinear').squeeze(3) 97 | x = self.upsample(x) 98 | # x = self.pad(x) 99 | x = self.conv(x) 100 | x = self.bn(x) 101 | x = self.act(x) 102 | 103 | # pass through resnet blocks to improve internal representations 104 | # of data 105 | if self.res_blocks != None: 106 | x = self.res_blocks(x) 107 | return x 108 | 109 | 110 | class ConvBlock(nn.Module): 111 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): 112 | super(ConvBlock, self).__init__() 113 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) 114 | self.bn = nn.BatchNorm2d(out_channels) 115 | self.act = nn.ELU() 116 | 117 | def forward(self, x): 118 | x = self.conv(x) 119 | x = self.bn(x) 120 | x = self.act(x) 121 | return x 122 | 123 | class Head(nn.Module): 124 | """Head module 125 | 126 | Args: 127 | channels (list): list of #channels in each upsampling layer 128 | pre_residuals (int, optional): number of residual blocks before upsampling. Default: 64 129 | down_conv_channels (list): list of #channels in each down_conv blocks 130 | up_residuals (int, optional): number of residual blocks in each upsampling module. Default: 0 131 | """ 132 | def __init__(self, channels, 133 | pre_residuals=64, 134 | pre_conv_channels=[64, 32, 16, 8, 4], 135 | up_residuals=0, 136 | post_residuals=2): 137 | super(Head, self).__init__() 138 | pre_convs = [] 139 | c0 = pre_conv_channels[0] 140 | pre_convs.append(ConvBlock(1, c0, kernel_size=3, padding=1)) 141 | for _ in range(pre_residuals): 142 | pre_convs.append(ResnetBlock(c0, c0)) 143 | 144 | for i in range(len(pre_conv_channels) -1): 145 | in_c = pre_conv_channels[i] 146 | out_c = pre_conv_channels[i + 1] 147 | pre_convs.append(ResnetBlock(in_c, out_c)) 148 | for _ in range(pre_residuals): 149 | pre_convs.append(ResnetBlock(out_c, out_c)) 150 | self.pre_conv = nn.Sequential(*pre_convs) 151 | 152 | up_layers = [] 153 | for i in range(len(channels) - 1): 154 | in_channels = channels[i] 155 | out_channels = channels[i + 1] 156 | layer = UpsamplingLayer(in_channels, out_channels, residuals=up_residuals) 157 | up_layers.append(layer) 158 | self.upsampling = nn.Sequential(*up_layers) 159 | 160 | post_convs = [] 161 | last_channels = channels[-1] 162 | for i in range(post_residuals): 163 | post_convs.append(ResnetBlock(last_channels, last_channels, one_d=True, kernel_size=5)) 164 | self.post_conv = nn.Sequential(*post_convs) 165 | 166 | def forward(self, x): 167 | """ 168 | forward pass 169 | Args: 170 | x (Tensor): B x C x T 171 | 172 | Returns: 173 | Tensor: B x C x (2^#channels * T) 174 | """ 175 | x = x.unsqueeze(1) # reshape to [B x 1 x C x T] 176 | x = self.pre_conv(x) 177 | s1, _, _, s4 = x.shape 178 | x = x.reshape(s1, -1, s4) 179 | x = self.upsampling(x) 180 | x2 = self.post_conv(x) 181 | return x, x2 182 | 183 | 184 | DEFAULT_LAYERS_PARAMS = [80, 128, 128, 64, 64, 32, 16, 8, 1] 185 | class CNNVocoder(nn.Module): 186 | """CNN Vocoder 187 | 188 | Args: 189 | n_heads (int): Number of heads 190 | layer_channels (list): list of #channels of each layer 191 | """ 192 | def __init__(self, n_heads=3, 193 | layer_channels=DEFAULT_LAYERS_PARAMS, 194 | pre_conv_channels=[64, 32, 16, 8, 4], 195 | pre_residuals=64, 196 | up_residuals=0, 197 | post_residuals=3): 198 | super(CNNVocoder, self).__init__() 199 | self.head = Head(layer_channels, 200 | pre_conv_channels=pre_conv_channels, 201 | pre_residuals=pre_residuals, up_residuals=up_residuals, 202 | post_residuals=post_residuals) 203 | self.linear = nn.Linear(layer_channels[-1], 1) 204 | self.act_fn = nn.Softsign() 205 | 206 | def forward(self, x): 207 | b = x.shape[0] 208 | pre, post = self.head(x) 209 | 210 | rs0 = self.linear(pre.transpose(1, 2)) 211 | rs0 = self.act_fn(rs0).squeeze(-1) 212 | 213 | rs1 = self.linear(post.transpose(1, 2)) 214 | rs1 = self.act_fn(rs1).squeeze(-1) 215 | return rs0, rs1 216 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | from functools import partial 5 | import numpy as np 6 | from tqdm import tqdm 7 | import os 8 | import audio 9 | from random import shuffle 10 | from concurrent.futures import ProcessPoolExecutor 11 | from hparams import hparams 12 | 13 | def gen_samples(out_dir, wav_path, n_samples): 14 | wav = audio.load_wav(wav_path) 15 | hop_size = hparams.hop_length 16 | seg_len = hparams.seg_len 17 | spec_len = hparams.spec_len 18 | # not sure why we have to minus 1 here ? 19 | wav_len = wav.shape[0] // hop_size * hop_size -1 20 | wav = wav[:wav_len] 21 | spec = audio.spectrogram(wav) 22 | mel = audio.melspectrogram(wav) 23 | max_val = spec.shape[1] - 1 - spec_len 24 | if max_val < 0: 25 | return [] 26 | idx = np.random.randint(0, max_val, size=(n_samples)) 27 | d = [] 28 | i = 0 29 | for offset in idx: 30 | i += 1 31 | w = wav[offset * hop_size: offset * hop_size + seg_len] 32 | s = spec[:,offset:offset + spec_len] 33 | m = mel[:,offset:offset + spec_len] 34 | wav_name = wav_path.split('/')[-1].split('.')[0] 35 | file_path = "{0}/{1}_{2:03d}.npz".format(out_dir, wav_name, i) 36 | np.savez(file_path, wav=w, spec=s, mel=m) 37 | d.append(file_path) 38 | return d 39 | 40 | def main(args): 41 | executor = ProcessPoolExecutor( 42 | max_workers=args.num_workers) 43 | files = [] 44 | audio_dir = os.path.join(args.data_dir, 'wavs') 45 | out_dir = args.out_dir 46 | audio_files = glob.glob(audio_dir + '/*.wav') 47 | samples_per_audio = args.samples_per_audio 48 | futures = [] 49 | index = 0 50 | for wav_path in audio_files: 51 | futures.append(executor.submit(partial(gen_samples, out_dir, wav_path, samples_per_audio))) 52 | index += 1 53 | files = [future.result() for future in tqdm(futures)] 54 | files =[y for z in files for y in z] 55 | txt_path = os.path.join(out_dir, "files.txt") 56 | with open(txt_path, 'w', encoding='utf-8') as f: 57 | files = [f + '\n' for f in files] 58 | # shuffle data 59 | shuffle(files) 60 | f.writelines(files) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-o', '--out_dir', default='training_data', type=str, help='Output directory to store dataset') 66 | parser.add_argument('-d', '--data_dir', default="ljspeech", type=str, help='Path to ljspeech dataset') 67 | parser.add_argument('-s', '--samples_per_audio', type=int, default=400, help='Number of sample per audio') 68 | parser.add_argument('-n', '--num_workers', type=int, default=4) 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.4.0 2 | tensorflow>=1.3.0 3 | pysptk >= 0.1.9 4 | tensorboardX 5 | librosa 6 | tqdm 7 | matplotlib 8 | numpy 9 | scipy 10 | 11 | -------------------------------------------------------------------------------- /samples/sample1_gen.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample1_gen.wav -------------------------------------------------------------------------------- /samples/sample1_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample1_gt.wav -------------------------------------------------------------------------------- /samples/sample2_gen.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample2_gen.wav -------------------------------------------------------------------------------- /samples/sample2_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample2_gt.wav -------------------------------------------------------------------------------- /samples/sample3_gen.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample3_gen.wav -------------------------------------------------------------------------------- /samples/sample3_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample3_gt.wav -------------------------------------------------------------------------------- /samples/sample4_gen.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample4_gen.wav -------------------------------------------------------------------------------- /samples/sample4_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuan3w/cnn_vocoder/7855be5174bb61a38c4f7c7c482365d460f93a2a/samples/sample4_gt.wav -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from hparams import hparams 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import librosa 7 | 8 | def _normalize(S): 9 | return (S - hparams.min_level_db)/-hparams.min_level_db 10 | 11 | def _build_mel_basis(n_fft, n_mels=80): 12 | return torch.FloatTensor(librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=n_mels)).transpose(0, 1) 13 | 14 | def _amp_to_db(x): 15 | #return 20 * torch.log10(torch.clamp(x, min=1e-5)) 16 | return 20 * torch.log10(x + 1e-5) 17 | 18 | 19 | class Spectrogram(nn.Module): 20 | """Spectrogram transformation. 21 | 22 | Args: 23 | win_length (int): stft window length 24 | hop_length (int): stft hop length 25 | n_fft (int): number of fft basis 26 | preemp (bool): whether pre-emphasis audio before do stft 27 | """ 28 | def __init__(self, win_length=1024, hop_length=256, n_fft=2048, preemp=True): 29 | super(Spectrogram, self).__init__() 30 | if preemp: 31 | self.preemp = nn.Conv1d(1, 1, 2, bias=False, padding=1) 32 | self.preemp.weight.data[0][0][0] = -0.97 33 | self.preemp.weight.data[0][0][1] = 1.0 34 | self.preemp.weight.requires_grad = False 35 | else: 36 | self.preemp = None 37 | 38 | win = torch.hann_window(win_length) 39 | self.register_buffer('win', win) 40 | self.win_length = win_length 41 | self.hop_length = hop_length 42 | self.n_fft = n_fft 43 | 44 | def forward(self, x): 45 | if self.preemp is not None: 46 | x = x.unsqueeze(1) 47 | # conv and remove last padding 48 | x = self.preemp(x)[:, :, :-1] 49 | x = x.squeeze(1) 50 | 51 | # center=True 52 | # torch 0.4 doesnt support like librosa 53 | signal_dim = x.dim() 54 | extended_shape = [1] * (3 - signal_dim) + list(x.size()) 55 | # pad = int(self.n_fft // 2) 56 | pad = self.win_length 57 | x = F.pad(x.view(extended_shape), (pad, pad), 'constant') 58 | x = x.view(x.shape[-signal_dim:]) 59 | stft = torch.stft(x, 60 | self.win_length, 61 | self.hop_length, 62 | window=self.win, 63 | fft_size=self.n_fft) 64 | real = stft[:, :, :, 0] 65 | im = stft[:, :, :, 1] 66 | p = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2)) 67 | 68 | # convert volume to db 69 | spec = _amp_to_db(p) - hparams.ref_level_db 70 | return spec, p 71 | 72 | class MelSpectrogram(nn.Module): 73 | """MelSpectrogram transformation. 74 | 75 | Args: 76 | win_length (int): stft window length 77 | hop_length (int): stft hop length 78 | n_fft (int): number of fft basis 79 | n_mels (int): number of mel filters 80 | preemp (bool): whether pre-emphasis audio before do stft 81 | """ 82 | def __init__(self, win_length=1024, hop_length=256, n_fft=2048, n_mels=80, preemp=True): 83 | super(MelSpectrogram, self).__init__() 84 | if preemp: 85 | self.preemp = nn.Conv1d(1, 1, 2, bias=False, padding=1) 86 | self.preemp.weight.data[0][0][0] = -0.97 87 | self.preemp.weight.data[0][0][1] = 1.0 88 | self.preemp.weight.requires_grad = False 89 | else: 90 | self.preemp = None 91 | 92 | self.register_buffer('mel_basis', _build_mel_basis(n_fft, n_mels)) 93 | 94 | win = torch.hann_window(win_length) 95 | self.register_buffer('win', win) 96 | 97 | self.win_length = win_length 98 | self.hop_length = hop_length 99 | self.n_fft = n_fft 100 | 101 | def forward(self, x): 102 | if self.preemp is not None: 103 | x = x.unsqueeze(1) 104 | x = self.preemp(x) 105 | x = x.squeeze(1) 106 | stft = torch.stft(x, 107 | self.win_length, 108 | self.hop_length, 109 | fft_size=self.n_fft, 110 | window=self.win) 111 | real = stft[:, :, :, 0] 112 | im = stft[:, :, :, 1] 113 | spec = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2)) 114 | 115 | # convert linear spec to mel 116 | mel = torch.matmul(spec, self.mel_basis) 117 | # convert to db 118 | mel = _amp_to_db(mel) - hparams.ref_level_db 119 | return _normalize(mel) 120 | -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import audio 8 | from hparams import hparams, hparams_debug_string 9 | from model import CNNVocoder 10 | from utils import load_checkpoint 11 | 12 | 13 | def main(args): 14 | model = CNNVocoder( 15 | n_heads=hparams.n_heads, 16 | layer_channels=hparams.layer_channels, 17 | pre_conv_channels=hparams.pre_conv_channels, 18 | pre_residuals=hparams.pre_residuals, 19 | up_residuals=hparams.up_residuals, 20 | post_residuals=hparams.post_residuals 21 | ) 22 | model = model.cuda() 23 | 24 | model, _, _, _ = load_checkpoint( 25 | args.model_path, model) 26 | spec = np.load(args.spec_path) 27 | spec = torch.FloatTensor(spec).unsqueeze(0).cuda() 28 | t1 = time() 29 | _, wav = model(spec) 30 | dt = time() - t1 31 | print('Synthesized audio in {}s'.format(dt)) 32 | wav = wav.data.cpu()[0].numpy() 33 | audio.save_wav(wav, args.out_path) 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--model_path', type=str, required=True, 38 | help='path to model checkpoint') 39 | parser.add_argument('--spec_path', type=str, default="logs", 40 | help='path to spec file') 41 | parser.add_argument('--out_path', type=str, default=None, 42 | help='output wav path') 43 | args = parser.parse_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import matplotlib 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | from torch.nn.utils import clip_grad_norm_ 9 | from torch.utils.data import DataLoader 10 | 11 | import audio 12 | import librosa 13 | from data import MelDataset 14 | from hparams import hparams, hparams_debug_string 15 | from loss import compute_loss 16 | from model import CNNVocoder 17 | from tensorboardX import SummaryWriter 18 | from tqdm import tqdm 19 | from utils import (load_checkpoint, plot_spectrogram_to_numpy, save_checkpoint, 20 | weights_init) 21 | 22 | torch.backends.cudnn.benchmark = True 23 | # Set random seed to make training reproducible 24 | np.random.seed(hparams.seed) 25 | torch.manual_seed(hparams.seed) 26 | torch.cuda.manual_seed(hparams.seed) 27 | 28 | def add_log(writer, loss, l1, l2, l3, steps): 29 | writer.add_scalar("loss", loss, steps) 30 | writer.add_scalar("loss.log_stft", l1, steps) 31 | writer.add_scalar("loss.log_stft_low_freqs", l2, steps) 32 | writer.add_scalar("loss.stft_low_freqs", l3, steps) 33 | # writer.add_scalar("grad_norm", grad_norm, steps) 34 | 35 | def add_spec_sample(writer, mel_target, mel_predicted, steps): 36 | writer.add_image( 37 | "mel_target", 38 | plot_spectrogram_to_numpy(mel_target), 39 | steps) 40 | writer.add_image( 41 | "mel_predicted", 42 | plot_spectrogram_to_numpy(mel_predicted), 43 | steps) 44 | 45 | 46 | def prepare_directories(out_dir, log_dir, checkpoint_dir): 47 | log_dir = os.path.join(out_dir, log_dir) 48 | checkpoint_dir = os.path.join(out_dir, checkpoint_dir) 49 | dirs = [out_dir, log_dir, checkpoint_dir] 50 | for d in dirs: 51 | print('prepare dir: {}'.format(d)) 52 | if not os.path.isdir(d): 53 | os.makedirs(d) 54 | 55 | 56 | def train(args): 57 | print(hparams_debug_string()) 58 | 59 | # prepare logging, checkpoint directories 60 | prepare_directories(args.out_dir, args.log_dir, args.checkpoint_dir) 61 | # create model 62 | model = CNNVocoder( 63 | n_heads=hparams.n_heads, 64 | layer_channels=hparams.layer_channels, 65 | pre_conv_channels=hparams.pre_conv_channels, 66 | pre_residuals=hparams.pre_residuals, 67 | up_residuals=hparams.up_residuals, 68 | post_residuals=hparams.post_residuals 69 | ) 70 | model.apply(weights_init) 71 | model = model.cuda() 72 | 73 | # create optimizer 74 | lr = hparams.lr 75 | optimizer = optim.Adam(model.parameters(), 76 | lr=lr, 77 | weight_decay=hparams.weight_decay) 78 | 79 | dataloader = DataLoader( 80 | MelDataset(hparams.file_list), 81 | batch_size=hparams.batch_size, shuffle=True, 82 | num_workers=hparams.n_workers) 83 | 84 | steps = 0 85 | checkpoint_dir = os.path.join(args.out_dir, args.checkpoint_dir) 86 | log_dir = os.path.join(args.out_dir, args.log_dir) 87 | writer = SummaryWriter(log_dir) 88 | 89 | # load model from checkpoint 90 | if args.checkpoint_path: 91 | model, optimizer, lr, steps = load_checkpoint( 92 | args.checkpoint_path, model, optimizer, warm_start=args.warm_start) 93 | 94 | for i in range(hparams.epochs): 95 | print('Epoch: {}'.format(i)) 96 | for idx, batch in enumerate(dataloader): 97 | steps += 1 98 | wav, spec = batch[0].cuda(), batch[1].cuda() 99 | optimizer.zero_grad() 100 | pre_predict, predict = model(spec) 101 | post_loss, l1, l2, l3 = compute_loss(predict, wav) 102 | loss = post_loss 103 | print('Step: {:8d}, Loss = {:8.4f}, post_loss = {:8.4f}, pre_loss = {:8.4f}'.format(steps, loss, post_loss, post_loss)) 104 | if torch.isnan(loss).item() != 0: 105 | print('nan loss, ignore') 106 | return 107 | loss.backward() 108 | # clip grad norm 109 | grad_norm = clip_grad_norm_( 110 | model.parameters(), hparams.grad_clip_thresh) 111 | optimizer.step() 112 | 113 | # log training 114 | # add_log(writer, loss, p_loss, low_p_loss, phrase_loss, p1, grad_norm, steps) 115 | add_log(writer, loss, l1, l2, l3, steps) 116 | 117 | if steps > 0 and steps % hparams.checkpoint_interval == 0: 118 | checkpoint_path = '{}/checkpoint_{}'.format( 119 | checkpoint_dir, steps) 120 | save_checkpoint(checkpoint_path, lr, 121 | steps, model, optimizer) 122 | 123 | # saving example 124 | idx = np.random.randint(wav.shape[0]) 125 | t1 = wav[idx].data.cpu().numpy() 126 | t2 = predict[idx].data.cpu().numpy() 127 | audio.save_wav( 128 | t2, '{}/generated_{}.wav'.format(checkpoint_dir, steps)) 129 | audio.save_wav( 130 | t1, '{}/target_{}.wav'.format(checkpoint_dir, steps)) 131 | 132 | # add spec sample 133 | # spec_pred = audio.melspectrogram(t2) 134 | # spec_target = audio.melspectrogram(t1) 135 | # add_spec_sample(writer, spec_target, spec_pred, steps) 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('-o', '--out_dir', type=str, default="output", 141 | help='directory to save checkpoints') 142 | parser.add_argument('-l', '--log_dir', type=str, default="logs", 143 | help='log directory ${out_directory}/${log_directory}') 144 | parser.add_argument('--checkpoint_path', type=str, default=None, 145 | help='checkpoint model') 146 | parser.add_argument('-c', '--checkpoint_dir', type=str, default="checkpoints", 147 | required=False, help='checkpoint directory ${out_directory}/${checkpoint_dir}') 148 | 149 | parser.add_argument('--warm_start', action='store_true', 150 | help='load the model only (warm start)') 151 | 152 | args = parser.parse_args() 153 | train(args) 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #copy from https://github.com/NVIDIA/tacotron2/blob/master/plotting_utils.py 2 | import os 3 | import time, sys, math 4 | 5 | import matplotlib 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | import numpy as np 9 | import torch 10 | 11 | 12 | 13 | def save_figure_to_numpy(fig): 14 | # save it to a numpy array. 15 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 16 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 17 | return data 18 | 19 | def plot_spectrogram_to_numpy(spectrogram): 20 | fig, ax = plt.subplots(figsize=(12, 3)) 21 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 22 | interpolation='none') 23 | plt.colorbar(im, ax=ax) 24 | plt.xlabel("Frames") 25 | plt.ylabel("Channels") 26 | plt.tight_layout() 27 | 28 | fig.canvas.draw() 29 | data = save_figure_to_numpy(fig) 30 | plt.close() 31 | return data 32 | 33 | # from https://github.com/NVIDIA/vid2vid/blob/951a52bb38c2aa227533b3731b73f40cbd3843c4/models/networks.py#L17 34 | def weights_init(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Conv') != -1 and hasattr(m, 'weight'): 37 | # m.weight.data.normal_(0.0, 0.02) 38 | torch.nn.init.xavier_uniform_(m.weight) 39 | elif classname.find('BatchNorm') != -1: 40 | m.weight.data.normal_(1.0, 0.02) 41 | m.bias.data.fill_(0) 42 | 43 | def load_checkpoint(checkpoint_path, model, optimizer=None, warm_start=False): 44 | """Loads model from given checkpoint 45 | 46 | Args: 47 | model (nn.Module): vocoder model 48 | optimizer (Optimizer): optimizer 49 | """ 50 | assert os.path.isfile(checkpoint_path) 51 | print("Loading checkpoint from '{}'".format(checkpoint_path)) 52 | ckpt = torch.load(checkpoint_path, map_location='cpu') 53 | lr = ckpt['lr'] 54 | model.load_state_dict(ckpt['model_state']) 55 | if optimizer != None and not warm_start: 56 | print('update optimizer') 57 | optimizer.load_state_dict(ckpt['optim_state']) 58 | steps = ckpt['steps'] + 1 59 | if warm_start: 60 | steps = 0 61 | return model, optimizer, lr, steps 62 | 63 | 64 | def save_checkpoint(filename, lr, steps, model, optimizer): 65 | """Saves model 66 | Args: 67 | filename (string): checkpoint path 68 | lr (float): learning rate 69 | model (nn.Module): vocoder model 70 | optimizer (Optimizer): optimizer 71 | """ 72 | print('Saving checkpoint at step {}'.format(steps)) 73 | torch.save({ 74 | 'steps': steps, 75 | 'lr': lr, 76 | 'model_state': model.state_dict(), 77 | 'optim_state': optimizer.state_dict() 78 | }, filename) 79 | 80 | def time_since(started) : 81 | elapsed = time.time() - started 82 | m = int(elapsed // 60) 83 | s = int(elapsed % 60) 84 | if m >= 60 : 85 | h = int(m // 60) 86 | m = m % 60 87 | return f'{h}h {m}m {s}s' 88 | else : 89 | return f'{m}m {s}s' 90 | --------------------------------------------------------------------------------