├── .gitignore ├── .gitmodules ├── Description.md ├── LICENSE ├── README.md ├── audio.py ├── convert_model.py ├── dataset.py ├── distributions.py ├── docker ├── Dockerfile ├── install_docker └── startscript ├── hparams.py ├── inputs └── sample.wav ├── library └── src │ ├── CMakeLists.txt │ ├── WaveRNNVocoder.cpp │ ├── cxxopts.hpp │ ├── net_impl.cpp │ ├── net_impl.h │ ├── vocoder.cpp │ ├── vocoder.h │ ├── wavernn.cpp │ └── wavernn.h ├── loss_function.py ├── lrschedule.py ├── model.py ├── model_outputs ├── mel-northandsouth_52_f000076.npy ├── mel-northandsouth_52_f000076.wav └── model.bin ├── preprocess.py ├── requirements.txt ├── synthesize.py ├── test_wavernnvocoder.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | .hypothesis/ 46 | .pytest_cache/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | db.sqlite3 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | env.bak/ 89 | venv.bak/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "library/src/pybind11"] 2 | path = library/src/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | -------------------------------------------------------------------------------- /Description.md: -------------------------------------------------------------------------------- 1 | 2 | 1. Baseline vocoder 3 | 4 | 2. Pruning gru1. Target density 10% 5 | 6 | 3. Single GRU layer. 7 | 8 | 4. Resnet depth of 3. 9 | 10 | 5. Training length 5->7. 11 | 12 | 6. Trained on northandsouth. 13 | 14 | 7. Rescaling True. 15 | 16 | 8. Fixed aux net count. 17 | 18 | 9. FC layers 3->2 19 | 20 | 10. Prune FC layers. 21 | 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Gary Wang, Eugene Ingerman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveRNN-Pytorch 2 | This repository is a fork of Fatcord's [Alternative](https://github.com/fatchord/WaveRNN) WaveRNN implementation. 3 | The original model has been significantly simplified to allow real-time synthesis of high fidelity speech. This repository 4 | also contains a C++ library that can be used for real-time speech synthesis on a single CPU core. 5 | 6 | WaveRNN-Pytorch is a **vocoder** - it converts from speech features (i.e. mel spectrograms) to speech sound. On can build a 7 | complete text-to-speech pipeline by using, for example, Tacotron-2 to turn text into speech features, and then use this 8 | vocoder to produce a sound file. 9 | 10 | 11 | # Highlights 12 | * 10 bit quantized wav modeling for higher quality 13 | * Weight pruning for reducing model complexity 14 | * Fast, CPU only, C++ inference library running faster than real time on modern cpu. 15 | * Compressed pruned weight format to make weight files small 16 | * Python bindings for the C++ library 17 | * Can be used with a Tacotron-2 implementation for TTS. 18 | 19 | # Planned 20 | * Real time inference on modern ARM processors (e.g. inference on smartphone for high quality TTS) 21 | 22 | # Audio Samples 23 | * See Wiki 24 | 25 | # Pretrained Checkpoints 26 | * See "model_outputs" directory 27 | 28 | # Requirements 29 | ### Training: 30 | * Python 3 31 | * CUDA >=8.0 32 | * PyTorch >= v1.0 33 | * Python requirements: 34 | >pip install -r requirements.txt 35 | * sudo aptitude install libsoundtouch-dev 36 | 37 | ### C++ library 38 | * cmake, gcc, etc 39 | * Eigen3 development files 40 | > apt-get install libeigen3-dev 41 | * pybind11 https://github.com/pybind/pybind11 42 | 43 | # Installation 44 | Ensure above requirements are met. 45 | 46 | ``` 47 | git clone https://github.com/geneing/WaveRNN-Pytorch.git 48 | cd WaveRNN-Pytorch 49 | pip3 install -r requirements.txt 50 | ``` 51 | 52 | # Build C++ library 53 | ``` 54 | cd library 55 | mkdir build 56 | cd build 57 | cmake ../src 58 | make 59 | cp WaveRNNVocoder*.so python_install_directory 60 | ``` 61 | 62 | # Usage 63 | ## 1. Adjusting Hyperparameters 64 | Before running scripts, one can adjust hyperparameters in `hparams.py`. 65 | 66 | Some hyperparameters that you might want to adjust: 67 | * `input_type` (best performing ones are currently `bits` and `raw`, see `hparams.py` for more details) 68 | * `batch_size` - depends on your GPU memory. For 8GB memory, you should use batch_size=64 69 | * `save_every_step` (checkpoint saving frequency) 70 | * `evaluate_every_step` (evaluation frequency) 71 | * `seq_len_factor` (sequence length of training audio, the longer the more GPU it takes) 72 | 73 | ## 2. Preprocessing 74 | 75 | #### Using TTS preprocessing 76 | If you are planning to use this vocoder together with a TTS network (e.g. Tacotron-2) you should train on exactly the same data. 77 | Each implementation of TTS network uses slightly different definition of "mel-spectrogram". I recommend using TTS preoprocessing. 78 | 79 | This code has been tested with [Tacotron-2](https://github.com/Rayhane-mamah/Tacotron-2) and M-AILABS dataset. Example: 80 | ``` 81 | cd Tacotron-2 82 | python3 preprocess.py --dataset='M-AILABS' --language='en_US' --voice='female' --reader='mary_ann' --merge_books=True --output training_data 83 | ``` 84 | 85 | #### Using WaveRNN-Pytorch preprocessing 86 | If you are using vocoder as standalone library you can use native preprocessing. 87 | This function processes raw wav files into corresponding mel-spectrogram and wav files according to the audio processing hyperparameters. 88 | 89 | Example usage: 90 | ``` 91 | python3 preprocess.py --output_dir training_data /path/to/my/wav/files 92 | ``` 93 | This will process all the `.wav` files in the folder `/path/to/my/wav/files` and save them in the default local directory called `data_dir`. 94 | 95 | ## 3. Training 96 | Start training process. checkpoints are by default stored in the local directory `checkpoints`. 97 | The script will automatically save a checkpoint when terminated by `crtl + c`. 98 | 99 | 100 | Example 1: starting a new model for training from Tacotron-2 data 101 | ``` 102 | python3 train.py --dataset Tacotron training_data 103 | ``` 104 | `training_data` is the directory containing the processed files. 105 | 106 | Example 2: starting a new model for training 107 | ``` 108 | python3 train.py --dataset Audiobooks training_data 109 | ``` 110 | 111 | Example 3: Restoring training from checkpoint 112 | ``` 113 | python3 train.py training_data --checkpoint=checkpoints/checkpoint0010000.pth 114 | ``` 115 | Evaluation `.wav` files and plots are saved in `checkpoints/eval`. 116 | 117 | ## 4. Converting model for C++ library 118 | 119 | First you need to train the model for at least (hparams.start_prune+hparams.prune_steps) steps 120 | to ensure that the model is properly pruned. 121 | 122 | In order to use C++ library you need to convert the trained network to compressed model format. 123 | 124 | ``` 125 | python3 convert_model.py --output-dir model_outputs checkpoints/checkpoint_step000400000.pth 126 | ``` 127 | 128 | Example 1: Use python3 interface to the C++ library 129 | ``` 130 | import WaveRNNVocoder 131 | import numpy as np 132 | 133 | vocoder=WaveRNNVocoder.Vocoder() 134 | vocoder.loadWeights('model_outputs/model.bin') 135 | 136 | mel = np.load(fname) #make sure that mel.shape[0] == hparams.num_mels 137 | wav = vocoder.melToWav(mel) 138 | ``` 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import math 4 | import numpy as np 5 | from scipy import signal 6 | from hparams import hparams 7 | from scipy.io import wavfile 8 | 9 | # r9r9 preprocessing 10 | import lws 11 | 12 | 13 | def load_wav(path): 14 | return librosa.load(path, sr=hparams.sample_rate)[0] 15 | 16 | def save_wav(wav, path): 17 | wav = wav * 32767 / max(0.01, np.max(np.abs(wav))) 18 | wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 19 | 20 | 21 | def preemphasis(x): 22 | from nnmnkwii.preprocessing import preemphasis 23 | return preemphasis(x, hparams.preemphasis) 24 | 25 | 26 | def inv_preemphasis(x): 27 | from nnmnkwii.preprocessing import inv_preemphasis 28 | return inv_preemphasis(x, hparams.preemphasis) 29 | 30 | 31 | def spectrogram(y): 32 | D = _lws_processor().stft(preemphasis(y)).T 33 | S = _amp_to_db(np.abs(D) ** hparams.magnitude_power) - hparams.ref_level_db 34 | return _normalize(S) 35 | 36 | 37 | def inv_spectrogram(spectrogram): 38 | '''Converts spectrogram to waveform using librosa''' 39 | S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear 40 | processor = _lws_processor() 41 | D = processor.run_lws(S.astype(np.float64).T ** (1/hparams.magnitude_power)) 42 | y = processor.istft(D).astype(np.float32) 43 | return inv_preemphasis(y) 44 | 45 | def _stft(y): 46 | if hparams.use_lws: 47 | return _lws_processor(hparams).stft(y).T 48 | else: 49 | return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size, pad_mode='constant') 50 | 51 | # def melspectrogram(y): 52 | # D = _stft(preemphasis(y)) 53 | # S = _amp_to_db(_linear_to_mel(np.abs(D)**hparams.magnitude_power)) - hparams.ref_level_db 54 | # if not hparams.allow_clipping_in_normalization: 55 | # assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 56 | # return _normalize(S) 57 | 58 | def melspectrogram(y): 59 | D = _stft(preemphasis(y)) 60 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db 61 | return _normalize(S) 62 | 63 | def _lws_processor(): 64 | return lws.lws(hparams.win_size, hparams.hop_size, mode="speech") 65 | 66 | 67 | # Conversions: 68 | 69 | 70 | _mel_basis = None 71 | 72 | 73 | def _linear_to_mel(spectrogram): 74 | global _mel_basis 75 | if _mel_basis is None: 76 | _mel_basis = _build_mel_basis() 77 | return np.dot(_mel_basis, spectrogram) 78 | 79 | 80 | def _build_mel_basis(): 81 | if hparams.fmax is not None: 82 | assert hparams.fmax <= hparams.sample_rate // 2 83 | return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, 84 | fmin=hparams.fmin, fmax=hparams.fmax, 85 | n_mels=hparams.num_mels) 86 | 87 | 88 | def _amp_to_db(x): 89 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) 90 | return 20 * np.log10(np.maximum(min_level, x)) 91 | 92 | 93 | def _db_to_amp(x): 94 | return np.power(10.0, x * 0.05) 95 | 96 | 97 | # def _normalize(S): 98 | # return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 99 | # 100 | # 101 | # def _denormalize(S): 102 | # return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 103 | # 104 | 105 | # def _normalize(S): 106 | # if hparams.allow_clipping_in_normalization: 107 | # if hparams.symmetric_mels: 108 | # return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value, 109 | # -hparams.max_abs_value, hparams.max_abs_value) 110 | # else: 111 | # return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value) 112 | # 113 | # assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 114 | # if hparams.symmetric_mels: 115 | # return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value 116 | # else: 117 | # return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)) 118 | # 119 | # def _denormalize(D): 120 | # if hparams.allow_clipping_in_normalization: 121 | # if hparams.symmetric_mels: 122 | # return (((np.clip(D, -hparams.max_abs_value, 123 | # hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) 124 | # + hparams.min_level_db) 125 | # else: 126 | # return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) 127 | # 128 | # if hparams.symmetric_mels: 129 | # return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db) 130 | # else: 131 | # return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) 132 | 133 | def _normalize(S): 134 | # symmetric mels 135 | return 2 * hparams.max_abs_value * ((S - hparams.min_level_db) / -hparams.min_level_db) - hparams.max_abs_value 136 | 137 | def _denormalize(S): 138 | # symmetric mels 139 | return ((S + hparams.max_abs_value) * -hparams.min_level_db) / (2 * hparams.max_abs_value) + hparams.min_level_db 140 | 141 | 142 | # Fatcord's preprocessing 143 | def quantize(x): 144 | """quantize audio signal 145 | 146 | """ 147 | x = np.clip(x, -1., 1.) 148 | quant = ((x + 1.)/2.) * (2**hparams.bits - 1) 149 | return quant.astype(np.int) 150 | 151 | 152 | # testing 153 | def test_everything(): 154 | wav = np.random.randn(12000,) 155 | mel = melspectrogram(wav) 156 | spec = spectrogram(wav) 157 | quant = quantize(wav) 158 | print(wav.shape, mel.shape, spec.shape, quant.shape) 159 | print(quant.max(), quant.min(), mel.max(), mel.min(), spec.max(), spec.min()) -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | """Convert trained model for libwavernn 2 | 3 | usage: convert_model.py [options] 4 | 5 | options: 6 | --output-dir= Output Directory [default: model_outputs] 7 | -h, --help Show this help message and exit 8 | """ 9 | # --mel= Mel file input for testing. 10 | 11 | from docopt import docopt 12 | from model import * 13 | from hparams import hparams as hp 14 | 15 | import struct 16 | import numpy as np 17 | import scipy as sp 18 | 19 | elSize = 4 #change to 2 for fp16 20 | 21 | def compress(W): 22 | N = W.shape[1] 23 | W_nz = W.copy() 24 | W_nz[W_nz!=0]=1 25 | L = W_nz.reshape([-1, N // hp.sparse_group, hp.sparse_group]) 26 | S = L.max(axis=-1) 27 | #convert to compressed index 28 | #compressed representation has position in each row. "255" denotes row end. 29 | (row,col)=np.nonzero(S) 30 | idx=[] 31 | for i in range(S.shape[0]+1): 32 | idx += list(col[row==i]) 33 | idx += [255] 34 | mask = np.repeat(S, hp.sparse_group, axis=1) 35 | idx = np.asarray(idx, dtype='uint8') 36 | return (W[mask!=0], idx) 37 | 38 | def writeCompressed(f, W): 39 | weights, idx = compress(W) 40 | f.write(struct.pack('@i',weights.size)) 41 | f.write(weights.tobytes(order='C')) 42 | f.write(struct.pack('@i',idx.size)) 43 | f.write(idx.tobytes(order='C')) 44 | return 45 | 46 | 47 | def linear_saver(f, layer): 48 | weight = layer.weight.cpu().detach().numpy() 49 | 50 | bias = layer.bias.cpu().detach().numpy() 51 | nrows, ncols = weight.shape 52 | v = struct.pack('@iii', elSize, nrows, ncols) 53 | f.write(v) 54 | writeCompressed(f, weight) 55 | f.write(bias.tobytes(order='C')) 56 | 57 | def conv1d_saver(f, layer): 58 | weight = layer.weight.cpu().detach().numpy() 59 | out_channels, in_channels, nkernel = weight.shape 60 | v = struct.pack('@iiiii', elSize, not(layer.bias is None), in_channels, out_channels, nkernel) 61 | f.write(v) 62 | f.write(weight.tobytes(order='C')) 63 | if not (layer.bias is None ): 64 | bias = layer.bias.cpu().detach().numpy() 65 | f.write(bias.tobytes(order='C')) 66 | return 67 | 68 | def conv2d_saver(f, layer): 69 | weight = layer.weight.cpu().detach().numpy() 70 | assert(weight.shape[0]==weight.shape[1]==weight.shape[2]==1) #handles only specific type used in WaveRNN 71 | weight = weight.squeeze() 72 | nkernel = weight.shape[0] 73 | 74 | v = struct.pack('@ii', elSize, nkernel) 75 | f.write(v) 76 | f.write(weight.tobytes(order='C')) 77 | return 78 | 79 | def batchnorm1d_saver(f, layer): 80 | 81 | v = struct.pack('@iif', elSize, layer.num_features, layer.eps) 82 | f.write(v) 83 | weight=layer.weight.detach().numpy() 84 | bias=layer.bias.detach().numpy() 85 | running_mean = layer.running_mean.detach().numpy() 86 | running_var = layer.running_var.detach().numpy() 87 | 88 | f.write(weight.tobytes(order='C')) 89 | f.write(bias.tobytes(order='C')) 90 | f.write(running_mean.tobytes(order='C')) 91 | f.write(running_var.tobytes(order='C')) 92 | 93 | return 94 | 95 | def gru_saver(f, layer): 96 | weight_ih_l0 = layer.weight_ih_l0.detach().cpu().numpy() 97 | weight_hh_l0 = layer.weight_hh_l0.detach().cpu().numpy() 98 | bias_ih_l0 = layer.bias_ih_l0.detach().cpu().numpy() 99 | bias_hh_l0 = layer.bias_hh_l0.detach().cpu().numpy() 100 | 101 | W_ir,W_iz,W_in=np.vsplit(weight_ih_l0, 3) 102 | W_hr,W_hz,W_hn=np.vsplit(weight_hh_l0, 3) 103 | 104 | b_ir,b_iz,b_in=np.split(bias_ih_l0, 3) 105 | b_hr,b_hz,b_hn=np.split(bias_hh_l0, 3) 106 | 107 | hidden_size, input_size = W_ir.shape 108 | v = struct.pack('@iii', elSize, hidden_size, input_size) 109 | f.write(v) 110 | writeCompressed(f, W_ir) 111 | writeCompressed(f, W_iz) 112 | writeCompressed(f, W_in) 113 | writeCompressed(f, W_hr) 114 | writeCompressed(f, W_hz) 115 | writeCompressed(f, W_hn) 116 | f.write(b_ir.tobytes(order='C')) 117 | f.write(b_iz.tobytes(order='C')) 118 | f.write(b_in.tobytes(order='C')) 119 | f.write(b_hr.tobytes(order='C')) 120 | f.write(b_hz.tobytes(order='C')) 121 | f.write(b_hn.tobytes(order='C')) 122 | return 123 | 124 | def stretch2d_saver(f, layer): 125 | v = struct.pack('@ii', layer.x_scale, layer.y_scale) 126 | f.write(v) 127 | return 128 | 129 | savers = { 'Conv1d':conv1d_saver, 'Conv2d':conv2d_saver, 'BatchNorm1d':batchnorm1d_saver, 'Linear':linear_saver, 'GRU':gru_saver, 'Stretch2d':stretch2d_saver } 130 | layer_enum = { 'Conv1d':1, 'Conv2d':2, 'BatchNorm1d':3, 'Linear':4, 'GRU':5, 'Stretch2d':6 } 131 | 132 | def save_layer(f, layer): 133 | layer_type_name = layer._get_name() 134 | v = struct.pack('@i64s', layer_enum[layer_type_name], layer.__str__().encode() ) 135 | f.write(v) 136 | savers[layer_type_name](f, layer) 137 | return 138 | 139 | def torch_test_gru(model, checkpoint): 140 | x = 1.+1./np.arange(1,513) 141 | h = -3. + 2./np.arange(1,513) 142 | 143 | state=checkpoint['state_dict'] 144 | rnn1 = model.rnn1 145 | 146 | weight_ih_l0 = rnn1.weight_ih_l0.detach().cpu().numpy() 147 | weight_hh_l0 = rnn1.weight_hh_l0.detach().cpu().numpy() 148 | bias_ih_l0 = rnn1.bias_ih_l0.detach().cpu().numpy() 149 | bias_hh_l0 = rnn1.bias_hh_l0.detach().cpu().numpy() 150 | 151 | W_ir,W_iz,W_in=np.vsplit(weight_ih_l0, 3) 152 | W_hr,W_hz,W_hn=np.vsplit(weight_hh_l0, 3) 153 | 154 | b_ir,b_iz,b_in=np.split(bias_ih_l0, 3) 155 | b_hr,b_hz,b_hn=np.split(bias_hh_l0, 3) 156 | 157 | gru_cell = nn.GRUCell(rnn1.input_size, rnn1.hidden_size).cpu() 158 | gru_cell.weight_hh.data = rnn1.weight_hh_l0.cpu().data 159 | gru_cell.weight_ih.data = rnn1.weight_ih_l0.cpu().data 160 | gru_cell.bias_hh.data = rnn1.bias_hh_l0.cpu().data 161 | gru_cell.bias_ih.data = rnn1.bias_ih_l0.cpu().data 162 | 163 | # hx_ref = hx.clone() 164 | # x_ref = x.clone() 165 | #hx_gru = gru_cell(x_ref, hx_ref) 166 | 167 | sigmoid = sp.special.expit 168 | r = sigmoid( np.matmul(W_ir, x).squeeze() + b_ir + np.matmul(W_hr, h).squeeze() + b_hr) 169 | z = sigmoid( np.matmul(W_iz, x).squeeze() + b_iz + np.matmul(W_hz, h).squeeze() + b_hz) 170 | n = np.tanh( np.matmul(W_in, x).squeeze() + b_in + r * (np.matmul(W_hn, h).squeeze() + b_hn)) 171 | hout = (1-z)*n+z*h.squeeze() 172 | print(hout) 173 | 174 | #hx_gru=hx_gru.detach().numpy().squeeze() 175 | #dif = hx_gru-hout 176 | 177 | return 178 | 179 | def torch_test_conv1d( model, checkpoint ): 180 | 181 | #x=np.matmul((1.+1./np.arange(1,81))[:,np.newaxis], (-3 + 2./np.arange(1,101))[np.newaxis,:]) 182 | x=np.matmul((1.+1./np.arange(1,81))[:,np.newaxis], (-3 + 2./np.arange(1,11))[np.newaxis,:]) 183 | 184 | xt=torch.tensor(x[np.newaxis,:,:],dtype=torch.float32) 185 | weight=model.upsample.resnet.conv_in.weight 186 | c = torch.nn.functional.conv1d(torch.tensor(xt).type(torch.FloatTensor), weight) 187 | c1=c.detach().numpy().squeeze() 188 | 189 | 190 | w = weight.detach().numpy() 191 | y = np.zeros([128,6]) 192 | for i1 in range(128): 193 | for i3 in range(6): 194 | y[i1,i3] = (x[:,i3:i3+5]*w[i1,:,:]).sum() 195 | return y 196 | 197 | def torch_test_conv1d_1x( model, checkpoint ): 198 | 199 | #x=np.matmul((1.+1./np.arange(1,81))[:,np.newaxis], (-3 + 2./np.arange(1,101))[np.newaxis,:]) 200 | x=np.matmul((1.+1./np.arange(1,129))[:,np.newaxis], (-3 + 2./np.arange(1,11))[np.newaxis,:]) 201 | 202 | xt=torch.tensor(x[np.newaxis,:,:],dtype=torch.float32) 203 | weight = model.upsample.resnet.layers[0].conv1.weight 204 | c = torch.nn.functional.conv1d(torch.tensor(xt).type(torch.FloatTensor), weight) 205 | c1=c.detach().numpy().squeeze() 206 | 207 | w = weight.detach().numpy() 208 | y = np.zeros([128, 10]) 209 | y = np.matmul(w.squeeze(), x) 210 | return y 211 | 212 | def torch_test_conv2d( model, checkpoint ): 213 | layer = model.upsample.up_layers[1] 214 | x=np.matmul((1.+1./np.arange(1,81))[:,np.newaxis], (-3 + 2./np.arange(1,21))[np.newaxis,:]) 215 | 216 | xt=torch.tensor(x[np.newaxis,:,:],dtype=torch.float32) 217 | xt = xt.unsqueeze(1) 218 | c = layer(xt) 219 | 220 | weight = layer.weight 221 | weight=weight.detach().numpy() 222 | assert(weight.shape[0]==weight.shape[1]==weight.shape[2]==1) 223 | 224 | xt=xt.squeeze() 225 | weight = weight.squeeze() 226 | npad = (weight.size-1)/2 227 | 228 | y = np.zeros(xt.shape) 229 | for i in range(xt.shape[0]): 230 | a = np.pad(xt[i,:], (int(npad), int(npad)), mode='constant') 231 | for j in range(xt.shape[1]): 232 | y[i,j] = np.sum(a[j:j+9]*weight) 233 | 234 | x = xt.squeeze() 235 | weight = weight.squeeze() 236 | 237 | 238 | return 239 | 240 | def torch_test_batchnorm1d( model, checkpoint ): 241 | 242 | layer = model.upsample.resnet.layers[0].batch_norm1 243 | x=np.matmul((1.+1./np.arange(1,129))[:,np.newaxis], (-3 + 2./np.arange(1,11))[np.newaxis,:]) 244 | xt=torch.tensor(x[np.newaxis,:,:],dtype=torch.float32) 245 | 246 | weight=layer.weight.detach().numpy() 247 | bias=layer.bias.detach().numpy() 248 | running_mean = layer.running_mean.detach().numpy() 249 | running_var = layer.running_var.detach().numpy() 250 | 251 | x1=xt.detach().numpy().squeeze() 252 | 253 | mean = np.mean(x1, axis=0) 254 | var = np.var(x1, axis=0) 255 | eps = layer.eps 256 | y = ((x1[:,0]-running_mean)/(np.sqrt(running_var+eps)))*weight+bias 257 | #y = ((x1[:,0]-mean[0])/(np.sqrt(var[0]+eps)))*weight+bias 258 | 259 | c = layer(xt) 260 | return 261 | 262 | def save_resnet_block(f, layers): 263 | for l in layers: 264 | save_layer(f, l.conv1) 265 | save_layer(f, l.batch_norm1) 266 | save_layer(f, l.conv2) 267 | save_layer(f, l.batch_norm2) 268 | 269 | def save_resnet( f, model ): 270 | try: 271 | model.upsample.resnet = model.upsample.resnet1 #temp hack 272 | except: pass 273 | 274 | save_layer(f, model.upsample.resnet.conv_in) 275 | save_layer(f, model.upsample.resnet.batch_norm) 276 | save_resnet_block( f, model.upsample.resnet.layers ) #save the resblock stack 277 | save_layer(f, model.upsample.resnet.conv_out) 278 | save_layer(f, model.upsample.resnet_stretch) 279 | return 280 | 281 | def save_upsample(f, model): 282 | for l in model.upsample.up_layers: 283 | save_layer(f, l) 284 | return 285 | 286 | def save_main(f, model): 287 | save_layer(f, model.I) 288 | save_layer(f, model.rnn1) 289 | save_layer(f, model.fc1) 290 | save_layer(f, model.fc3) 291 | return 292 | 293 | if __name__ == "__main__": 294 | args = docopt(__doc__) 295 | print("Command line args:\n", args) 296 | output_path = args["--output-dir"] 297 | # mel_file = args["--mel"] 298 | 299 | device = torch.device("cpu") 300 | 301 | checkpoint_file_name = args[''] 302 | 303 | # build model, create optimizer 304 | model = build_model().to(device) 305 | checkpoint = torch.load(checkpoint_file_name, map_location=device) 306 | model.load_state_dict(checkpoint["state_dict"]) 307 | model = model.eval() 308 | 309 | # torch_test_gru(model, checkpoint) 310 | # torch_test_conv1d(model, checkpoint) 311 | # torch_test_conv1d_1x(model, checkpoint) 312 | # torch_test_conv2d(model, checkpoint) 313 | # torch_test_batchnorm1d(model, checkpoint) 314 | 315 | # mel = np.load(mel_file) 316 | # mel = mel.astype('float32').T 317 | # v = struct.pack('@ii', mel.shape[0], mel.shape[1]) 318 | # with open(output_path+'/mel.bin', 'wb') as f: 319 | # f.write(v) 320 | # f.write(mel.tobytes(order='C')) 321 | 322 | with open(output_path+'/model.bin','wb') as f: 323 | v = struct.pack('@iiii', hp.res_blocks, len(hp.upsample_factors), np.prod(hp.upsample_factors), hp.pad) 324 | f.write(v) 325 | save_resnet(f, model) 326 | save_upsample(f, model) 327 | save_main(f, model) 328 | 329 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, Dataset 7 | from hparams import hparams as hp 8 | from utils import mulaw_quantize, inv_mulaw_quantize 9 | import pickle 10 | import csv 11 | from audio import quantize 12 | 13 | class AudiobookDataset(Dataset): 14 | def __init__(self, data_path): 15 | self.path = os.path.join(data_path, "") 16 | with open(os.path.join(self.path,'dataset_ids.pkl'), 'rb') as f: 17 | self.metadata = pickle.load(f) 18 | self.mel_path = os.path.join(data_path, "mel") 19 | self.wav_path = os.path.join(data_path, "wav") 20 | self.test_path = os.path.join(data_path, "test") 21 | 22 | def __getitem__(self, index): 23 | file = self.metadata[index] 24 | m = np.load(os.path.join(self.mel_path,'{}.npy'.format(file))) 25 | x = np.load(os.path.join(self.wav_path,'{}.npy'.format(file))) 26 | return m, x 27 | 28 | def __len__(self): 29 | return len(self.metadata) 30 | 31 | class TacotronDataset(Dataset): 32 | def __init__(self, data_path): 33 | self.metadata=[] 34 | self.path = os.path.join(data_path, "") 35 | with open(os.path.join(self.path,'train.txt'), 'r', newline='') as f: 36 | csvreader = csv.reader(f, delimiter='|') 37 | for row in csvreader: 38 | self.metadata.append(row) 39 | 40 | self.mel_path = os.path.join(data_path, "mels") 41 | self.wav_path = os.path.join(data_path, "audio") 42 | self.test_path = os.path.join(data_path, "mels") 43 | 44 | def __getitem__(self, index): 45 | entry = self.metadata[index] 46 | m = np.load(os.path.join(self.mel_path, entry[1])).T 47 | wav = np.load(os.path.join(self.wav_path, entry[0])) 48 | 49 | if hp.input_type == 'raw' or hp.input_type=='mixture': 50 | wav = wav.astype(np.float32) 51 | elif hp.input_type == 'mulaw': 52 | wav = mulaw_quantize(wav, hp.mulaw_quantize_channels).astype(np.int) 53 | elif hp.input_type == 'bits': 54 | wav = quantize(wav).astype(np.int) 55 | else: 56 | raise ValueError("hp.input_type {} not recognized".format(hp.input_type)) 57 | return m, wav 58 | 59 | def __len__(self): 60 | return len(self.metadata) 61 | 62 | class Tacotron2Dataset(Dataset): 63 | def __init__(self, data_path): 64 | self.metadata=[] 65 | self.path = os.path.join(data_path, "") 66 | with open(os.path.join(self.path,'train.txt'), 'r', newline='') as f: 67 | csvreader = csv.reader(f, delimiter='|') 68 | for row in csvreader: 69 | self.metadata.append(row) 70 | 71 | self.mel_path = os.path.join(data_path, "mels") 72 | self.wav_path = os.path.join(data_path, "audio") 73 | self.test_path = os.path.join(data_path, "mels") 74 | 75 | def __getitem__(self, index): 76 | entry = self.metadata[index] 77 | m = np.load(os.path.join(self.mel_path, entry[1])).T 78 | wav = np.load(os.path.join(self.wav_path, entry[0])) 79 | 80 | if hp.input_type == 'raw' or hp.input_type=='mixture': 81 | wav = wav.astype(np.float32) 82 | elif hp.input_type == 'mulaw': 83 | wav = mulaw_quantize(wav, hp.mulaw_quantize_channels).astype(np.int) 84 | elif hp.input_type == 'bits': 85 | wav = quantize(wav).astype(np.int) 86 | else: 87 | raise ValueError("hp.input_type {} not recognized".format(hp.input_type)) 88 | return m, wav 89 | 90 | def __len__(self): 91 | return len(self.metadata) 92 | 93 | 94 | class MozillaTTS(Dataset): 95 | def __init__(self, data_path): 96 | self.metadata=[] 97 | self.path = os.path.join(data_path, "") 98 | with open(os.path.join(self.path,'tts_metadata.csv'), 'r', newline='') as f: 99 | csvreader = csv.reader(f, delimiter='|') 100 | for row in csvreader: 101 | self.metadata.append(row) 102 | 103 | self.mel_path = os.path.join(data_path, "mel") 104 | self.wav_path = os.path.join(data_path, "audio") 105 | self.test_path = os.path.join(data_path, "mel") 106 | 107 | def __getitem__(self, index): 108 | entry = self.metadata[index] 109 | m = np.load(entry[2].strip()) 110 | wav = np.load(entry[1].strip()) 111 | 112 | if hp.input_type == 'raw' or hp.input_type=='mixture': 113 | wav = wav.astype(np.float32) 114 | elif hp.input_type == 'mulaw': 115 | wav = mulaw_quantize(wav, hp.mulaw_quantize_channels).astype(np.int) 116 | elif hp.input_type == 'bits': 117 | wav = quantize(wav).astype(np.int) 118 | else: 119 | raise ValueError("hp.input_type {} not recognized".format(hp.input_type)) 120 | return m, wav 121 | 122 | def __len__(self): 123 | return len(self.metadata) 124 | 125 | 126 | 127 | def raw_collate(batch) : 128 | """collate function used for raw wav forms, such as using beta/guassian/mixture of logistic 129 | """ 130 | 131 | pad = 2 132 | mel_win = hp.seq_len // hp.hop_size + 2 * pad 133 | max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch] 134 | mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] 135 | sig_offsets = [(offset + pad) * hp.hop_size for offset in mel_offsets] 136 | 137 | mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] \ 138 | for i, x in enumerate(batch)] 139 | 140 | coarse = [x[1][sig_offsets[i]:sig_offsets[i] + hp.seq_len + 1] \ 141 | for i, x in enumerate(batch)] 142 | 143 | mels = np.stack(mels).astype(np.float32) 144 | coarse = np.stack(coarse).astype(np.float32) 145 | 146 | mels = torch.FloatTensor(mels) 147 | coarse = torch.FloatTensor(coarse) 148 | 149 | x_input = coarse[:,:hp.seq_len] 150 | 151 | y_coarse = coarse[:, 1:] 152 | 153 | return x_input, mels, y_coarse 154 | 155 | 156 | 157 | def discrete_collate(batch) : 158 | """collate function used for discrete wav output, such as 9-bit, mulaw-discrete, etc. 159 | """ 160 | 161 | pad = 2 162 | mel_win = hp.seq_len // hp.hop_size + 2 * pad 163 | max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch] 164 | mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] 165 | sig_offsets = [(offset + pad) * hp.hop_size for offset in mel_offsets] 166 | 167 | mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] \ 168 | for i, x in enumerate(batch)] 169 | 170 | coarse = [x[1][sig_offsets[i]:sig_offsets[i] + hp.seq_len + 1] \ 171 | for i, x in enumerate(batch)] 172 | 173 | mels = np.stack(mels).astype(np.float32) 174 | coarse = np.stack(coarse).astype(np.int64) 175 | 176 | mels = torch.FloatTensor(mels) 177 | coarse = torch.LongTensor(coarse) 178 | if hp.input_type == 'bits': 179 | x_input = 2 * coarse[:, :hp.seq_len].float() / (2**hp.bits - 1.) - 1. 180 | elif hp.input_type == 'mulaw': 181 | x_input = inv_mulaw_quantize(coarse[:, :hp.seq_len], hp.mulaw_quantize_channels) 182 | 183 | y_coarse = coarse[:, 1:] 184 | 185 | return x_input, mels, y_coarse 186 | 187 | 188 | def no_test_raw_collate(): 189 | import matplotlib.pyplot as plt 190 | from test_utils import plot, plot_spec 191 | data_id_path = "data_dir/" 192 | data_path = "data_dir/" 193 | print(hp.seq_len) 194 | 195 | with open('{}dataset_ids.pkl'.format(data_id_path), 'rb') as f: 196 | dataset_ids = pickle.load(f) 197 | dataset = AudiobookDataset(data_path) 198 | print(len(dataset)) 199 | 200 | data_loader = DataLoader(dataset, collate_fn=raw_collate, batch_size=32, 201 | num_workers=0, shuffle=True) 202 | 203 | x, m, y = next(iter(data_loader)) 204 | print(x.shape, m.shape, y.shape) 205 | plot(x.numpy()[0]) 206 | plot(y.numpy()[0]) 207 | plot_spec(m.numpy()[0]) 208 | 209 | 210 | def test_discrete_collate(): 211 | import matplotlib.pyplot as plt 212 | from test_utils import plot, plot_spec 213 | data_id_path = "data_dir/" 214 | data_path = "data_dir/" 215 | print(hp.seq_len) 216 | 217 | with open('{}dataset_ids.pkl'.format(data_id_path), 'rb') as f: 218 | dataset_ids = pickle.load(f) 219 | dataset = AudiobookDataset(data_path) 220 | print(len(dataset)) 221 | 222 | data_loader = DataLoader(dataset, collate_fn=discrete_collate, batch_size=32, 223 | num_workers=0, shuffle=True) 224 | 225 | x, m, y = next(iter(data_loader)) 226 | print(x.shape, m.shape, y.shape) 227 | plot(x.numpy()[0]) 228 | plot(y.numpy()[0]) 229 | plot_spec(m.numpy()[0]) 230 | 231 | 232 | 233 | def no_test_dataset(): 234 | data_id_path = "data_dir/" 235 | data_path = "data_dir/" 236 | print(hp.seq_len) 237 | dataset = AudiobookDataset(data_path) 238 | -------------------------------------------------------------------------------- /distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.distributions import Beta, Normal 8 | from hparams import hparams as hp 9 | 10 | def sample_from_beta_dist(y_hat): 11 | """ 12 | y_hat (batch_size x seq_len x 2): 13 | 14 | """ 15 | # take exponentional to ensure positive 16 | loc_y = y_hat.exp() 17 | alpha = loc_y[:,:,0].unsqueeze(-1) 18 | beta = loc_y[:,:,1].unsqueeze(-1) 19 | dist = Beta(alpha, beta) 20 | sample = dist.sample() 21 | # rescale sample from [0,1] to [-1, 1] 22 | sample = 2.0*sample-1.0 23 | return sample 24 | 25 | 26 | def beta_mle_loss(y_hat, y, reduce=True): 27 | """y_hat (batch_size x seq_len x 2) 28 | y (batch_size x seq_len x 1) 29 | 30 | """ 31 | # take exponentional to ensure positive 32 | loc_y = y_hat.exp() 33 | alpha = loc_y[:,:,0].unsqueeze(-1) 34 | beta = loc_y[:,:,1].unsqueeze(-1) 35 | dist = Beta(alpha, beta) 36 | # rescale y to be between 37 | y = (y + 1.0)/2.0 38 | # note that we will get inf loss if y == 0 or 1.0 exactly, so we will clip it slightly just in case 39 | y = torch.clamp(y, 1e-5, 0.99999) 40 | # compute logprob 41 | loss = -dist.log_prob(y).squeeze(-1) 42 | if reduce: 43 | return loss.mean() 44 | else: 45 | return loss 46 | 47 | 48 | def log_sum_exp(x): 49 | """ numerically stable log_sum_exp implementation that prevents overflow """ 50 | # TF ordering 51 | axis = len(x.size()) - 1 52 | m, _ = torch.max(x, dim=axis) 53 | m2, _ = torch.max(x, dim=axis, keepdim=True) 54 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 55 | 56 | 57 | def discretized_mix_logistic_loss(y_hat, y, num_classes=256, 58 | log_scale_min=hp.log_scale_min, reduce=True): 59 | """Discretized mixture of logistic distributions loss 60 | 61 | Note that it is assumed that input is scaled to [-1, 1]. 62 | 63 | Args: 64 | y_hat (Tensor): Predicted output (B x T x C) 65 | y (Tensor): Target (B x T x 1). 66 | num_classes (int): Number of classes 67 | log_scale_min (float): Log scale minimum value 68 | reduce (bool): If True, the losses are averaged or summed for each 69 | minibatch. 70 | 71 | Returns 72 | Tensor: loss 73 | """ 74 | y_hat = y_hat.permute(0,2,1) 75 | assert y_hat.dim() == 3 76 | assert y_hat.size(1) % 3 == 0 77 | nr_mix = y_hat.size(1) // 3 78 | 79 | # (B x T x C) 80 | y_hat = y_hat.transpose(1, 2) 81 | 82 | # unpack parameters. (B, T, num_mixtures) x 3 83 | logit_probs = y_hat[:, :, :nr_mix] 84 | means = y_hat[:, :, nr_mix:2 * nr_mix] 85 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 86 | 87 | # B x T x 1 -> B x T x num_mixtures 88 | y = y.expand_as(means) 89 | 90 | centered_y = y - means 91 | inv_stdv = torch.exp(-log_scales) 92 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) 93 | cdf_plus = torch.sigmoid(plus_in) 94 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) 95 | cdf_min = torch.sigmoid(min_in) 96 | 97 | # log probability for edge case of 0 (before scaling) 98 | # equivalent: torch.log(F.sigmoid(plus_in)) 99 | log_cdf_plus = plus_in - F.softplus(plus_in) 100 | 101 | # log probability for edge case of 255 (before scaling) 102 | # equivalent: (1 - F.sigmoid(min_in)).log() 103 | log_one_minus_cdf_min = -F.softplus(min_in) 104 | 105 | # probability for all other cases 106 | cdf_delta = cdf_plus - cdf_min 107 | 108 | mid_in = inv_stdv * centered_y 109 | # log probability in the center of the bin, to be used in extreme cases 110 | # (not actually used in our code) 111 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 112 | 113 | # tf equivalent 114 | """ 115 | log_probs = tf.where(x < -0.999, log_cdf_plus, 116 | tf.where(x > 0.999, log_one_minus_cdf_min, 117 | tf.where(cdf_delta > 1e-5, 118 | tf.log(tf.maximum(cdf_delta, 1e-12)), 119 | log_pdf_mid - np.log(127.5)))) 120 | """ 121 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value 122 | # for num_classes=65536 case? 1e-7? not sure.. 123 | inner_inner_cond = (cdf_delta > 1e-5).float() 124 | 125 | inner_inner_out = inner_inner_cond * \ 126 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ 127 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) 128 | inner_cond = (y > 0.999).float() 129 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 130 | cond = (y < -0.999).float() 131 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 132 | 133 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 134 | 135 | if reduce: 136 | return -torch.sum(log_sum_exp(log_probs)) 137 | else: 138 | return -log_sum_exp(log_probs).unsqueeze(-1) 139 | 140 | 141 | def to_one_hot(tensor, n, fill_with=1.): 142 | # we perform one hot encore with respect to the last axis 143 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 144 | if tensor.is_cuda: 145 | one_hot = one_hot.cuda() 146 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 147 | return one_hot 148 | 149 | 150 | def sample_from_discretized_mix_logistic(y, log_scale_min=hp.log_scale_min): 151 | """ 152 | Sample from discretized mixture of logistic distributions 153 | 154 | Args: 155 | y (Tensor): B x C x T 156 | log_scale_min (float): Log scale minimum value 157 | 158 | Returns: 159 | Tensor: sample in range of [-1, 1]. 160 | """ 161 | assert y.size(1) % 3 == 0 162 | nr_mix = y.size(1) // 3 163 | 164 | # B x T x C 165 | y = y.transpose(1, 2) 166 | logit_probs = y[:, :, :nr_mix] 167 | 168 | # sample mixture indicator from softmax 169 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 170 | temp = logit_probs.data - torch.log(- torch.log(temp)) 171 | _, argmax = temp.max(dim=-1) 172 | 173 | # (B, T) -> (B, T, nr_mix) 174 | one_hot = to_one_hot(argmax, nr_mix) 175 | # select logistic parameters 176 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 177 | log_scales = torch.clamp(torch.sum( 178 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) 179 | # sample from logistic & clip to interval 180 | # we don't actually round to the nearest 8bit value when sampling 181 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) 182 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 183 | 184 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 185 | 186 | return x 187 | 188 | 189 | # add gaussian from clarinet implementation:https://raw.githubusercontent.com/ksw0306/ClariNet/master/loss.py 190 | def gaussian_loss(y_hat, y, log_std_min=-7.0, reduce=True): 191 | """y_hat (batch_size x seq_len x 2) 192 | y (batch_size x seq_len x 1) 193 | """ 194 | assert y_hat.dim() == 3 195 | assert y_hat.size(2) == 2 196 | 197 | mean = y_hat[:, :, :1] 198 | log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) 199 | 200 | log_probs = -0.5 * (- math.log(2.0 * math.pi) - 2. * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) 201 | 202 | if reduce: 203 | return log_probs.squeeze().mean() 204 | else: 205 | return log_probs.squeeze() 206 | 207 | 208 | def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.): 209 | """y_hat (batch_size x seq_len x 2) 210 | y (batch_size x seq_len x 1) 211 | """ 212 | assert y_hat.size(2) == 2 213 | 214 | mean = y_hat[:, :, :1] 215 | log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) 216 | dist = Normal(mean, torch.exp(log_std)) 217 | sample = dist.sample() 218 | sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) 219 | del dist 220 | return sample 221 | 222 | 223 | 224 | 225 | 226 | def test_gaussian(): 227 | 228 | y_hat = torch.rand(16, 120, 2) 229 | y_true = torch.rand(16, 120, 1) 230 | out = sample_from_gaussian(y_hat) 231 | loss = gaussian_loss(y_hat, y_true) 232 | loss_mean = loss.mean() 233 | print(out.shape, loss.shape, loss_mean.item()) -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu18.04 2 | 3 | ARG PYTHON_VERSION=3.6 4 | ENV TZ=America/Los_Angeles 5 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 6 | 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | build-essential \ 9 | cmake ccache \ 10 | git \ 11 | curl \ 12 | vim \ 13 | ca-certificates \ 14 | libjpeg-dev libeigen3-dev libnuma-dev libleveldb-dev liblmdb-dev libopencv-dev libsnappy-dev \ 15 | libpng-dev &&\ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 19 | chmod +x ~/miniconda.sh && \ 20 | ~/miniconda.sh -b -p /opt/conda && \ 21 | rm ~/miniconda.sh && \ 22 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy matplotlib ipython mkl mkl-include cython typing && \ 23 | /opt/conda/bin/conda install -c mingfeima mkldnn && \ 24 | /opt/conda/bin/conda install -y -c pytorch magma-cuda92 25 | #/opt/conda/bin/conda install -c pytorch pytorch 26 | #/opt/conda/bin/conda clean -ya 27 | 28 | RUN /opt/conda/bin/conda install -c anaconda tensorflow-gpu 29 | 30 | ENV PATH /opt/conda/bin:$PATH 31 | # This must be done before pip so that requirements.txt is available 32 | WORKDIR /opt/pytorch 33 | RUN git clone --recursive https://github.com/pytorch/pytorch.git pytorch 34 | WORKDIR /opt/pytorch/pytorch 35 | RUN git checkout v1.0.0 36 | RUN git submodule update --init 37 | 38 | RUN TORCH_CUDA_ARCH_LIST="3.0;3.5;5.2;6.0;6.1" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 39 | CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \ 40 | #python setup.py build && \ 41 | pip install -v . 42 | 43 | WORKDIR /opt/pytorch 44 | RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v . 45 | 46 | RUN pip install -U docopts nnmnkwii>=0.0.11 tqdm tensorboardX keras scikit-learn lws librosa 47 | 48 | WORKDIR /workspace 49 | RUN chmod -R a+w /workspace 50 | #RUN git clone --recursive https://github.com/geneing/WaveRNN-Pytorch.git 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /docker/install_docker: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | sudo apt-get remove docker docker-engine docker.io containerd runc 5 | sudo apt-get update 6 | 7 | sudo apt-get -y install \ 8 | apt-transport-https \ 9 | ca-certificates \ 10 | curl \ 11 | gnupg2 \ 12 | software-properties-common 13 | 14 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 15 | 16 | sudo apt-key fingerprint 0EBFCD88 17 | 18 | sudo add-apt-repository \ 19 | "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ 20 | $(lsb_release -cs) \ 21 | stable" 22 | 23 | sudo apt-get update 24 | 25 | sudo apt-get -y install docker-ce 26 | 27 | #install nvidia-docker 28 | docker volume ls -q -f driver=nvidia-docker | xargs -r -I{} -n1 docker ps -q -a -f volume={} | xargs -r docker rm -f 29 | sudo apt-get purge -y nvidia-docker 30 | 31 | # Add the package repositories 32 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ 33 | sudo apt-key add - 34 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 35 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ 36 | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 37 | sudo apt-get update 38 | 39 | # Install nvidia-docker2 and reload the Docker daemon configuration 40 | sudo apt-get install -y nvidia-docker2 41 | sudo pkill -SIGHUP dockerd 42 | 43 | # Test nvidia-smi with the latest official CUDA image 44 | docker run --runtime=nvidia --rm nvidia/cuda:9.2-base nvidia-smi 45 | -------------------------------------------------------------------------------- /docker/startscript: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo $1 4 | 5 | SD=$PWD/../../ 6 | echo $SD 7 | 8 | CUR=$PWD 9 | echo $CUR 10 | cp ../runscript $CUR/runscript 11 | 12 | docker run --runtime=nvidia -v $CUR:/workspace/ -v $SD/../TrainingData:/workspace/TrainingData/:ro -v $CUR/output:/output/ -e BRANCH=$1 -ti dl_ubuntu /bin/bash -c /workspace/runscript -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Default hyperparameters: 4 | hparams = tf.contrib.training.HParams( 5 | name="WaveRNN", 6 | num_workers=32, 7 | # Input type: 8 | # 1. raw [-1, 1] 9 | # 2. mixture [-1, 1] 10 | # 3. bits [0, 512] 11 | # 4. mulaw[0, mulaw_quantize_channels] 12 | # 13 | input_type='bits', 14 | # 15 | # distribution type, currently supports only 'beta' and 'mixture' 16 | distribution='beta', # or "mixture" 17 | log_scale_min=-32.23619130191664, # = float(np.log(1e-7)) 18 | quantize_channels=65536, # quantize channel used for compute loss for mixture of logistics 19 | # 20 | # for Fatcord's original 9 bit audio, specify the audio bit rate. Note this corresponds to network output 21 | # of size 2**bits, so 9 bits would be 512 output, etc. 22 | bits=10, 23 | # for mu-law 24 | mulaw_quantize_channels=512, 25 | # note: r9r9's deepvoice3 preprocessing is used instead of Fatchord's original. 26 | # -------------- 27 | # audio processing parameters 28 | num_mels=80, 29 | fmin=95, 30 | fmax=7600, 31 | n_fft=2048, 32 | hop_size=200, 33 | win_size=800, 34 | sample_rate=16000, 35 | 36 | min_level_db=-100, 37 | ref_level_db=20, 38 | rescaling=False, 39 | rescaling_max=0.999, 40 | 41 | #Mel and Linear spectrograms normalization/scaling and clipping 42 | signal_normalization = True, #Whether to normalize mel spectrograms to some predefined range (following below parameters) 43 | allow_clipping_in_normalization = True, #Only relevant if mel_normalization = True 44 | symmetric_mels = True, #Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, faster and cleaner convergence) 45 | max_abs_value = 4., #max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not be too big to avoid gradient explosion, 46 | 47 | #Contribution by @begeekmyfriend 48 | #Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude levels. Also allows for better G&L phase reconstruction) 49 | preemphasize = False, #whether to apply filter 50 | preemphasis = 0.97, #filter coefficient. 51 | 52 | magnitude_power=2., #The power of the spectrogram magnitude (1. for energy, 2. for power) 53 | 54 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 55 | # It's preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 56 | # Does not work if n_ffit is not multiple of hop_size!! 57 | use_lws=False, #Only used to set as True if using WaveNet, no difference in performance is observed in either cases. 58 | silence_threshold=2, #silence threshold used for sound trimming for wavenet preprocessing 59 | 60 | 61 | # ---------------- 62 | # 63 | # ---------------- 64 | # model parameters 65 | rnn_dims=256, 66 | fc_dims=128, 67 | pad=2, 68 | # note upsample factors must multiply out to be equal to hop_size, so adjust 69 | # if necessary (i.e 4 x 5 x 10 = 200) 70 | upsample_factors=(4, 5, 10), 71 | compute_dims=64, 72 | res_out_dims=32*2, #aux output is fed into 2 downstream nets 73 | res_blocks=3, 74 | # ---------------- 75 | # 76 | # ---------------- 77 | # training parameters 78 | batch_size=128, 79 | nepochs=5000, 80 | save_every_step=10000, 81 | evaluate_every_step=10000, 82 | # seq_len_factor can be adjusted to increase training sequence length (will increase GPU usage) 83 | seq_len_factor=7, 84 | 85 | grad_norm=10, 86 | # learning rate parameters 87 | initial_learning_rate=1e-3, 88 | lr_schedule_type='noam', # or 'noam' 89 | 90 | # for step learning rate schedule 91 | step_gamma=0.5, 92 | lr_step_interval=15000, 93 | 94 | # sparsification 95 | start_prune=80000, 96 | prune_steps=80000, # 20000 97 | sparsity_target=0.90, 98 | sparsity_target_rnn=0.90, 99 | sparse_group=4, 100 | 101 | adam_beta1=0.9, 102 | adam_beta2=0.999, 103 | adam_eps=1e-8, 104 | amsgrad=False, 105 | weight_decay=0, #1e-5, 106 | fix_learning_rate=None, 107 | # modify if one wants to use a fixed learning rate, else set to None to use noam learning rate 108 | # ----------------- 109 | batch_size_gen=32, 110 | ) 111 | 112 | hparams.seq_len = hparams.seq_len_factor * hparams.hop_size 113 | 114 | # for noam learning rate schedule 115 | hparams.noam_warm_up_steps = 2000 * (hparams.batch_size // 16) 116 | -------------------------------------------------------------------------------- /library/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.0) 2 | 3 | project(WaveRNN) 4 | 5 | set(CMAKE_CXX_STANDARD 11) 6 | 7 | find_package (Eigen3 3.3 REQUIRED NO_MODULE) 8 | include_directories(${EIGEN3_INCLUDE_DIRS}) 9 | 10 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O2 -ffast-math -march=native") 11 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ffast-math -march=native") 12 | 13 | add_executable (vocoder vocoder.cpp) 14 | add_library(wavernn wavernn.cpp net_impl.cpp) 15 | target_link_libraries(vocoder wavernn Eigen3::Eigen) 16 | 17 | add_subdirectory(pybind11) 18 | pybind11_add_module(WaveRNNVocoder WaveRNNVocoder.cpp wavernn.cpp net_impl.cpp) 19 | 20 | -------------------------------------------------------------------------------- /library/src/WaveRNNVocoder.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | #include "net_impl.h" 9 | 10 | namespace py = pybind11; 11 | 12 | typedef Matrixf MatrixPy; 13 | 14 | typedef MatrixPy::Scalar Scalar; 15 | constexpr bool rowMajor = MatrixPy::Flags & Eigen::RowMajorBit; 16 | 17 | class Vocoder { 18 | Model model; 19 | bool isLoaded; 20 | public: 21 | Vocoder() { isLoaded = false; } 22 | void loadWeights( const std::string& fileName ){ 23 | FILE* fd = fopen(fileName.c_str(), "rb"); 24 | if( not fd ){ 25 | throw std::runtime_error("Cannot open file."); 26 | } 27 | model.loadNext(fd); 28 | isLoaded = true; 29 | } 30 | 31 | Vectorf melToWav( Eigen::Ref mels ){ 32 | 33 | if( not isLoaded ){ 34 | throw std::runtime_error("Model hasn't been loaded. Call loadWeights first."); 35 | } 36 | return model.apply(mels); 37 | } 38 | 39 | }; 40 | 41 | PYBIND11_MODULE(WaveRNNVocoder, m){ 42 | m.doc() = "WaveRNN Vocoder"; 43 | 44 | py::class_( m, "Matrix", py::buffer_protocol() ) 45 | .def("__init__", [](MatrixPy &m, py::buffer b) { 46 | typedef Eigen::Stride Strides; 47 | 48 | /* Request a buffer descriptor from Python */ 49 | py::buffer_info info = b.request(); 50 | 51 | /* Some sanity checks ... */ 52 | if (info.format != py::format_descriptor::format()) 53 | throw std::runtime_error("Incompatible format: expected a float32 array!"); 54 | 55 | if (info.ndim != 2) 56 | throw std::runtime_error("Incompatible buffer dimension!"); 57 | 58 | auto strides = Strides( 59 | info.strides[rowMajor ? 0 : 1] / (py::ssize_t)sizeof(Scalar), 60 | info.strides[rowMajor ? 1 : 0] / (py::ssize_t)sizeof(Scalar)); 61 | 62 | auto map = Eigen::Map( 63 | static_cast(info.ptr), info.shape[0], info.shape[1], strides); 64 | 65 | new (&m) MatrixPy(map); 66 | }); 67 | 68 | py::class_( m, "Vocoder") 69 | .def(py::init()) 70 | .def("loadWeights", &Vocoder::loadWeights ) 71 | .def("melToWav", &Vocoder::melToWav ) 72 | ; 73 | } 74 | -------------------------------------------------------------------------------- /library/src/cxxopts.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Copyright (c) 2014, 2015, 2016, 2017 Jarryd Beck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | */ 24 | 25 | #ifndef CXXOPTS_HPP_INCLUDED 26 | #define CXXOPTS_HPP_INCLUDED 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | #ifdef __cpp_lib_optional 42 | #include 43 | #define CXXOPTS_HAS_OPTIONAL 44 | #endif 45 | 46 | #define CXXOPTS__VERSION_MAJOR 2 47 | #define CXXOPTS__VERSION_MINOR 2 48 | #define CXXOPTS__VERSION_PATCH 0 49 | 50 | namespace cxxopts 51 | { 52 | static constexpr struct { 53 | uint8_t major, minor, patch; 54 | } version = { 55 | CXXOPTS__VERSION_MAJOR, 56 | CXXOPTS__VERSION_MINOR, 57 | CXXOPTS__VERSION_PATCH 58 | }; 59 | } 60 | 61 | //when we ask cxxopts to use Unicode, help strings are processed using ICU, 62 | //which results in the correct lengths being computed for strings when they 63 | //are formatted for the help output 64 | //it is necessary to make sure that can be found by the 65 | //compiler, and that icu-uc is linked in to the binary. 66 | 67 | #ifdef CXXOPTS_USE_UNICODE 68 | #include 69 | 70 | namespace cxxopts 71 | { 72 | typedef icu::UnicodeString String; 73 | 74 | inline 75 | String 76 | toLocalString(std::string s) 77 | { 78 | return icu::UnicodeString::fromUTF8(std::move(s)); 79 | } 80 | 81 | class UnicodeStringIterator : public 82 | std::iterator 83 | { 84 | public: 85 | 86 | UnicodeStringIterator(const icu::UnicodeString* string, int32_t pos) 87 | : s(string) 88 | , i(pos) 89 | { 90 | } 91 | 92 | value_type 93 | operator*() const 94 | { 95 | return s->char32At(i); 96 | } 97 | 98 | bool 99 | operator==(const UnicodeStringIterator& rhs) const 100 | { 101 | return s == rhs.s && i == rhs.i; 102 | } 103 | 104 | bool 105 | operator!=(const UnicodeStringIterator& rhs) const 106 | { 107 | return !(*this == rhs); 108 | } 109 | 110 | UnicodeStringIterator& 111 | operator++() 112 | { 113 | ++i; 114 | return *this; 115 | } 116 | 117 | UnicodeStringIterator 118 | operator+(int32_t v) 119 | { 120 | return UnicodeStringIterator(s, i + v); 121 | } 122 | 123 | private: 124 | const icu::UnicodeString* s; 125 | int32_t i; 126 | }; 127 | 128 | inline 129 | String& 130 | stringAppend(String&s, String a) 131 | { 132 | return s.append(std::move(a)); 133 | } 134 | 135 | inline 136 | String& 137 | stringAppend(String& s, int n, UChar32 c) 138 | { 139 | for (int i = 0; i != n; ++i) 140 | { 141 | s.append(c); 142 | } 143 | 144 | return s; 145 | } 146 | 147 | template 148 | String& 149 | stringAppend(String& s, Iterator begin, Iterator end) 150 | { 151 | while (begin != end) 152 | { 153 | s.append(*begin); 154 | ++begin; 155 | } 156 | 157 | return s; 158 | } 159 | 160 | inline 161 | size_t 162 | stringLength(const String& s) 163 | { 164 | return s.length(); 165 | } 166 | 167 | inline 168 | std::string 169 | toUTF8String(const String& s) 170 | { 171 | std::string result; 172 | s.toUTF8String(result); 173 | 174 | return result; 175 | } 176 | 177 | inline 178 | bool 179 | empty(const String& s) 180 | { 181 | return s.isEmpty(); 182 | } 183 | } 184 | 185 | namespace std 186 | { 187 | inline 188 | cxxopts::UnicodeStringIterator 189 | begin(const icu::UnicodeString& s) 190 | { 191 | return cxxopts::UnicodeStringIterator(&s, 0); 192 | } 193 | 194 | inline 195 | cxxopts::UnicodeStringIterator 196 | end(const icu::UnicodeString& s) 197 | { 198 | return cxxopts::UnicodeStringIterator(&s, s.length()); 199 | } 200 | } 201 | 202 | //ifdef CXXOPTS_USE_UNICODE 203 | #else 204 | 205 | namespace cxxopts 206 | { 207 | typedef std::string String; 208 | 209 | template 210 | T 211 | toLocalString(T&& t) 212 | { 213 | return std::forward(t); 214 | } 215 | 216 | inline 217 | size_t 218 | stringLength(const String& s) 219 | { 220 | return s.length(); 221 | } 222 | 223 | inline 224 | String& 225 | stringAppend(String&s, String a) 226 | { 227 | return s.append(std::move(a)); 228 | } 229 | 230 | inline 231 | String& 232 | stringAppend(String& s, size_t n, char c) 233 | { 234 | return s.append(n, c); 235 | } 236 | 237 | template 238 | String& 239 | stringAppend(String& s, Iterator begin, Iterator end) 240 | { 241 | return s.append(begin, end); 242 | } 243 | 244 | template 245 | std::string 246 | toUTF8String(T&& t) 247 | { 248 | return std::forward(t); 249 | } 250 | 251 | inline 252 | bool 253 | empty(const std::string& s) 254 | { 255 | return s.empty(); 256 | } 257 | } 258 | 259 | //ifdef CXXOPTS_USE_UNICODE 260 | #endif 261 | 262 | namespace cxxopts 263 | { 264 | namespace 265 | { 266 | #ifdef _WIN32 267 | const std::string LQUOTE("\'"); 268 | const std::string RQUOTE("\'"); 269 | #else 270 | const std::string LQUOTE("‘"); 271 | const std::string RQUOTE("’"); 272 | #endif 273 | } 274 | 275 | class Value : public std::enable_shared_from_this 276 | { 277 | public: 278 | 279 | virtual ~Value() = default; 280 | 281 | virtual 282 | std::shared_ptr 283 | clone() const = 0; 284 | 285 | virtual void 286 | parse(const std::string& text) const = 0; 287 | 288 | virtual void 289 | parse() const = 0; 290 | 291 | virtual bool 292 | has_default() const = 0; 293 | 294 | virtual bool 295 | is_container() const = 0; 296 | 297 | virtual bool 298 | has_implicit() const = 0; 299 | 300 | virtual std::string 301 | get_default_value() const = 0; 302 | 303 | virtual std::string 304 | get_implicit_value() const = 0; 305 | 306 | virtual std::shared_ptr 307 | default_value(const std::string& value) = 0; 308 | 309 | virtual std::shared_ptr 310 | implicit_value(const std::string& value) = 0; 311 | 312 | virtual bool 313 | is_boolean() const = 0; 314 | }; 315 | 316 | class OptionException : public std::exception 317 | { 318 | public: 319 | OptionException(const std::string& message) 320 | : m_message(message) 321 | { 322 | } 323 | 324 | virtual const char* 325 | what() const noexcept 326 | { 327 | return m_message.c_str(); 328 | } 329 | 330 | private: 331 | std::string m_message; 332 | }; 333 | 334 | class OptionSpecException : public OptionException 335 | { 336 | public: 337 | 338 | OptionSpecException(const std::string& message) 339 | : OptionException(message) 340 | { 341 | } 342 | }; 343 | 344 | class OptionParseException : public OptionException 345 | { 346 | public: 347 | OptionParseException(const std::string& message) 348 | : OptionException(message) 349 | { 350 | } 351 | }; 352 | 353 | class option_exists_error : public OptionSpecException 354 | { 355 | public: 356 | option_exists_error(const std::string& option) 357 | : OptionSpecException(u8"Option " + LQUOTE + option + RQUOTE + u8" already exists") 358 | { 359 | } 360 | }; 361 | 362 | class invalid_option_format_error : public OptionSpecException 363 | { 364 | public: 365 | invalid_option_format_error(const std::string& format) 366 | : OptionSpecException(u8"Invalid option format " + LQUOTE + format + RQUOTE) 367 | { 368 | } 369 | }; 370 | 371 | class option_syntax_exception : public OptionParseException { 372 | public: 373 | option_syntax_exception(const std::string& text) 374 | : OptionParseException(u8"Argument " + LQUOTE + text + RQUOTE + 375 | u8" starts with a - but has incorrect syntax") 376 | { 377 | } 378 | }; 379 | 380 | class option_not_exists_exception : public OptionParseException 381 | { 382 | public: 383 | option_not_exists_exception(const std::string& option) 384 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" does not exist") 385 | { 386 | } 387 | }; 388 | 389 | class missing_argument_exception : public OptionParseException 390 | { 391 | public: 392 | missing_argument_exception(const std::string& option) 393 | : OptionParseException( 394 | u8"Option " + LQUOTE + option + RQUOTE + u8" is missing an argument" 395 | ) 396 | { 397 | } 398 | }; 399 | 400 | class option_requires_argument_exception : public OptionParseException 401 | { 402 | public: 403 | option_requires_argument_exception(const std::string& option) 404 | : OptionParseException( 405 | u8"Option " + LQUOTE + option + RQUOTE + u8" requires an argument" 406 | ) 407 | { 408 | } 409 | }; 410 | 411 | class option_not_has_argument_exception : public OptionParseException 412 | { 413 | public: 414 | option_not_has_argument_exception 415 | ( 416 | const std::string& option, 417 | const std::string& arg 418 | ) 419 | : OptionParseException( 420 | u8"Option " + LQUOTE + option + RQUOTE + 421 | u8" does not take an argument, but argument " + 422 | LQUOTE + arg + RQUOTE + " given" 423 | ) 424 | { 425 | } 426 | }; 427 | 428 | class option_not_present_exception : public OptionParseException 429 | { 430 | public: 431 | option_not_present_exception(const std::string& option) 432 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" not present") 433 | { 434 | } 435 | }; 436 | 437 | class argument_incorrect_type : public OptionParseException 438 | { 439 | public: 440 | argument_incorrect_type 441 | ( 442 | const std::string& arg 443 | ) 444 | : OptionParseException( 445 | u8"Argument " + LQUOTE + arg + RQUOTE + u8" failed to parse" 446 | ) 447 | { 448 | } 449 | }; 450 | 451 | class option_required_exception : public OptionParseException 452 | { 453 | public: 454 | option_required_exception(const std::string& option) 455 | : OptionParseException( 456 | u8"Option " + LQUOTE + option + RQUOTE + u8" is required but not present" 457 | ) 458 | { 459 | } 460 | }; 461 | 462 | namespace values 463 | { 464 | namespace 465 | { 466 | std::basic_regex integer_pattern 467 | ("(-)?(0x)?([0-9a-zA-Z]+)|((0x)?0)"); 468 | std::basic_regex truthy_pattern 469 | ("(t|T)(rue)?"); 470 | std::basic_regex falsy_pattern 471 | ("((f|F)(alse)?)?"); 472 | } 473 | 474 | namespace detail 475 | { 476 | template 477 | struct SignedCheck; 478 | 479 | template 480 | struct SignedCheck 481 | { 482 | template 483 | void 484 | operator()(bool negative, U u, const std::string& text) 485 | { 486 | if (negative) 487 | { 488 | if (u > static_cast(-(std::numeric_limits::min)())) 489 | { 490 | throw argument_incorrect_type(text); 491 | } 492 | } 493 | else 494 | { 495 | if (u > static_cast((std::numeric_limits::max)())) 496 | { 497 | throw argument_incorrect_type(text); 498 | } 499 | } 500 | } 501 | }; 502 | 503 | template 504 | struct SignedCheck 505 | { 506 | template 507 | void 508 | operator()(bool, U, const std::string&) {} 509 | }; 510 | 511 | template 512 | void 513 | check_signed_range(bool negative, U value, const std::string& text) 514 | { 515 | SignedCheck::is_signed>()(negative, value, text); 516 | } 517 | } 518 | 519 | template 520 | R 521 | checked_negate(T&& t, const std::string&, std::true_type) 522 | { 523 | // if we got to here, then `t` is a positive number that fits into 524 | // `R`. So to avoid MSVC C4146, we first cast it to `R`. 525 | // See https://github.com/jarro2783/cxxopts/issues/62 for more details. 526 | return -static_cast(t); 527 | } 528 | 529 | template 530 | T 531 | checked_negate(T&&, const std::string& text, std::false_type) 532 | { 533 | throw argument_incorrect_type(text); 534 | } 535 | 536 | template 537 | void 538 | integer_parser(const std::string& text, T& value) 539 | { 540 | std::smatch match; 541 | std::regex_match(text, match, integer_pattern); 542 | 543 | if (match.length() == 0) 544 | { 545 | throw argument_incorrect_type(text); 546 | } 547 | 548 | if (match.length(4) > 0) 549 | { 550 | value = 0; 551 | return; 552 | } 553 | 554 | using US = typename std::make_unsigned::type; 555 | 556 | constexpr auto umax = (std::numeric_limits::max)(); 557 | constexpr bool is_signed = std::numeric_limits::is_signed; 558 | const bool negative = match.length(1) > 0; 559 | const uint8_t base = match.length(2) > 0 ? 16 : 10; 560 | 561 | auto value_match = match[3]; 562 | 563 | US result = 0; 564 | 565 | for (auto iter = value_match.first; iter != value_match.second; ++iter) 566 | { 567 | US digit = 0; 568 | 569 | if (*iter >= '0' && *iter <= '9') 570 | { 571 | digit = *iter - '0'; 572 | } 573 | else if (base == 16 && *iter >= 'a' && *iter <= 'f') 574 | { 575 | digit = *iter - 'a' + 10; 576 | } 577 | else if (base == 16 && *iter >= 'A' && *iter <= 'F') 578 | { 579 | digit = *iter - 'A' + 10; 580 | } 581 | else 582 | { 583 | throw argument_incorrect_type(text); 584 | } 585 | 586 | if (umax - digit < result * base) 587 | { 588 | throw argument_incorrect_type(text); 589 | } 590 | 591 | result = result * base + digit; 592 | } 593 | 594 | detail::check_signed_range(negative, result, text); 595 | 596 | if (negative) 597 | { 598 | value = checked_negate(result, 599 | text, 600 | std::integral_constant()); 601 | } 602 | else 603 | { 604 | value = result; 605 | } 606 | } 607 | 608 | template 609 | void stringstream_parser(const std::string& text, T& value) 610 | { 611 | std::stringstream in(text); 612 | in >> value; 613 | if (!in) { 614 | throw argument_incorrect_type(text); 615 | } 616 | } 617 | 618 | inline 619 | void 620 | parse_value(const std::string& text, uint8_t& value) 621 | { 622 | integer_parser(text, value); 623 | } 624 | 625 | inline 626 | void 627 | parse_value(const std::string& text, int8_t& value) 628 | { 629 | integer_parser(text, value); 630 | } 631 | 632 | inline 633 | void 634 | parse_value(const std::string& text, uint16_t& value) 635 | { 636 | integer_parser(text, value); 637 | } 638 | 639 | inline 640 | void 641 | parse_value(const std::string& text, int16_t& value) 642 | { 643 | integer_parser(text, value); 644 | } 645 | 646 | inline 647 | void 648 | parse_value(const std::string& text, uint32_t& value) 649 | { 650 | integer_parser(text, value); 651 | } 652 | 653 | inline 654 | void 655 | parse_value(const std::string& text, int32_t& value) 656 | { 657 | integer_parser(text, value); 658 | } 659 | 660 | inline 661 | void 662 | parse_value(const std::string& text, uint64_t& value) 663 | { 664 | integer_parser(text, value); 665 | } 666 | 667 | inline 668 | void 669 | parse_value(const std::string& text, int64_t& value) 670 | { 671 | integer_parser(text, value); 672 | } 673 | 674 | inline 675 | void 676 | parse_value(const std::string& text, bool& value) 677 | { 678 | std::smatch result; 679 | std::regex_match(text, result, truthy_pattern); 680 | 681 | if (!result.empty()) 682 | { 683 | value = true; 684 | return; 685 | } 686 | 687 | std::regex_match(text, result, falsy_pattern); 688 | if (!result.empty()) 689 | { 690 | value = false; 691 | return; 692 | } 693 | 694 | throw argument_incorrect_type(text); 695 | } 696 | 697 | inline 698 | void 699 | parse_value(const std::string& text, std::string& value) 700 | { 701 | value = text; 702 | } 703 | 704 | // The fallback parser. It uses the stringstream parser to parse all types 705 | // that have not been overloaded explicitly. It has to be placed in the 706 | // source code before all other more specialized templates. 707 | template 708 | void 709 | parse_value(const std::string& text, T& value) { 710 | stringstream_parser(text, value); 711 | } 712 | 713 | template 714 | void 715 | parse_value(const std::string& text, std::vector& value) 716 | { 717 | T v; 718 | parse_value(text, v); 719 | value.push_back(v); 720 | } 721 | 722 | #ifdef CXXOPTS_HAS_OPTIONAL 723 | template 724 | void 725 | parse_value(const std::string& text, std::optional& value) 726 | { 727 | T result; 728 | parse_value(text, result); 729 | value = std::move(result); 730 | } 731 | #endif 732 | 733 | template 734 | struct type_is_container 735 | { 736 | static constexpr bool value = false; 737 | }; 738 | 739 | template 740 | struct type_is_container> 741 | { 742 | static constexpr bool value = true; 743 | }; 744 | 745 | template 746 | class abstract_value : public Value 747 | { 748 | using Self = abstract_value; 749 | 750 | public: 751 | abstract_value() 752 | : m_result(std::make_shared()) 753 | , m_store(m_result.get()) 754 | { 755 | } 756 | 757 | abstract_value(T* t) 758 | : m_store(t) 759 | { 760 | } 761 | 762 | virtual ~abstract_value() = default; 763 | 764 | abstract_value(const abstract_value& rhs) 765 | { 766 | if (rhs.m_result) 767 | { 768 | m_result = std::make_shared(); 769 | m_store = m_result.get(); 770 | } 771 | else 772 | { 773 | m_store = rhs.m_store; 774 | } 775 | 776 | m_default = rhs.m_default; 777 | m_implicit = rhs.m_implicit; 778 | m_default_value = rhs.m_default_value; 779 | m_implicit_value = rhs.m_implicit_value; 780 | } 781 | 782 | void 783 | parse(const std::string& text) const 784 | { 785 | parse_value(text, *m_store); 786 | } 787 | 788 | bool 789 | is_container() const 790 | { 791 | return type_is_container::value; 792 | } 793 | 794 | void 795 | parse() const 796 | { 797 | parse_value(m_default_value, *m_store); 798 | } 799 | 800 | bool 801 | has_default() const 802 | { 803 | return m_default; 804 | } 805 | 806 | bool 807 | has_implicit() const 808 | { 809 | return m_implicit; 810 | } 811 | 812 | std::shared_ptr 813 | default_value(const std::string& value) 814 | { 815 | m_default = true; 816 | m_default_value = value; 817 | return shared_from_this(); 818 | } 819 | 820 | std::shared_ptr 821 | implicit_value(const std::string& value) 822 | { 823 | m_implicit = true; 824 | m_implicit_value = value; 825 | return shared_from_this(); 826 | } 827 | 828 | std::string 829 | get_default_value() const 830 | { 831 | return m_default_value; 832 | } 833 | 834 | std::string 835 | get_implicit_value() const 836 | { 837 | return m_implicit_value; 838 | } 839 | 840 | bool 841 | is_boolean() const 842 | { 843 | return std::is_same::value; 844 | } 845 | 846 | const T& 847 | get() const 848 | { 849 | if (m_store == nullptr) 850 | { 851 | return *m_result; 852 | } 853 | else 854 | { 855 | return *m_store; 856 | } 857 | } 858 | 859 | protected: 860 | std::shared_ptr m_result; 861 | T* m_store; 862 | 863 | bool m_default = false; 864 | bool m_implicit = false; 865 | 866 | std::string m_default_value; 867 | std::string m_implicit_value; 868 | }; 869 | 870 | template 871 | class standard_value : public abstract_value 872 | { 873 | public: 874 | using abstract_value::abstract_value; 875 | 876 | std::shared_ptr 877 | clone() const 878 | { 879 | return std::make_shared>(*this); 880 | } 881 | }; 882 | 883 | template <> 884 | class standard_value : public abstract_value 885 | { 886 | public: 887 | ~standard_value() = default; 888 | 889 | standard_value() 890 | { 891 | set_default_and_implicit(); 892 | } 893 | 894 | standard_value(bool* b) 895 | : abstract_value(b) 896 | { 897 | set_default_and_implicit(); 898 | } 899 | 900 | std::shared_ptr 901 | clone() const 902 | { 903 | return std::make_shared>(*this); 904 | } 905 | 906 | private: 907 | 908 | void 909 | set_default_and_implicit() 910 | { 911 | m_default = true; 912 | m_default_value = "false"; 913 | m_implicit = true; 914 | m_implicit_value = "true"; 915 | } 916 | }; 917 | } 918 | 919 | template 920 | std::shared_ptr 921 | value() 922 | { 923 | return std::make_shared>(); 924 | } 925 | 926 | template 927 | std::shared_ptr 928 | value(T& t) 929 | { 930 | return std::make_shared>(&t); 931 | } 932 | 933 | class OptionAdder; 934 | 935 | class OptionDetails 936 | { 937 | public: 938 | OptionDetails 939 | ( 940 | const std::string& short_, 941 | const std::string& long_, 942 | const String& desc, 943 | std::shared_ptr val 944 | ) 945 | : m_short(short_) 946 | , m_long(long_) 947 | , m_desc(desc) 948 | , m_value(val) 949 | , m_count(0) 950 | { 951 | } 952 | 953 | OptionDetails(const OptionDetails& rhs) 954 | : m_desc(rhs.m_desc) 955 | , m_count(rhs.m_count) 956 | { 957 | m_value = rhs.m_value->clone(); 958 | } 959 | 960 | OptionDetails(OptionDetails&& rhs) = default; 961 | 962 | const String& 963 | description() const 964 | { 965 | return m_desc; 966 | } 967 | 968 | const Value& value() const { 969 | return *m_value; 970 | } 971 | 972 | std::shared_ptr 973 | make_storage() const 974 | { 975 | return m_value->clone(); 976 | } 977 | 978 | const std::string& 979 | short_name() const 980 | { 981 | return m_short; 982 | } 983 | 984 | const std::string& 985 | long_name() const 986 | { 987 | return m_long; 988 | } 989 | 990 | private: 991 | std::string m_short; 992 | std::string m_long; 993 | String m_desc; 994 | std::shared_ptr m_value; 995 | int m_count; 996 | }; 997 | 998 | struct HelpOptionDetails 999 | { 1000 | std::string s; 1001 | std::string l; 1002 | String desc; 1003 | bool has_default; 1004 | std::string default_value; 1005 | bool has_implicit; 1006 | std::string implicit_value; 1007 | std::string arg_help; 1008 | bool is_container; 1009 | bool is_boolean; 1010 | }; 1011 | 1012 | struct HelpGroupDetails 1013 | { 1014 | std::string name; 1015 | std::string description; 1016 | std::vector options; 1017 | }; 1018 | 1019 | class OptionValue 1020 | { 1021 | public: 1022 | void 1023 | parse 1024 | ( 1025 | std::shared_ptr details, 1026 | const std::string& text 1027 | ) 1028 | { 1029 | ensure_value(details); 1030 | ++m_count; 1031 | m_value->parse(text); 1032 | } 1033 | 1034 | void 1035 | parse_default(std::shared_ptr details) 1036 | { 1037 | ensure_value(details); 1038 | m_value->parse(); 1039 | } 1040 | 1041 | size_t 1042 | count() const 1043 | { 1044 | return m_count; 1045 | } 1046 | 1047 | template 1048 | const T& 1049 | as() const 1050 | { 1051 | if (m_value == nullptr) { 1052 | throw std::domain_error("No value"); 1053 | } 1054 | 1055 | #ifdef CXXOPTS_NO_RTTI 1056 | return static_cast&>(*m_value).get(); 1057 | #else 1058 | return dynamic_cast&>(*m_value).get(); 1059 | #endif 1060 | } 1061 | 1062 | private: 1063 | void 1064 | ensure_value(std::shared_ptr details) 1065 | { 1066 | if (m_value == nullptr) 1067 | { 1068 | m_value = details->make_storage(); 1069 | } 1070 | } 1071 | 1072 | std::shared_ptr m_value; 1073 | size_t m_count = 0; 1074 | }; 1075 | 1076 | class KeyValue 1077 | { 1078 | public: 1079 | KeyValue(std::string key_, std::string value_) 1080 | : m_key(std::move(key_)) 1081 | , m_value(std::move(value_)) 1082 | { 1083 | } 1084 | 1085 | const 1086 | std::string& 1087 | key() const 1088 | { 1089 | return m_key; 1090 | } 1091 | 1092 | const std::string 1093 | value() const 1094 | { 1095 | return m_value; 1096 | } 1097 | 1098 | template 1099 | T 1100 | as() const 1101 | { 1102 | T result; 1103 | values::parse_value(m_value, result); 1104 | return result; 1105 | } 1106 | 1107 | private: 1108 | std::string m_key; 1109 | std::string m_value; 1110 | }; 1111 | 1112 | class ParseResult 1113 | { 1114 | public: 1115 | 1116 | ParseResult( 1117 | const std::shared_ptr< 1118 | std::unordered_map> 1119 | >, 1120 | std::vector, 1121 | bool allow_unrecognised, 1122 | int&, char**&); 1123 | 1124 | size_t 1125 | count(const std::string& o) const 1126 | { 1127 | auto iter = m_options->find(o); 1128 | if (iter == m_options->end()) 1129 | { 1130 | return 0; 1131 | } 1132 | 1133 | auto riter = m_results.find(iter->second); 1134 | 1135 | return riter->second.count(); 1136 | } 1137 | 1138 | const OptionValue& 1139 | operator[](const std::string& option) const 1140 | { 1141 | auto iter = m_options->find(option); 1142 | 1143 | if (iter == m_options->end()) 1144 | { 1145 | throw option_not_present_exception(option); 1146 | } 1147 | 1148 | auto riter = m_results.find(iter->second); 1149 | 1150 | return riter->second; 1151 | } 1152 | 1153 | const std::vector& 1154 | arguments() const 1155 | { 1156 | return m_sequential; 1157 | } 1158 | 1159 | private: 1160 | 1161 | void 1162 | parse(int& argc, char**& argv); 1163 | 1164 | void 1165 | add_to_option(const std::string& option, const std::string& arg); 1166 | 1167 | bool 1168 | consume_positional(std::string a); 1169 | 1170 | void 1171 | parse_option 1172 | ( 1173 | std::shared_ptr value, 1174 | const std::string& name, 1175 | const std::string& arg = "" 1176 | ); 1177 | 1178 | void 1179 | parse_default(std::shared_ptr details); 1180 | 1181 | void 1182 | checked_parse_arg 1183 | ( 1184 | int argc, 1185 | char* argv[], 1186 | int& current, 1187 | std::shared_ptr value, 1188 | const std::string& name 1189 | ); 1190 | 1191 | const std::shared_ptr< 1192 | std::unordered_map> 1193 | > m_options; 1194 | std::vector m_positional; 1195 | std::vector::iterator m_next_positional; 1196 | std::unordered_set m_positional_set; 1197 | std::unordered_map, OptionValue> m_results; 1198 | 1199 | bool m_allow_unrecognised; 1200 | 1201 | std::vector m_sequential; 1202 | }; 1203 | 1204 | class Options 1205 | { 1206 | typedef std::unordered_map> 1207 | OptionMap; 1208 | public: 1209 | 1210 | Options(std::string program, std::string help_string = "") 1211 | : m_program(std::move(program)) 1212 | , m_help_string(toLocalString(std::move(help_string))) 1213 | , m_custom_help("[OPTION...]") 1214 | , m_positional_help("positional parameters") 1215 | , m_show_positional(false) 1216 | , m_allow_unrecognised(false) 1217 | , m_options(std::make_shared()) 1218 | , m_next_positional(m_positional.end()) 1219 | { 1220 | } 1221 | 1222 | Options& 1223 | positional_help(std::string help_text) 1224 | { 1225 | m_positional_help = std::move(help_text); 1226 | return *this; 1227 | } 1228 | 1229 | Options& 1230 | custom_help(std::string help_text) 1231 | { 1232 | m_custom_help = std::move(help_text); 1233 | return *this; 1234 | } 1235 | 1236 | Options& 1237 | show_positional_help() 1238 | { 1239 | m_show_positional = true; 1240 | return *this; 1241 | } 1242 | 1243 | Options& 1244 | allow_unrecognised_options() 1245 | { 1246 | m_allow_unrecognised = true; 1247 | return *this; 1248 | } 1249 | 1250 | ParseResult 1251 | parse(int& argc, char**& argv); 1252 | 1253 | OptionAdder 1254 | add_options(std::string group = ""); 1255 | 1256 | void 1257 | add_option 1258 | ( 1259 | const std::string& group, 1260 | const std::string& s, 1261 | const std::string& l, 1262 | std::string desc, 1263 | std::shared_ptr value, 1264 | std::string arg_help 1265 | ); 1266 | 1267 | //parse positional arguments into the given option 1268 | void 1269 | parse_positional(std::string option); 1270 | 1271 | void 1272 | parse_positional(std::vector options); 1273 | 1274 | void 1275 | parse_positional(std::initializer_list options); 1276 | 1277 | template 1278 | void 1279 | parse_positional(Iterator begin, Iterator end) { 1280 | parse_positional(std::vector{begin, end}); 1281 | } 1282 | 1283 | std::string 1284 | help(const std::vector& groups = {}) const; 1285 | 1286 | const std::vector 1287 | groups() const; 1288 | 1289 | const HelpGroupDetails& 1290 | group_help(const std::string& group) const; 1291 | 1292 | private: 1293 | 1294 | void 1295 | add_one_option 1296 | ( 1297 | const std::string& option, 1298 | std::shared_ptr details 1299 | ); 1300 | 1301 | String 1302 | help_one_group(const std::string& group) const; 1303 | 1304 | void 1305 | generate_group_help 1306 | ( 1307 | String& result, 1308 | const std::vector& groups 1309 | ) const; 1310 | 1311 | void 1312 | generate_all_groups_help(String& result) const; 1313 | 1314 | std::string m_program; 1315 | String m_help_string; 1316 | std::string m_custom_help; 1317 | std::string m_positional_help; 1318 | bool m_show_positional; 1319 | bool m_allow_unrecognised; 1320 | 1321 | std::shared_ptr m_options; 1322 | std::vector m_positional; 1323 | std::vector::iterator m_next_positional; 1324 | std::unordered_set m_positional_set; 1325 | 1326 | //mapping from groups to help options 1327 | std::map m_help; 1328 | }; 1329 | 1330 | class OptionAdder 1331 | { 1332 | public: 1333 | 1334 | OptionAdder(Options& options, std::string group) 1335 | : m_options(options), m_group(std::move(group)) 1336 | { 1337 | } 1338 | 1339 | OptionAdder& 1340 | operator() 1341 | ( 1342 | const std::string& opts, 1343 | const std::string& desc, 1344 | std::shared_ptr value 1345 | = ::cxxopts::value(), 1346 | std::string arg_help = "" 1347 | ); 1348 | 1349 | private: 1350 | Options& m_options; 1351 | std::string m_group; 1352 | }; 1353 | 1354 | namespace 1355 | { 1356 | constexpr int OPTION_LONGEST = 30; 1357 | constexpr int OPTION_DESC_GAP = 2; 1358 | 1359 | std::basic_regex option_matcher 1360 | ("--([[:alnum:]][-_[:alnum:]]+)(=(.*))?|-([[:alnum:]]+)"); 1361 | 1362 | std::basic_regex option_specifier 1363 | ("(([[:alnum:]]),)?[ ]*([[:alnum:]][-_[:alnum:]]*)?"); 1364 | 1365 | String 1366 | format_option 1367 | ( 1368 | const HelpOptionDetails& o 1369 | ) 1370 | { 1371 | auto& s = o.s; 1372 | auto& l = o.l; 1373 | 1374 | String result = " "; 1375 | 1376 | if (s.size() > 0) 1377 | { 1378 | result += "-" + toLocalString(s) + ","; 1379 | } 1380 | else 1381 | { 1382 | result += " "; 1383 | } 1384 | 1385 | if (l.size() > 0) 1386 | { 1387 | result += " --" + toLocalString(l); 1388 | } 1389 | 1390 | auto arg = o.arg_help.size() > 0 ? toLocalString(o.arg_help) : "arg"; 1391 | 1392 | if (!o.is_boolean) 1393 | { 1394 | if (o.has_implicit) 1395 | { 1396 | result += " [=" + arg + "(=" + toLocalString(o.implicit_value) + ")]"; 1397 | } 1398 | else 1399 | { 1400 | result += " " + arg; 1401 | } 1402 | } 1403 | 1404 | return result; 1405 | } 1406 | 1407 | String 1408 | format_description 1409 | ( 1410 | const HelpOptionDetails& o, 1411 | size_t start, 1412 | size_t width 1413 | ) 1414 | { 1415 | auto desc = o.desc; 1416 | 1417 | if (o.has_default && (!o.is_boolean || o.default_value != "false")) 1418 | { 1419 | desc += toLocalString(" (default: " + o.default_value + ")"); 1420 | } 1421 | 1422 | String result; 1423 | 1424 | auto current = std::begin(desc); 1425 | auto startLine = current; 1426 | auto lastSpace = current; 1427 | 1428 | auto size = size_t{}; 1429 | 1430 | while (current != std::end(desc)) 1431 | { 1432 | if (*current == ' ') 1433 | { 1434 | lastSpace = current; 1435 | } 1436 | 1437 | if (*current == '\n') 1438 | { 1439 | startLine = current + 1; 1440 | lastSpace = startLine; 1441 | } 1442 | else if (size > width) 1443 | { 1444 | if (lastSpace == startLine) 1445 | { 1446 | stringAppend(result, startLine, current + 1); 1447 | stringAppend(result, "\n"); 1448 | stringAppend(result, start, ' '); 1449 | startLine = current + 1; 1450 | lastSpace = startLine; 1451 | } 1452 | else 1453 | { 1454 | stringAppend(result, startLine, lastSpace); 1455 | stringAppend(result, "\n"); 1456 | stringAppend(result, start, ' '); 1457 | startLine = lastSpace + 1; 1458 | } 1459 | size = 0; 1460 | } 1461 | else 1462 | { 1463 | ++size; 1464 | } 1465 | 1466 | ++current; 1467 | } 1468 | 1469 | //append whatever is left 1470 | stringAppend(result, startLine, current); 1471 | 1472 | return result; 1473 | } 1474 | } 1475 | 1476 | inline 1477 | ParseResult::ParseResult 1478 | ( 1479 | const std::shared_ptr< 1480 | std::unordered_map> 1481 | > options, 1482 | std::vector positional, 1483 | bool allow_unrecognised, 1484 | int& argc, char**& argv 1485 | ) 1486 | : m_options(options) 1487 | , m_positional(std::move(positional)) 1488 | , m_next_positional(m_positional.begin()) 1489 | , m_allow_unrecognised(allow_unrecognised) 1490 | { 1491 | parse(argc, argv); 1492 | } 1493 | 1494 | inline 1495 | OptionAdder 1496 | Options::add_options(std::string group) 1497 | { 1498 | return OptionAdder(*this, std::move(group)); 1499 | } 1500 | 1501 | inline 1502 | OptionAdder& 1503 | OptionAdder::operator() 1504 | ( 1505 | const std::string& opts, 1506 | const std::string& desc, 1507 | std::shared_ptr value, 1508 | std::string arg_help 1509 | ) 1510 | { 1511 | std::match_results result; 1512 | std::regex_match(opts.c_str(), result, option_specifier); 1513 | 1514 | if (result.empty()) 1515 | { 1516 | throw invalid_option_format_error(opts); 1517 | } 1518 | 1519 | const auto& short_match = result[2]; 1520 | const auto& long_match = result[3]; 1521 | 1522 | if (!short_match.length() && !long_match.length()) 1523 | { 1524 | throw invalid_option_format_error(opts); 1525 | } else if (long_match.length() == 1 && short_match.length()) 1526 | { 1527 | throw invalid_option_format_error(opts); 1528 | } 1529 | 1530 | auto option_names = [] 1531 | ( 1532 | const std::sub_match& short_, 1533 | const std::sub_match& long_ 1534 | ) 1535 | { 1536 | if (long_.length() == 1) 1537 | { 1538 | return std::make_tuple(long_.str(), short_.str()); 1539 | } 1540 | else 1541 | { 1542 | return std::make_tuple(short_.str(), long_.str()); 1543 | } 1544 | }(short_match, long_match); 1545 | 1546 | m_options.add_option 1547 | ( 1548 | m_group, 1549 | std::get<0>(option_names), 1550 | std::get<1>(option_names), 1551 | desc, 1552 | value, 1553 | std::move(arg_help) 1554 | ); 1555 | 1556 | return *this; 1557 | } 1558 | 1559 | inline 1560 | void 1561 | ParseResult::parse_default(std::shared_ptr details) 1562 | { 1563 | m_results[details].parse_default(details); 1564 | } 1565 | 1566 | inline 1567 | void 1568 | ParseResult::parse_option 1569 | ( 1570 | std::shared_ptr value, 1571 | const std::string& /*name*/, 1572 | const std::string& arg 1573 | ) 1574 | { 1575 | auto& result = m_results[value]; 1576 | result.parse(value, arg); 1577 | 1578 | m_sequential.emplace_back(value->long_name(), arg); 1579 | } 1580 | 1581 | inline 1582 | void 1583 | ParseResult::checked_parse_arg 1584 | ( 1585 | int argc, 1586 | char* argv[], 1587 | int& current, 1588 | std::shared_ptr value, 1589 | const std::string& name 1590 | ) 1591 | { 1592 | if (current + 1 >= argc) 1593 | { 1594 | if (value->value().has_implicit()) 1595 | { 1596 | parse_option(value, name, value->value().get_implicit_value()); 1597 | } 1598 | else 1599 | { 1600 | throw missing_argument_exception(name); 1601 | } 1602 | } 1603 | else 1604 | { 1605 | if (value->value().has_implicit()) 1606 | { 1607 | parse_option(value, name, value->value().get_implicit_value()); 1608 | } 1609 | else 1610 | { 1611 | parse_option(value, name, argv[current + 1]); 1612 | ++current; 1613 | } 1614 | } 1615 | } 1616 | 1617 | inline 1618 | void 1619 | ParseResult::add_to_option(const std::string& option, const std::string& arg) 1620 | { 1621 | auto iter = m_options->find(option); 1622 | 1623 | if (iter == m_options->end()) 1624 | { 1625 | throw option_not_exists_exception(option); 1626 | } 1627 | 1628 | parse_option(iter->second, option, arg); 1629 | } 1630 | 1631 | inline 1632 | bool 1633 | ParseResult::consume_positional(std::string a) 1634 | { 1635 | while (m_next_positional != m_positional.end()) 1636 | { 1637 | auto iter = m_options->find(*m_next_positional); 1638 | if (iter != m_options->end()) 1639 | { 1640 | auto& result = m_results[iter->second]; 1641 | if (!iter->second->value().is_container()) 1642 | { 1643 | if (result.count() == 0) 1644 | { 1645 | add_to_option(*m_next_positional, a); 1646 | ++m_next_positional; 1647 | return true; 1648 | } 1649 | else 1650 | { 1651 | ++m_next_positional; 1652 | continue; 1653 | } 1654 | } 1655 | else 1656 | { 1657 | add_to_option(*m_next_positional, a); 1658 | return true; 1659 | } 1660 | } 1661 | ++m_next_positional; 1662 | } 1663 | 1664 | return false; 1665 | } 1666 | 1667 | inline 1668 | void 1669 | Options::parse_positional(std::string option) 1670 | { 1671 | parse_positional(std::vector{std::move(option)}); 1672 | } 1673 | 1674 | inline 1675 | void 1676 | Options::parse_positional(std::vector options) 1677 | { 1678 | m_positional = std::move(options); 1679 | m_next_positional = m_positional.begin(); 1680 | 1681 | m_positional_set.insert(m_positional.begin(), m_positional.end()); 1682 | } 1683 | 1684 | inline 1685 | void 1686 | Options::parse_positional(std::initializer_list options) 1687 | { 1688 | parse_positional(std::vector(std::move(options))); 1689 | } 1690 | 1691 | inline 1692 | ParseResult 1693 | Options::parse(int& argc, char**& argv) 1694 | { 1695 | ParseResult result(m_options, m_positional, m_allow_unrecognised, argc, argv); 1696 | return result; 1697 | } 1698 | 1699 | inline 1700 | void 1701 | ParseResult::parse(int& argc, char**& argv) 1702 | { 1703 | int current = 1; 1704 | 1705 | int nextKeep = 1; 1706 | 1707 | bool consume_remaining = false; 1708 | 1709 | while (current != argc) 1710 | { 1711 | if (strcmp(argv[current], "--") == 0) 1712 | { 1713 | consume_remaining = true; 1714 | ++current; 1715 | break; 1716 | } 1717 | 1718 | std::match_results result; 1719 | std::regex_match(argv[current], result, option_matcher); 1720 | 1721 | if (result.empty()) 1722 | { 1723 | //not a flag 1724 | 1725 | // but if it starts with a `-`, then it's an error 1726 | if (argv[current][0] == '-' && argv[current][1] != '\0') { 1727 | throw option_syntax_exception(argv[current]); 1728 | } 1729 | 1730 | //if true is returned here then it was consumed, otherwise it is 1731 | //ignored 1732 | if (consume_positional(argv[current])) 1733 | { 1734 | } 1735 | else 1736 | { 1737 | argv[nextKeep] = argv[current]; 1738 | ++nextKeep; 1739 | } 1740 | //if we return from here then it was parsed successfully, so continue 1741 | } 1742 | else 1743 | { 1744 | //short or long option? 1745 | if (result[4].length() != 0) 1746 | { 1747 | const std::string& s = result[4]; 1748 | 1749 | for (std::size_t i = 0; i != s.size(); ++i) 1750 | { 1751 | std::string name(1, s[i]); 1752 | auto iter = m_options->find(name); 1753 | 1754 | if (iter == m_options->end()) 1755 | { 1756 | if (m_allow_unrecognised) 1757 | { 1758 | continue; 1759 | } 1760 | else 1761 | { 1762 | //error 1763 | throw option_not_exists_exception(name); 1764 | } 1765 | } 1766 | 1767 | auto value = iter->second; 1768 | 1769 | if (i + 1 == s.size()) 1770 | { 1771 | //it must be the last argument 1772 | checked_parse_arg(argc, argv, current, value, name); 1773 | } 1774 | else if (value->value().has_implicit()) 1775 | { 1776 | parse_option(value, name, value->value().get_implicit_value()); 1777 | } 1778 | else 1779 | { 1780 | //error 1781 | throw option_requires_argument_exception(name); 1782 | } 1783 | } 1784 | } 1785 | else if (result[1].length() != 0) 1786 | { 1787 | const std::string& name = result[1]; 1788 | 1789 | auto iter = m_options->find(name); 1790 | 1791 | if (iter == m_options->end()) 1792 | { 1793 | if (m_allow_unrecognised) 1794 | { 1795 | // keep unrecognised options in argument list, skip to next argument 1796 | argv[nextKeep] = argv[current]; 1797 | ++nextKeep; 1798 | ++current; 1799 | continue; 1800 | } 1801 | else 1802 | { 1803 | //error 1804 | throw option_not_exists_exception(name); 1805 | } 1806 | } 1807 | 1808 | auto opt = iter->second; 1809 | 1810 | //equals provided for long option? 1811 | if (result[2].length() != 0) 1812 | { 1813 | //parse the option given 1814 | 1815 | parse_option(opt, name, result[3]); 1816 | } 1817 | else 1818 | { 1819 | //parse the next argument 1820 | checked_parse_arg(argc, argv, current, opt, name); 1821 | } 1822 | } 1823 | 1824 | } 1825 | 1826 | ++current; 1827 | } 1828 | 1829 | for (auto& opt : *m_options) 1830 | { 1831 | auto& detail = opt.second; 1832 | auto& value = detail->value(); 1833 | 1834 | auto& store = m_results[detail]; 1835 | 1836 | if(!store.count() && value.has_default()){ 1837 | parse_default(detail); 1838 | } 1839 | } 1840 | 1841 | if (consume_remaining) 1842 | { 1843 | while (current < argc) 1844 | { 1845 | if (!consume_positional(argv[current])) { 1846 | break; 1847 | } 1848 | ++current; 1849 | } 1850 | 1851 | //adjust argv for any that couldn't be swallowed 1852 | while (current != argc) { 1853 | argv[nextKeep] = argv[current]; 1854 | ++nextKeep; 1855 | ++current; 1856 | } 1857 | } 1858 | 1859 | argc = nextKeep; 1860 | 1861 | } 1862 | 1863 | inline 1864 | void 1865 | Options::add_option 1866 | ( 1867 | const std::string& group, 1868 | const std::string& s, 1869 | const std::string& l, 1870 | std::string desc, 1871 | std::shared_ptr value, 1872 | std::string arg_help 1873 | ) 1874 | { 1875 | auto stringDesc = toLocalString(std::move(desc)); 1876 | auto option = std::make_shared(s, l, stringDesc, value); 1877 | 1878 | if (s.size() > 0) 1879 | { 1880 | add_one_option(s, option); 1881 | } 1882 | 1883 | if (l.size() > 0) 1884 | { 1885 | add_one_option(l, option); 1886 | } 1887 | 1888 | //add the help details 1889 | auto& options = m_help[group]; 1890 | 1891 | options.options.emplace_back(HelpOptionDetails{s, l, stringDesc, 1892 | value->has_default(), value->get_default_value(), 1893 | value->has_implicit(), value->get_implicit_value(), 1894 | std::move(arg_help), 1895 | value->is_container(), 1896 | value->is_boolean()}); 1897 | } 1898 | 1899 | inline 1900 | void 1901 | Options::add_one_option 1902 | ( 1903 | const std::string& option, 1904 | std::shared_ptr details 1905 | ) 1906 | { 1907 | auto in = m_options->emplace(option, details); 1908 | 1909 | if (!in.second) 1910 | { 1911 | throw option_exists_error(option); 1912 | } 1913 | } 1914 | 1915 | inline 1916 | String 1917 | Options::help_one_group(const std::string& g) const 1918 | { 1919 | typedef std::vector> OptionHelp; 1920 | 1921 | auto group = m_help.find(g); 1922 | if (group == m_help.end()) 1923 | { 1924 | return ""; 1925 | } 1926 | 1927 | OptionHelp format; 1928 | 1929 | size_t longest = 0; 1930 | 1931 | String result; 1932 | 1933 | if (!g.empty()) 1934 | { 1935 | result += toLocalString(" " + g + " options:\n"); 1936 | } 1937 | 1938 | for (const auto& o : group->second.options) 1939 | { 1940 | if (o.is_container && 1941 | m_positional_set.find(o.l) != m_positional_set.end() && 1942 | !m_show_positional) 1943 | { 1944 | continue; 1945 | } 1946 | 1947 | auto s = format_option(o); 1948 | longest = (std::max)(longest, stringLength(s)); 1949 | format.push_back(std::make_pair(s, String())); 1950 | } 1951 | 1952 | longest = (std::min)(longest, static_cast(OPTION_LONGEST)); 1953 | 1954 | //widest allowed description 1955 | auto allowed = size_t{76} - longest - OPTION_DESC_GAP; 1956 | 1957 | auto fiter = format.begin(); 1958 | for (const auto& o : group->second.options) 1959 | { 1960 | if (o.is_container && 1961 | m_positional_set.find(o.l) != m_positional_set.end() && 1962 | !m_show_positional) 1963 | { 1964 | continue; 1965 | } 1966 | 1967 | auto d = format_description(o, longest + OPTION_DESC_GAP, allowed); 1968 | 1969 | result += fiter->first; 1970 | if (stringLength(fiter->first) > longest) 1971 | { 1972 | result += '\n'; 1973 | result += toLocalString(std::string(longest + OPTION_DESC_GAP, ' ')); 1974 | } 1975 | else 1976 | { 1977 | result += toLocalString(std::string(longest + OPTION_DESC_GAP - 1978 | stringLength(fiter->first), 1979 | ' ')); 1980 | } 1981 | result += d; 1982 | result += '\n'; 1983 | 1984 | ++fiter; 1985 | } 1986 | 1987 | return result; 1988 | } 1989 | 1990 | inline 1991 | void 1992 | Options::generate_group_help 1993 | ( 1994 | String& result, 1995 | const std::vector& print_groups 1996 | ) const 1997 | { 1998 | for (size_t i = 0; i != print_groups.size(); ++i) 1999 | { 2000 | const String& group_help_text = help_one_group(print_groups[i]); 2001 | if (empty(group_help_text)) 2002 | { 2003 | continue; 2004 | } 2005 | result += group_help_text; 2006 | if (i < print_groups.size() - 1) 2007 | { 2008 | result += '\n'; 2009 | } 2010 | } 2011 | } 2012 | 2013 | inline 2014 | void 2015 | Options::generate_all_groups_help(String& result) const 2016 | { 2017 | std::vector all_groups; 2018 | all_groups.reserve(m_help.size()); 2019 | 2020 | for (auto& group : m_help) 2021 | { 2022 | all_groups.push_back(group.first); 2023 | } 2024 | 2025 | generate_group_help(result, all_groups); 2026 | } 2027 | 2028 | inline 2029 | std::string 2030 | Options::help(const std::vector& help_groups) const 2031 | { 2032 | String result = m_help_string + "\nUsage:\n " + 2033 | toLocalString(m_program) + " " + toLocalString(m_custom_help); 2034 | 2035 | if (m_positional.size() > 0 && m_positional_help.size() > 0) { 2036 | result += " " + toLocalString(m_positional_help); 2037 | } 2038 | 2039 | result += "\n\n"; 2040 | 2041 | if (help_groups.size() == 0) 2042 | { 2043 | generate_all_groups_help(result); 2044 | } 2045 | else 2046 | { 2047 | generate_group_help(result, help_groups); 2048 | } 2049 | 2050 | return toUTF8String(result); 2051 | } 2052 | 2053 | inline 2054 | const std::vector 2055 | Options::groups() const 2056 | { 2057 | std::vector g; 2058 | 2059 | std::transform( 2060 | m_help.begin(), 2061 | m_help.end(), 2062 | std::back_inserter(g), 2063 | [] (const std::map::value_type& pair) 2064 | { 2065 | return pair.first; 2066 | } 2067 | ); 2068 | 2069 | return g; 2070 | } 2071 | 2072 | inline 2073 | const HelpGroupDetails& 2074 | Options::group_help(const std::string& group) const 2075 | { 2076 | return m_help.at(group); 2077 | } 2078 | 2079 | } 2080 | 2081 | #endif //CXXOPTS_HPP_INCLUDED 2082 | -------------------------------------------------------------------------------- /library/src/net_impl.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Eugene Ingerman 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "wavernn.h" 17 | #include "net_impl.h" 18 | 19 | Vectorf softmax( const Vectorf& x ) 20 | { 21 | float maxVal = x.maxCoeff(); 22 | Vectorf y = x.array()-maxVal; 23 | 24 | y = Eigen::exp(y.array()); 25 | float sum = y.sum(); 26 | return y.array() / sum; 27 | } 28 | 29 | 30 | void ResBlock::loadNext(FILE *fd) 31 | { 32 | resblock.resize( RES_BLOCKS*4 ); 33 | for(int i=0; i cdf(probabilities.size()); 132 | float uniform_random = static_cast(rnd()) / rnd.max(); 133 | 134 | std::partial_sum(probabilities.data(), probabilities.data()+probabilities.size(), cdf.begin()); 135 | auto it = std::find_if(cdf.cbegin(), cdf.cend(), [uniform_random](float x){ return (x >= uniform_random);}); 136 | int pos = std::distance(cdf.cbegin(), it); 137 | return pos; 138 | } 139 | 140 | inline float invMulawQuantize( float x_mu ) 141 | { 142 | const float mu = MULAW_QUANTIZE_CHANNELS - 1; 143 | float x = (x_mu / mu) * 2.f - 1.f; 144 | x = std::copysign(1.f, x) * (std::exp(std::fabs(x) * std::log1p(mu) ) - 1.f) / mu; 145 | return x; 146 | } 147 | 148 | Vectorf Model::apply(const Matrixf &mels_in) 149 | { 150 | std::vector rnn_shape = rnn1.shape(); 151 | 152 | Matrixf mel_padded = pad(mels_in, header.nPad); 153 | Matrixf mels = upsample.apply(mel_padded); 154 | int indent = header.nPad * header.total_scale; 155 | 156 | mels = mels.block(0,indent, mels.rows(), mels.cols()-2*indent ).eval(); //remove padding added in the previous step 157 | 158 | Matrixf aux = resnet.apply(mel_padded); 159 | 160 | assert(mels.cols() == aux.cols()); 161 | int seq_len = mels.cols(); 162 | int n_aux = aux.rows(); 163 | 164 | Matrixf a1 = aux.block(0, 0, n_aux/2-1, aux.cols()); //we are throwing away the last aux row to keep network input a mulitple of 8. 165 | Matrixf a2 = aux.block(n_aux/2, 0, n_aux/2, aux.cols()); 166 | 167 | Vectorf wav_out(seq_len); //output vector 168 | 169 | Vectorf x = Vectorf::Zero(1); //current sound amplitude 170 | 171 | Vectorf h1 = Vectorf::Zero(rnn_shape[0]); 172 | 173 | for(int i=0; i 5 | #include 6 | #include "wavernn.h" 7 | 8 | const int RES_BLOCKS = 3; 9 | const int UPSAMPLE_LAYERS = 3; 10 | 11 | Vectorf softmax( const Vectorf& x ); 12 | 13 | class ResBlock{ 14 | std::vector resblock; 15 | public: 16 | ResBlock() = default; 17 | void loadNext( FILE* fd ); 18 | Matrixf apply( const Matrixf& x ); 19 | }; 20 | 21 | class Resnet{ 22 | TorchLayer conv_in; 23 | TorchLayer batch_norm; 24 | ResBlock resblock; 25 | TorchLayer conv_out; 26 | TorchLayer stretch2d; //moved stretch2d layer into resnet from upsample as in python code 27 | 28 | public: 29 | Resnet() = default; 30 | void loadNext( FILE* fd ); 31 | Matrixf apply( const Matrixf& x ); 32 | }; 33 | 34 | class UpsampleNetwork{ 35 | std::vector up_layers; 36 | 37 | public: 38 | UpsampleNetwork() = default; 39 | void loadNext( FILE* fd ); 40 | Matrixf apply( const Matrixf& x ); 41 | }; 42 | 43 | class Model{ 44 | 45 | struct Header{ 46 | int num_res_blocks; 47 | int num_upsample; 48 | int total_scale; 49 | int nPad; 50 | }; 51 | 52 | Header header; 53 | 54 | UpsampleNetwork upsample; 55 | Resnet resnet; 56 | TorchLayer I; 57 | TorchLayer rnn1; 58 | TorchLayer fc1; 59 | TorchLayer fc2; 60 | 61 | public: 62 | Model() = default; 63 | void loadNext( FILE* fd ); 64 | Vectorf apply( const Matrixf& x ); 65 | }; 66 | 67 | 68 | #endif // NET_IMPL_H 69 | -------------------------------------------------------------------------------- /library/src/vocoder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 Eugene Ingerman 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | */ 10 | 11 | #include 12 | #include 13 | #include "cxxopts.hpp" 14 | 15 | #include "vocoder.h" 16 | #include "net_impl.h" 17 | #include "wavernn.h" 18 | 19 | using namespace std; 20 | 21 | Matrixf loadMel( FILE *fd ) 22 | { 23 | 24 | struct Header{ 25 | int nRows, nCols; 26 | } header; 27 | fread( &header, sizeof( Header ), 1, fd); 28 | 29 | Matrixf mel( header.nRows, header.nCols ); 30 | fread(mel.data(), sizeof(float), header.nRows*header.nCols, fd); 31 | 32 | return mel; 33 | } 34 | 35 | int main(int argc, char* argv[]) 36 | { 37 | 38 | cxxopts::Options options("vocoder", "WaveRNN based vocoder"); 39 | options.add_options() 40 | ("w,weights", "File with network weights", cxxopts::value()->default_value("")) 41 | ("m,mel", "File with mel inputs", cxxopts::value()->default_value("")) 42 | ; 43 | auto result = options.parse(argc, argv); 44 | 45 | string weights_file = result["weights"].as(); 46 | string mel_file = result["mel"].as(); 47 | 48 | FILE *fdMel = fopen( mel_file.c_str(), "rb"); 49 | Matrixf mel = loadMel( fdMel ); 50 | 51 | 52 | FILE *fd = fopen(weights_file.c_str(), "rb"); 53 | assert(fd); 54 | 55 | Model model; 56 | model.loadNext(fd); 57 | 58 | Vectorf wav = model.apply(mel); 59 | 60 | FILE *fdout = fopen("wavout.bin","wb"); 61 | fwrite(wav.data(), sizeof(float), wav.size(), fdout); 62 | fclose(fdout); 63 | 64 | // TorchLayer I; I.loadNext(fd); 65 | // TorchLayer GRU; GRU.loadNext(fd); 66 | // TorchLayer conv_in; conv_in.loadNext(fd); 67 | // TorchLayer conv_1; conv_1.loadNext(fd); 68 | // TorchLayer conv_2d; conv_2d.loadNext(fd); 69 | // TorchLayer batch_norm; batch_norm.loadNext(fd); 70 | 71 | // Test for linear layer 72 | // Vectorf x(112); 73 | // for(int j=1; j<=112; ++j) 74 | // x(j-1) = 1. + 1./j; 75 | // Vectorf x1, x2; 76 | // x1 = I(x); 77 | 78 | 79 | // Vectorf x(512), hx(512); 80 | 81 | // for(int j=1; j<=512; ++j){ 82 | // x(j-1) = 1. + 1./j; 83 | // hx(j-1) = -3. + 2./j; 84 | // } 85 | 86 | // Vectorf h1 = GRU(x, hx); 87 | 88 | // Matrixf mel(128,10); 89 | // for(int i=0; i 12 | #include 13 | #include 14 | #include "wavernn.h" 15 | 16 | 17 | Matrixf relu( const Matrixf& x){ 18 | return x.array().max(0.f); 19 | //return x.unaryExpr([](float x){return std::max(0.f, x);}); 20 | } 21 | 22 | inline Vectorf sigmoid( const Vectorf& v ) 23 | { 24 | //TODO: optimize this 25 | //maybe use one of these approximations: https://stackoverflow.com/questions/10732027/fast-sigmoid-algorithm 26 | Vectorf y = 1.f / ( 1.f + Eigen::exp( - v.array())); 27 | return y; 28 | } 29 | 30 | inline Vectorf tanh( const Vectorf& v ) 31 | { 32 | //TODO: optimize this 33 | Vectorf y = Eigen::tanh( v.array() ); 34 | return y; 35 | } 36 | 37 | BaseLayer *TorchLayer::loadNext(FILE *fd) 38 | { 39 | TorchLayer::Header header; 40 | fread(&header, sizeof(TorchLayer::Header), 1, fd); 41 | 42 | std::cerr << "Loading:" << header.name << std::endl; 43 | 44 | switch( header.layerType ){ 45 | 46 | case TorchLayer::Header::LayerType::Linear: 47 | { 48 | impl = new LinearLayer(); 49 | impl->loadNext(fd); 50 | return impl; 51 | } 52 | break; 53 | 54 | case TorchLayer::Header::LayerType::GRU: 55 | { 56 | impl = new GRULayer(); 57 | impl->loadNext(fd); 58 | return impl; 59 | } 60 | break; 61 | 62 | case TorchLayer::Header::LayerType::Conv1d: 63 | { 64 | impl = new Conv1dLayer(); 65 | impl->loadNext(fd); 66 | return impl; 67 | } 68 | case TorchLayer::Header::LayerType::Conv2d:{ 69 | impl = new Conv2dLayer(); 70 | impl->loadNext(fd); 71 | return impl; 72 | } 73 | case TorchLayer::Header::LayerType::BatchNorm1d: 74 | { 75 | impl = new BatchNorm1dLayer(); 76 | impl->loadNext(fd); 77 | return impl; 78 | } 79 | case TorchLayer::Header::LayerType::Stretch2d: 80 | { 81 | impl = new Stretch2dLayer(); 82 | impl->loadNext(fd); 83 | return impl; 84 | } 85 | 86 | default: 87 | return nullptr; 88 | } 89 | } 90 | 91 | 92 | LinearLayer* LinearLayer::loadNext(FILE *fd) 93 | { 94 | LinearLayer::Header header; 95 | fread( &header, sizeof(LinearLayer::Header), 1, fd); 96 | assert(header.elSize==4 or header.elSize==2); 97 | 98 | mat.read(fd, header.elSize, header.nRows, header.nCols); //read compressed array 99 | 100 | bias.resize(header.nRows); 101 | fread(bias.data(), header.elSize, header.nRows, fd); 102 | return this; 103 | } 104 | 105 | Vectorf LinearLayer::apply(const Vectorf &x) 106 | { 107 | return (mat*x)+bias; 108 | } 109 | 110 | GRULayer* GRULayer::loadNext(FILE *fd) 111 | { 112 | GRULayer::Header header; 113 | fread( &header, sizeof(GRULayer::Header), 1, fd); 114 | assert(header.elSize==4 or header.elSize==2); 115 | 116 | nRows = header.nHidden; 117 | nCols = header.nInput; 118 | 119 | b_ir.resize(header.nHidden); 120 | b_iz.resize(header.nHidden); 121 | b_in.resize(header.nHidden); 122 | 123 | b_hr.resize(header.nHidden); 124 | b_hz.resize(header.nHidden); 125 | b_hn.resize(header.nHidden); 126 | 127 | 128 | W_ir.read( fd, header.elSize, header.nHidden, header.nInput); 129 | W_iz.read( fd, header.elSize, header.nHidden, header.nInput); 130 | W_in.read( fd, header.elSize, header.nHidden, header.nInput); 131 | 132 | W_hr.read( fd, header.elSize, header.nHidden, header.nHidden); 133 | W_hz.read( fd, header.elSize, header.nHidden, header.nHidden); 134 | W_hn.read( fd, header.elSize, header.nHidden, header.nHidden); 135 | 136 | fread(b_ir.data(), header.elSize, header.nHidden, fd); 137 | fread(b_iz.data(), header.elSize, header.nHidden, fd); 138 | fread(b_in.data(), header.elSize, header.nHidden, fd); 139 | 140 | fread(b_hr.data(), header.elSize, header.nHidden, fd); 141 | fread(b_hz.data(), header.elSize, header.nHidden, fd); 142 | fread(b_hn.data(), header.elSize, header.nHidden, fd); 143 | 144 | return this; 145 | } 146 | 147 | 148 | Vectorf GRULayer::apply(const Vectorf &x, const Vectorf &hx) 149 | { 150 | Vectorf r, z, n, hout; 151 | 152 | r = sigmoid( W_ir*x + b_ir + W_hr*hx + b_hr); 153 | z = sigmoid( W_iz*x + b_iz + W_hz*hx + b_hz); 154 | n = tanh( W_in*x + b_in + (r.array() * (W_hn*hx + b_hn).array()).matrix()); 155 | hout = (1.f-z.array()) * n.array() + z.array() * hx.array(); 156 | return hout; 157 | } 158 | 159 | 160 | Vectorf CompMatrix::operator*(const Vectorf &x) 161 | { 162 | Vectorf y = Vectorf::Zero(nRows); 163 | assert(nCols == x.size()); 164 | 165 | int weightPos = 0; 166 | 167 | const float * __restrict x_ptr = x.data(); 168 | float * __restrict y_ptr = y.data(); 169 | 170 | for(int i=0; i 5 | #include 6 | #include 7 | 8 | using namespace Eigen; 9 | 10 | const int SPARSE_GROUP_SIZE = 4; //When pruning we use groups of 4 to reduce index 11 | const int MULAW_QUANTIZE_CHANNELS = 512; //same as hparams.mulaw_quantize_channels 12 | const uint8_t ROW_END_MARKER = 255; 13 | 14 | typedef Matrix Matrixf; 15 | typedef Matrix Vectorf; 16 | typedef Matrix Vectori8; 17 | 18 | 19 | Matrixf relu( const Matrixf& x); 20 | 21 | class CompMatrix{ 22 | //Vectorf weight; 23 | //Vectori8 index; 24 | float __attribute__((aligned (32))) *weight; 25 | int __attribute__((aligned (32))) *rowIdx; 26 | int8_t __attribute__((aligned (32))) *colIdx; //colIdx gets multiplied by SPARSE_GROUP_SIZE to get the actual position 27 | int nGroups; 28 | 29 | int nRows, nCols; 30 | 31 | void prepData( std::vector& wght, std::vector& idx ) 32 | { 33 | weight = static_cast(aligned_alloc(32, sizeof(float)*wght.size())); 34 | 35 | nGroups = wght.size()/SPARSE_GROUP_SIZE; 36 | rowIdx = static_cast(aligned_alloc(32, sizeof(int)*nGroups)); 37 | colIdx = static_cast(aligned_alloc(32, sizeof(int8_t)*nGroups)); 38 | 39 | std::copy(wght.begin(), wght.end(), weight); 40 | 41 | int row = 0; 42 | int n = 0; 43 | 44 | for(int i=0; i weight(nWeights); 74 | fread(weight.data(), elSize, nWeights, fd); 75 | 76 | fread(&nIndex, sizeof(int), 1, fd); 77 | std::vector index(nIndex); 78 | fread(index.data(), sizeof(uint8_t), nIndex, fd); 79 | prepData( weight, index ); 80 | } 81 | 82 | Vectorf operator*( const Vectorf& x); 83 | }; 84 | 85 | class BaseLayer { 86 | public: 87 | BaseLayer() = default; 88 | virtual BaseLayer* loadNext( FILE* fd ) {assert(0); return nullptr;}; 89 | virtual Matrixf apply( const Matrixf& x){assert(0); return Matrixf();}; 90 | virtual Vectorf apply( const Vectorf& x){assert(0); return Vectorf();}; 91 | virtual Vectorf apply( const Vectorf& x, const Vectorf& h){assert(0); return Vectorf();}; 92 | virtual std::vector shape(void) const { return std::vector(); } 93 | 94 | }; 95 | 96 | //TODO: This should be turned into a proper class factory pattern 97 | class TorchLayer : public BaseLayer { 98 | struct Header{ 99 | //int size; //size of data blob, not including this header 100 | enum class LayerType : int { Conv1d=1, Conv2d=2, BatchNorm1d=3, Linear=4, GRU=5, Stretch2d=6 } layerType; 101 | char name[64]; //layer name for debugging 102 | }; 103 | 104 | BaseLayer* impl; 105 | 106 | public: 107 | BaseLayer* loadNext( FILE* fd ); 108 | 109 | template< typename T> T operator()( const T& x){ return impl->apply( x ); } 110 | template< typename T, typename T2> T operator()( const T& x, const T2& h ){ return impl->apply( x, h );} 111 | virtual std::vector shape( void ) const override { return impl->shape(); } 112 | 113 | virtual Matrixf apply( const Matrixf& x) override { return impl->apply(x); }; 114 | virtual Vectorf apply( const Vectorf& x) override { return impl->apply(x); }; 115 | virtual Vectorf apply( const Vectorf& x, const Vectorf& h) override { return impl->apply(x); }; 116 | 117 | virtual ~TorchLayer(){ 118 | delete impl; 119 | impl=nullptr; 120 | } 121 | }; 122 | 123 | class Conv1dLayer : public TorchLayer{ 124 | struct Header{ 125 | int elSize; //size of each entry in bytes: 4 for float, 2 for fp16. 126 | int useBias; 127 | int inChannels; 128 | int outChannels; 129 | int kernelSize; 130 | }; 131 | 132 | std::vector weight; 133 | Vectorf bias; 134 | 135 | bool hasBias; 136 | int inChannels; 137 | int outChannels; 138 | int nKernel; 139 | public: 140 | Conv1dLayer() = default; 141 | //call TorchLayer loadNext, not derived loadNext 142 | Conv1dLayer* loadNext( FILE* fd ); 143 | Matrixf apply( const Matrixf& x ) override; 144 | virtual std::vector shape( void ) const override { return std::vector({inChannels, outChannels, nKernel}); } 145 | }; 146 | 147 | class Conv2dLayer : public TorchLayer{ 148 | struct Header{ 149 | int elSize; //size of each entry in bytes: 4 for float, 2 for fp16. 150 | int nKernel; //kernel size. special case of conv2d used in WaveRNN 151 | }; 152 | 153 | Vectorf weight; 154 | int nKernel; 155 | 156 | public: 157 | Conv2dLayer() = default; 158 | //call TorchLayer loadNext, not derived loadNext 159 | Conv2dLayer* loadNext( FILE* fd ); 160 | Matrixf apply( const Matrixf& x ) override; 161 | virtual std::vector shape(void) const override { return std::vector({nKernel}); } 162 | }; 163 | 164 | class BatchNorm1dLayer : public TorchLayer{ 165 | struct Header{ 166 | int elSize; //size of each entry in bytes: 4 for float, 2 for fp16. 167 | int inChannels; 168 | float eps; 169 | }; 170 | 171 | Vectorf weight; 172 | Vectorf bias; 173 | Vectorf running_mean; 174 | Vectorf running_var; 175 | float eps; 176 | int nChannels; 177 | 178 | public: 179 | //call TorchLayer loadNext, not derived loadNext 180 | BatchNorm1dLayer* loadNext( FILE* fd ); 181 | 182 | Matrixf apply(const Matrixf &x ) override; 183 | virtual std::vector shape(void) const override { return std::vector({nChannels}); } 184 | }; 185 | 186 | 187 | class LinearLayer : public TorchLayer{ 188 | struct Header{ 189 | int elSize; //size of each entry in bytes: 4 for float, 2 for fp16. 190 | int nRows; 191 | int nCols; 192 | }; 193 | 194 | CompMatrix mat; 195 | Vectorf bias; 196 | int nRows; 197 | int nCols; 198 | 199 | 200 | public: 201 | LinearLayer() = default; 202 | //call TorchLayer loadNext, not derived loadNext 203 | LinearLayer* loadNext( FILE* fd ); 204 | Vectorf apply( const Vectorf& x ) override; 205 | virtual std::vector shape(void) const override { return std::vector({nRows, nCols}); } 206 | }; 207 | 208 | 209 | class GRULayer : public TorchLayer{ 210 | struct Header{ 211 | int elSize; //size of each entry in bytes: 4 for float, 2 for fp16. 212 | int nHidden; 213 | int nInput; 214 | }; 215 | 216 | CompMatrix W_ir,W_iz,W_in; 217 | CompMatrix W_hr,W_hz,W_hn; 218 | Vectorf b_ir,b_iz,b_in; 219 | Vectorf b_hr,b_hz,b_hn; 220 | int nRows; 221 | int nCols; 222 | 223 | public: 224 | GRULayer() = default; 225 | //call TorchLayer loadNext, not derived loadNext 226 | GRULayer* loadNext( FILE* fd ); 227 | Vectorf apply( const Vectorf& x, const Vectorf& hx ) override; 228 | virtual std::vector shape(void) const override { return std::vector({nRows, nCols}); } 229 | }; 230 | 231 | class Stretch2dLayer : public TorchLayer{ 232 | struct Header{ 233 | int x_scale; 234 | int y_scale; 235 | }; 236 | 237 | int x_scale; 238 | int y_scale; 239 | 240 | public: 241 | Stretch2dLayer() = default; 242 | //call TorchLayer loadNext, not derived loadNext 243 | Stretch2dLayer* loadNext( FILE* fd ); 244 | Matrixf apply(const Matrixf &x ) override; 245 | virtual std::vector shape(void) const override { return std::vector({0}); } 246 | }; 247 | 248 | 249 | #endif // WAVERNN_H 250 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def nll_loss(y_hat, y, reduce=True): 6 | y_hat = y_hat.permute(0,2,1) 7 | y = y.squeeze(-1) 8 | loss = F.nll_loss(y_hat, y) 9 | return loss 10 | 11 | def test_loss(): 12 | yhat = torch.rand(16, 100, 54) 13 | y = torch.rand(16, 100, 1) 14 | loss = nll_loss(yhat, y.squeeze(-1)) -------------------------------------------------------------------------------- /lrschedule.py: -------------------------------------------------------------------------------- 1 | # reference: https://raw.githubusercontent.com/r9y9/wavenet_vocoder/master/lrschedule.py 2 | 3 | import numpy as np 4 | 5 | 6 | # https://github.com/tensorflow/tensor2tensor/issues/280#issuecomment-339110329 7 | def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000): 8 | # Noam scheme from tensor2tensor: 9 | warmup_steps = float(warmup_steps) 10 | step = global_step + 1. 11 | lr = init_lr * warmup_steps**0.5 * np.minimum( 12 | step * warmup_steps**-1.5, step**-0.5) 13 | return lr 14 | 15 | 16 | def step_learning_rate_decay(init_lr, global_step, 17 | anneal_rate=0.98, 18 | anneal_interval=30000): 19 | return init_lr * anneal_rate ** (global_step // anneal_interval) 20 | 21 | 22 | def cyclic_cosine_annealing(init_lr, global_step, T, M): 23 | """Cyclic cosine annealing 24 | 25 | https://arxiv.org/pdf/1704.00109.pdf 26 | 27 | Args: 28 | init_lr (float): Initial learning rate 29 | global_step (int): Current iteration number 30 | T (int): Total iteration number (i,e. nepoch) 31 | M (int): Number of ensembles we want 32 | 33 | Returns: 34 | float: Annealed learning rate 35 | """ 36 | TdivM = T // M 37 | return init_lr / 2.0 * (np.cos(np.pi * ((global_step - 1) % TdivM) / TdivM) + 1.0) 38 | 39 | 40 | def test_noam(): 41 | lr = 1e-3 42 | init_lr = 1e-3 43 | for i in range(50000): 44 | print(i, lr) 45 | lr = noam_learning_rate_decay(init_lr, i) 46 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from hparams import hparams as hp 5 | from torch.utils.data import DataLoader, Dataset 6 | from distributions import * 7 | from utils import num_params, mulaw_quantize, inv_mulaw_quantize 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | class ResBlock(nn.Module) : 13 | def __init__(self, dims) : 14 | super().__init__() 15 | self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 16 | self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 17 | self.batch_norm1 = nn.BatchNorm1d(dims) 18 | self.batch_norm2 = nn.BatchNorm1d(dims) 19 | 20 | def forward(self, x) : 21 | residual = x 22 | x = self.conv1(x) 23 | x = self.batch_norm1(x) 24 | x = F.relu(x) 25 | x = self.conv2(x) 26 | x = self.batch_norm2(x) 27 | return x + residual 28 | 29 | class MelResNet(nn.Module) : 30 | def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims) : 31 | super().__init__() 32 | self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=5, bias=False) 33 | self.batch_norm = nn.BatchNorm1d(compute_dims) 34 | self.layers = nn.ModuleList() 35 | for i in range(res_blocks) : 36 | self.layers.append(ResBlock(compute_dims)) 37 | self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) 38 | 39 | def forward(self, x) : 40 | x = self.conv_in(x) 41 | x = self.batch_norm(x) 42 | x = F.relu(x) 43 | for f in self.layers : x = f(x) 44 | x = self.conv_out(x) 45 | return x 46 | 47 | class Stretch2d(nn.Module) : 48 | def __init__(self, x_scale, y_scale) : 49 | super().__init__() 50 | self.x_scale = x_scale 51 | self.y_scale = y_scale 52 | 53 | def forward(self, x) : 54 | b, c, h, w = x.size() 55 | x = x.unsqueeze(-1).unsqueeze(3) 56 | x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) 57 | return x.view(b, c, h * self.y_scale, w * self.x_scale) 58 | 59 | class UpsampleNetwork(nn.Module) : 60 | def __init__(self, feat_dims, upsample_scales, compute_dims, 61 | res_blocks, res_out_dims, pad) : 62 | super().__init__() 63 | total_scale = np.cumproduct(upsample_scales)[-1] 64 | self.indent = pad * total_scale 65 | self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims) 66 | self.resnet_stretch = Stretch2d(total_scale, 1) 67 | self.up_layers = nn.ModuleList() 68 | for scale in upsample_scales : 69 | k_size = (1, scale * 2 + 1) 70 | padding = (0, scale) 71 | stretch = Stretch2d(scale, 1) 72 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) 73 | conv.weight.data.fill_(1. / k_size[1]) 74 | self.up_layers.append(stretch) 75 | self.up_layers.append(conv) 76 | 77 | def forward(self, m) : 78 | aux = self.resnet(m).unsqueeze(1) 79 | aux = self.resnet_stretch(aux) 80 | aux = aux.squeeze(1) 81 | m = m.unsqueeze(1) 82 | for f in self.up_layers: 83 | m = f(m) 84 | m = m.squeeze(1)[:, :, self.indent:-self.indent] 85 | return m.transpose(1, 2), aux.transpose(1, 2) 86 | 87 | 88 | class Model(nn.Module) : 89 | def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, 90 | feat_dims, compute_dims, res_out_dims, res_blocks): 91 | super().__init__() 92 | if hp.input_type == 'raw': 93 | self.n_classes = 2 94 | elif hp.input_type == 'mixture': 95 | # mixture requires multiple of 3, default at 10 component mixture, i.e 3 x 10 = 30 96 | self.n_classes = 30 97 | elif hp.input_type == 'mulaw': 98 | self.n_classes = hp.mulaw_quantize_channels 99 | elif hp.input_type == 'bits': 100 | self.n_classes = 2**bits 101 | else: 102 | raise ValueError("input_type: {hp.input_type} not supported") 103 | self.rnn_dims = rnn_dims 104 | self.aux_dims = res_out_dims // 2 105 | self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, 106 | res_blocks, res_out_dims, pad) 107 | self.I = nn.Linear(feat_dims + self.aux_dims - 1 + 1, rnn_dims) #First dimension has to be divizible by 8, so we take away one aux channel 108 | self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) 109 | 110 | self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) 111 | #self.fc2 = nn.Linear(fc_dims, fc_dims) 112 | self.fc3 = nn.Linear(fc_dims, self.n_classes) 113 | num_params(self) 114 | 115 | def forward(self, x, mels) : 116 | bsize = x.size(0) 117 | h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() 118 | 119 | mels, aux = self.upsample(mels) 120 | 121 | aux_idx = [self.aux_dims * i for i in range(3)] 122 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]] 123 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]] 124 | 125 | x = torch.cat([x.unsqueeze(-1), mels, a1[:,:,:-1]], dim=2) 126 | x = self.I(x) 127 | res = x 128 | x, _ = self.rnn1(x, h1) 129 | x = x + res 130 | 131 | x = torch.cat([x, a2], dim=2) 132 | x = F.relu(self.fc1(x)) 133 | #x = F.relu(self.fc2(x)) 134 | 135 | x = self.fc3(x) 136 | 137 | if hp.input_type == 'raw': 138 | return x 139 | elif hp.input_type == 'mixture': 140 | return x 141 | elif hp.input_type == 'bits' or hp.input_type == 'mulaw': 142 | return F.log_softmax(x, dim=-1) 143 | else: 144 | raise ValueError("input_type: {hp.input_type} not supported") 145 | 146 | 147 | def preview_upsampling(self, mels) : 148 | mels, aux = self.upsample(mels) 149 | return mels, aux 150 | 151 | # def generate(self, mels) : 152 | # self.eval() 153 | # output = [] 154 | # rnn1 = self.get_gru_cell(self.rnn1) 155 | # rnn2 = self.get_gru_cell(self.rnn2) 156 | # 157 | # with torch.no_grad() : 158 | # x = torch.zeros(1, 1) 159 | # h1 = torch.zeros(1, self.rnn_dims) 160 | # h2 = torch.zeros(1, self.rnn_dims) 161 | # 162 | # mels = torch.FloatTensor(mels).unsqueeze(0) 163 | # mels, aux = self.upsample(mels) 164 | # 165 | # aux_idx = [self.aux_dims * i for i in range(5)] 166 | # a1 = aux[:, :, aux_idx[0]:aux_idx[1]] 167 | # a2 = aux[:, :, aux_idx[1]:aux_idx[2]] 168 | # a3 = aux[:, :, aux_idx[2]:aux_idx[3]] 169 | # a4 = aux[:, :, aux_idx[3]:aux_idx[4]] 170 | # 171 | # seq_len = mels.size(1) 172 | # 173 | # for i in tqdm(range(seq_len)) : 174 | # 175 | # m_t = mels[:, i, :] 176 | # a1_t = a1[:, i, :] 177 | # a2_t = a2[:, i, :] 178 | # a3_t = a3[:, i, :] 179 | # a4_t = a4[:, i, :] 180 | # 181 | # x = torch.cat([x, m_t, a1_t], dim=1) 182 | # x = self.I(x) 183 | # h1 = rnn1(x, h1) 184 | # 185 | # x = x + h1 186 | # inp = torch.cat([x, a2_t], dim=1) 187 | # h2 = rnn2(inp, h2) 188 | # 189 | # x = x + h2 190 | # x = torch.cat([x, a3_t], dim=1) 191 | # x = F.relu(self.fc1(x)) 192 | # 193 | # x = torch.cat([x, a4_t], dim=1) 194 | # x = F.relu(self.fc2(x)) 195 | # x = self.fc3(x) 196 | # if hp.input_type == 'raw': 197 | # if hp.distribution == 'beta': 198 | # sample = sample_from_beta_dist(x.unsqueeze(0)) 199 | # elif hp.distribution == 'gaussian': 200 | # sample = sample_from_gaussian(x.unsqueeze(0)) 201 | # elif hp.input_type == 'mixture': 202 | # sample = sample_from_discretized_mix_logistic(x.unsqueeze(-1),hp.log_scale_min) 203 | # elif hp.input_type == 'bits': 204 | # posterior = F.softmax(x, dim=1).view(-1) 205 | # distrib = torch.distributions.Categorical(posterior) 206 | # sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. 207 | # elif hp.input_type == 'mulaw': 208 | # posterior = F.softmax(x, dim=1).view(-1) 209 | # distrib = torch.distributions.Categorical(posterior) 210 | # sample = inv_mulaw_quantize(distrib.sample(), hp.mulaw_quantize_channels, True) 211 | # output.append(sample.view(-1)) 212 | # x = torch.FloatTensor([[sample]]) 213 | # output = torch.stack(output).cpu().numpy() 214 | # self.train() 215 | # return output 216 | 217 | 218 | def pad_tensor(self, x, pad, side='both') : 219 | # NB - this is just a quick method i need right now 220 | # i.e., it won't generalise to other shapes/dims 221 | b, t, c = x.size() 222 | total = t + 2 * pad if side == 'both' else t + pad 223 | padded = torch.zeros(b, total, c).cuda() 224 | if side == 'before' or side == 'both' : 225 | padded[:, pad:pad+t, :] = x 226 | elif side == 'after': 227 | padded[:, :t, :] = x 228 | return padded 229 | 230 | 231 | def fold_with_overlap(self, x, target, overlap) : 232 | 233 | ''' Fold the tensor with overlap for quick batched inference. 234 | Overlap will be used for crossfading in xfade_and_unfold() 235 | 236 | Args: 237 | x (tensor) : Upsampled conditioning features. 238 | shape=(1, timesteps, features) 239 | target (int) : Target timesteps for each index of batch 240 | overlap (int) : Timesteps for both xfade and rnn warmup 241 | 242 | Return: 243 | (tensor) : shape=(num_folds, target + 2 * overlap, features) 244 | 245 | Details: 246 | x = [[h1, h2, ... hn]] 247 | 248 | Where each h is a vector of conditioning features 249 | 250 | Eg: target=2, overlap=1 with x.size(1)=10 251 | 252 | folded = [[h1, h2, h3, h4], 253 | [h4, h5, h6, h7], 254 | [h7, h8, h9, h10]] 255 | ''' 256 | 257 | _, total_len, features = x.size() 258 | 259 | # Calculate variables needed 260 | num_folds = (total_len - overlap) // (target + overlap) 261 | extended_len = num_folds * (overlap + target) + overlap 262 | remaining = total_len - extended_len 263 | 264 | # Pad if some time steps poking out 265 | if remaining != 0 : 266 | num_folds += 1 267 | padding = target + 2 * overlap - remaining 268 | x = self.pad_tensor(x, padding, side='after') 269 | 270 | folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda() 271 | 272 | # Get the values for the folded tensor 273 | for i in range(num_folds) : 274 | start = i * (target + overlap) 275 | end = start + target + 2 * overlap 276 | folded[i] = x[:, start:end, :] 277 | 278 | return folded 279 | 280 | 281 | def xfade_and_unfold(self, y, target, overlap) : 282 | 283 | ''' Applies a crossfade and unfolds into a 1d array. 284 | 285 | Args: 286 | y (ndarry) : Batched sequences of audio samples 287 | shape=(num_folds, target + 2 * overlap) 288 | dtype=np.float64 289 | overlap (int) : Timesteps for both xfade and rnn warmup 290 | 291 | Return: 292 | (ndarry) : audio samples in a 1d array 293 | shape=(total_len) 294 | dtype=np.float64 295 | 296 | Details: 297 | y = [[seq1], 298 | [seq2], 299 | [seq3]] 300 | 301 | Apply a gain envelope at both ends of the sequences 302 | 303 | y = [[seq1_in, seq1_target, seq1_out], 304 | [seq2_in, seq2_target, seq2_out], 305 | [seq3_in, seq3_target, seq3_out]] 306 | 307 | Stagger and add up the groups of samples: 308 | 309 | [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] 310 | 311 | ''' 312 | 313 | num_folds, length = y.shape 314 | target = length - 2 * overlap 315 | total_len = num_folds * (target + overlap) + overlap 316 | 317 | # Need some silence for the rnn warmup 318 | silence_len = overlap // 2 319 | fade_len = overlap - silence_len 320 | silence = np.zeros((silence_len)) 321 | 322 | # Equal power crossfade 323 | t = np.linspace(-1, 1, fade_len) 324 | fade_in = np.sqrt(0.5 * (1 + t)) 325 | fade_out = np.sqrt(0.5 * (1 - t)) 326 | 327 | # Concat the silence to the fades 328 | fade_in = np.concatenate([silence, fade_in]) 329 | fade_out = np.concatenate([fade_out, silence]) 330 | 331 | # Apply the gain to the overlap samples 332 | y[:, :overlap] *= fade_in 333 | y[:, -overlap:] *= fade_out 334 | 335 | unfolded = np.zeros((total_len)) 336 | 337 | # Loop to add up all the samples 338 | for i in range(num_folds ) : 339 | start = i * (target + overlap) 340 | end = start + target + 2 * overlap 341 | unfolded[start:end] += y[i] 342 | 343 | return unfolded 344 | 345 | def generate(self, mels, target=11000, overlap=550, batched=True): 346 | 347 | self.eval() 348 | output = [] 349 | 350 | rnn1 = self.get_gru_cell(self.rnn1) 351 | 352 | with torch.no_grad(): 353 | mels = torch.FloatTensor(mels).cuda().unsqueeze(0) 354 | mels = self.pad_tensor(mels.transpose(1, 2), pad=hp.pad, side='both') 355 | 356 | mels, aux = self.upsample(mels.transpose(1, 2)) 357 | 358 | if batched: 359 | mels = self.fold_with_overlap(mels, target, overlap) 360 | aux = self.fold_with_overlap(aux, target, overlap) 361 | 362 | b_size, seq_len, _ = mels.size() 363 | 364 | h1 = torch.zeros(b_size, self.rnn_dims).cuda() 365 | 366 | x = torch.zeros(b_size, 1).cuda() 367 | 368 | d = self.aux_dims 369 | aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(2)] 370 | 371 | for i in range(seq_len): 372 | 373 | m_t = mels[:, i, :] 374 | 375 | a1_t, a2_t = \ 376 | (a[:, i, :] for a in aux_split) 377 | 378 | x = torch.cat([x, m_t, a1_t[:,:-1]], dim=1) 379 | x = self.I(x) 380 | h1 = rnn1(x, h1) 381 | 382 | x = x + h1 383 | x = torch.cat([x, a2_t], dim=1) 384 | x = F.relu(self.fc1(x)) 385 | #x = F.relu(self.fc2(x)) 386 | x = self.fc3(x) 387 | 388 | if hp.input_type == 'raw': 389 | sample = sample_from_beta_dist(x.unsqueeze(0)).view(-1) 390 | elif hp.input_type == 'mixture': 391 | sample = sample_from_discretized_mix_logistic(x.unsqueeze(-1),hp.log_scale_min) 392 | elif hp.input_type == 'bits': 393 | posterior = F.softmax(x, dim=1) 394 | distrib = torch.distributions.Categorical(posterior) 395 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. 396 | elif hp.input_type == 'mulaw': 397 | posterior = F.softmax(x, dim=1) 398 | distrib = torch.distributions.Categorical(posterior) 399 | sample = inv_mulaw_quantize(distrib.sample(), hp.mulaw_quantize_channels, True) 400 | 401 | output.append(sample) 402 | x = sample.unsqueeze(-1) 403 | 404 | output = torch.stack(output).transpose(0, 1) 405 | output = output.cpu().numpy() 406 | 407 | if batched: 408 | output = self.xfade_and_unfold(output, target, overlap) 409 | else: 410 | output = output[0] 411 | 412 | self.train() 413 | return output 414 | 415 | def batch_generate(self, mels) : 416 | """mel should be of shape [batch_size x 80 x mel_length] 417 | """ 418 | self.eval() 419 | output = [] 420 | rnn1 = self.get_gru_cell(self.rnn1) 421 | #rnn2 = self.get_gru_cell(self.rnn2) 422 | b_size = mels.shape[0] 423 | assert len(mels.shape) == 3, "mels should have shape [batch_size x 80 x mel_length]" 424 | 425 | with torch.no_grad() : 426 | x = torch.zeros(b_size, 1).cuda() 427 | h1 = torch.zeros(b_size, self.rnn_dims).cuda() 428 | 429 | mels = torch.FloatTensor(mels).cuda() 430 | mels, aux = self.upsample(mels) 431 | 432 | aux_idx = [self.aux_dims * i for i in range(3)] 433 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]] 434 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]] 435 | 436 | seq_len = mels.size(1) 437 | 438 | for i in tqdm(range(seq_len)) : 439 | 440 | m_t = mels[:, i, :] 441 | a1_t = a1[:, i, :] 442 | a2_t = a2[:, i, :] 443 | 444 | 445 | x = torch.cat([x, m_t, a1_t[:,:-1]], dim=1) 446 | x = self.I(x) 447 | h1 = rnn1(x, h1) 448 | 449 | x = x + h1 450 | x = torch.cat([x, a2_t], dim=1) 451 | 452 | x = F.relu(self.fc1(x)) 453 | #x = F.relu(self.fc2(x)) 454 | x = self.fc3(x) 455 | 456 | if hp.input_type == 'raw': 457 | sample = sample_from_beta_dist(x.unsqueeze(0)) 458 | elif hp.input_type == 'mixture': 459 | sample = sample_from_discretized_mix_logistic(x.unsqueeze(-1),hp.log_scale_min) 460 | elif hp.input_type == 'bits': 461 | posterior = F.softmax(x, dim=1).view(b_size, -1) 462 | distrib = torch.distributions.Categorical(posterior) 463 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. 464 | elif hp.input_type == 'mulaw': 465 | posterior = F.softmax(x, dim=1).view(b_size, -1) 466 | distrib = torch.distributions.Categorical(posterior) 467 | sample = inv_mulaw_quantize(distrib.sample(), hp.mulaw_quantize_channels, True) 468 | output.append(sample.view(-1)) 469 | x = sample.view(b_size,1) 470 | output = torch.stack(output).cpu().numpy() 471 | self.train() 472 | # output is a batch of wav segments of shape [batch_size x seq_len] 473 | # will need to merge into one wav of size [batch_size * seq_len] 474 | assert output.shape[1] == b_size 475 | output = (output.swapaxes(1,0)).reshape(-1) 476 | return output 477 | 478 | def get_gru_cell(self, gru) : 479 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) 480 | gru_cell.weight_hh.data = gru.weight_hh_l0.data 481 | gru_cell.weight_ih.data = gru.weight_ih_l0.data 482 | gru_cell.bias_hh.data = gru.bias_hh_l0.data 483 | gru_cell.bias_ih.data = gru.bias_ih_l0.data 484 | return gru_cell 485 | 486 | 487 | def build_model(): 488 | """build model with hparams settings 489 | 490 | """ 491 | if hp.input_type == 'raw': 492 | print('building model with Beta distribution output') 493 | elif hp.input_type == 'mixture': 494 | print("building model with mixture of logistic output") 495 | elif hp.input_type == 'bits': 496 | print("building model with quantized bit audio") 497 | elif hp.input_type == 'mulaw': 498 | print("building model with quantized mulaw encoding") 499 | else: 500 | raise ValueError('input_type provided not supported') 501 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits, 502 | hp.pad, hp.upsample_factors, hp.num_mels, 503 | hp.compute_dims, hp.res_out_dims, hp.res_blocks) 504 | 505 | return model 506 | 507 | def no_test_build_model(): 508 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits, 509 | hp.pad, hp.upsample_factors, hp.num_mels, 510 | hp.compute_dims, hp.res_out_dims, hp.res_blocks).cuda() 511 | print(vars(model)) 512 | 513 | 514 | def test_batch_generate(): 515 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits, 516 | hp.pad, hp.upsample_factors, hp.num_mels, 517 | hp.compute_dims, hp.res_out_dims, hp.res_blocks).cuda() 518 | print(vars(model)) 519 | batch_mel = torch.rand(3, 80, 100) 520 | output = model.batch_generate(batch_mel) 521 | print(output.shape) -------------------------------------------------------------------------------- /model_outputs/mel-northandsouth_52_f000076.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geneing/WaveRNN-Pytorch/7b317c4d930ad8b7405e72c5ace7b6481bdc6f2b/model_outputs/mel-northandsouth_52_f000076.npy -------------------------------------------------------------------------------- /model_outputs/mel-northandsouth_52_f000076.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geneing/WaveRNN-Pytorch/7b317c4d930ad8b7405e72c5ace7b6481bdc6f2b/model_outputs/mel-northandsouth_52_f000076.wav -------------------------------------------------------------------------------- /model_outputs/model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geneing/WaveRNN-Pytorch/7b317c4d930ad8b7405e72c5ace7b6481bdc6f2b/model_outputs/model.bin -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess dataset 3 | 4 | usage: 5 | preprocess.py [options] ... 6 | 7 | options: 8 | --output-dir= Directory where processed outputs are saved. [default: data_dir]. 9 | -h, --help Show help message. 10 | """ 11 | import os 12 | from docopt import docopt 13 | import numpy as np 14 | import math, pickle, os 15 | from audio import * 16 | from hparams import hparams as hp 17 | from utils import * 18 | from tqdm import tqdm 19 | 20 | def get_wav_mel(path): 21 | """Given path to .wav file, get the quantized wav and mel spectrogram as numpy vectors 22 | 23 | """ 24 | wav = load_wav(path) 25 | mel = melspectrogram(wav) 26 | if hp.input_type == 'raw' or hp.input_type=='mixture': 27 | return wav.astype(np.float32), mel 28 | elif hp.input_type == 'mulaw': 29 | quant = mulaw_quantize(wav, hp.mulaw_quantize_channels) 30 | return quant.astype(np.int), mel 31 | elif hp.input_type == 'bits': 32 | quant = quantize(wav) 33 | return quant.astype(np.int), mel 34 | else: 35 | raise ValueError("hp.input_type {} not recognized".format(hp.input_type)) 36 | 37 | 38 | def process_data(wav_dirs, output_path, mel_path, wav_path): 39 | """ 40 | given wav directory and output directory, process wav files and save quantized wav and mel 41 | spectrogram to output directory 42 | """ 43 | dataset_ids = [] 44 | # get list of wav files 45 | wav_files=[] 46 | for wav_dir in wav_dirs: 47 | thisdir = os.listdir(wav_dir) 48 | thisdir=[ os.path.join(wav_dir, thisfile) for thisfile in thisdir] 49 | wav_files += thisdir 50 | 51 | # check wav_file 52 | assert len(wav_files) != 0 or wav_files[0][-4:] == '.wav', "no wav files found!" 53 | # create training and testing splits 54 | test_wav_files = wav_files[:4] 55 | wav_files = wav_files[4:] 56 | for i, wav_file in enumerate(tqdm(wav_files)): 57 | # get the file id 58 | file_id = '{:d}'.format(i).zfill(5) 59 | wav, mel = get_wav_mel(os.path.join(wav_dir,wav_file)) 60 | # save 61 | np.save(os.path.join(mel_path,file_id+".npy"), mel) 62 | np.save(os.path.join(wav_path,file_id+".npy"), wav) 63 | # add to dataset_ids 64 | dataset_ids.append(file_id) 65 | 66 | # save dataset_ids 67 | with open(os.path.join(output_path,'dataset_ids.pkl'), 'wb') as f: 68 | pickle.dump(dataset_ids, f) 69 | 70 | # process testing_wavs 71 | test_path = os.path.join(output_path,'test') 72 | os.makedirs(test_path, exist_ok=True) 73 | for i, wav_file in enumerate(test_wav_files): 74 | wav, mel = get_wav_mel(os.path.join(wav_dir,wav_file)) 75 | # save test_wavs 76 | np.save(os.path.join(test_path,"test_{}_mel.npy".format(i)),mel) 77 | np.save(os.path.join(test_path,"test_{}_wav.npy".format(i)),wav) 78 | 79 | 80 | print("\npreprocessing done, total processed wav files:{}.\nProcessed files are located in:{}".format(len(wav_files), os.path.abspath(output_path))) 81 | 82 | 83 | 84 | if __name__=="__main__": 85 | args = docopt(__doc__) 86 | wav_dir = args[""] 87 | output_dir = args["--output-dir"] 88 | 89 | # create paths 90 | output_path = os.path.join(output_dir,"") 91 | mel_path = os.path.join(output_dir,"mel") 92 | wav_path = os.path.join(output_dir,"wav") 93 | 94 | # create dirs 95 | os.makedirs(output_path, exist_ok=True) 96 | os.makedirs(mel_path, exist_ok=True) 97 | os.makedirs(wav_path, exist_ok=True) 98 | 99 | # process data 100 | process_data(wav_dir, output_path, mel_path, wav_path) 101 | 102 | 103 | 104 | def test_get_wav_mel(): 105 | wav, mel = get_wav_mel('sample.wav') 106 | print(wav.shape, mel.shape) 107 | print(wav) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt 2 | librosa 3 | nnmnkwii 4 | tqdm 5 | lws 6 | scipy 7 | numpy 8 | tensorboardx 9 | matplotlib 10 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | """Synthesis script for WaveRNN vocoder 2 | 3 | usage: synthesize.py [options] 4 | 5 | options: 6 | --checkpoint-dir= Directory where model checkpoint is saved [default: checkpoints]. 7 | --output-dir= Output Directory [default: model_outputs] 8 | --hparams= Hyper parameters [default: ]. 9 | --preset= Path of preset parameters (json). 10 | --checkpoint= Restore model from checkpoint path if given. 11 | --no-cuda Don't run on GPU 12 | -h, --help Show this help message and exit 13 | """ 14 | import os 15 | import librosa 16 | import glob 17 | 18 | from docopt import docopt 19 | from model import * 20 | from hparams import hparams 21 | from utils import num_params_count 22 | import pickle 23 | import time 24 | import numpy as np 25 | import scipy as sp 26 | 27 | 28 | 29 | if __name__ == "__main__": 30 | args = docopt(__doc__) 31 | print("Command line args:\n", args) 32 | checkpoint_dir = args["--checkpoint-dir"] 33 | output_path = args["--output-dir"] 34 | checkpoint_path = args["--checkpoint"] 35 | preset = args["--preset"] 36 | no_cuda = args["--no-cuda"] 37 | 38 | device = torch.device("cpu" if no_cuda else "cuda") 39 | print("using device:{}".format(device)) 40 | 41 | # Load preset if specified 42 | if preset is not None: 43 | with open(preset) as f: 44 | hparams.parse_json(f.read()) 45 | # Override hyper parameters 46 | hparams.parse(args["--hparams"]) 47 | 48 | mel_file_name = args[''] 49 | mel = np.load(mel_file_name) 50 | if mel.shape[0] > mel.shape[1]: #ugly hack for transposed mels 51 | mel = mel.T 52 | 53 | if checkpoint_path is None: 54 | flist = glob.glob(f'{checkpoint_dir}/checkpoint_*.pth') 55 | latest_checkpoint = max(flist, key=os.path.getctime) 56 | else: 57 | latest_checkpoint = checkpoint_path 58 | print('Loading: %s'%latest_checkpoint) 59 | # build model, create optimizer 60 | model = build_model().to(device) 61 | checkpoint = torch.load(latest_checkpoint, map_location=device) 62 | model.load_state_dict(checkpoint["state_dict"]) 63 | 64 | print("I: %.3f million"%(num_params_count(model.I))) 65 | print("Upsample: %.3f million"%(num_params_count(model.upsample))) 66 | print("rnn1: %.3f million"%(num_params_count(model.rnn1))) 67 | #print("rnn2: %.3f million"%(num_params_count(model.rnn2))) 68 | print("fc1: %.3f million"%(num_params_count(model.fc1))) 69 | #print("fc2: %.3f million"%(num_params_count(model.fc2))) 70 | print("fc3: %.3f million"%(num_params_count(model.fc3))) 71 | 72 | 73 | #onnx export 74 | model.train(False) 75 | #wav = np.load('WaveRNN-Pytorch/checkpoint/test_0_wav.npy') 76 | 77 | #doesn't work torch.onnx.export(model, (torch.tensor(wav),torch.tensor(mel)), checkpoint_dir+'/wavernn.onnx', verbose=True, input_names=['mel_input'], output_names=['wav_output']) 78 | 79 | 80 | #mel = np.pad(mel,(24000,0),'constant') 81 | # n_mels = mel.shape[1] 82 | # n_mels = hparams.batch_size_gen * (n_mels // hparams.batch_size_gen) 83 | # mel = mel[:, 0:n_mels] 84 | 85 | 86 | mel0 = mel.copy() 87 | mel0=np.hstack([np.ones([80,40])*(-4), mel0, np.ones([80,40])*(-4)]) 88 | start = time.time() 89 | output0 = model.generate(mel0, batched=False, target=2000, overlap=64) 90 | total_time = time.time() - start 91 | frag_time = len(output0) / hparams.sample_rate 92 | print("Generation time: {}. Sound time: {}, ratio: {}".format(total_time, frag_time, frag_time/total_time)) 93 | 94 | librosa.output.write_wav(os.path.join(output_path, os.path.basename(mel_file_name)+'_orig.wav'), output0, hparams.sample_rate) 95 | 96 | #mel = mel.reshape([mel.shape[0], hparams.batch_size_gen, -1]).swapaxes(0,1) 97 | #output, out1 = model.batch_generate(mel) 98 | #bootstrap_len = hp.hop_size * hp.resnet_pad 99 | #output=output[:,bootstrap_len:].reshape(-1) 100 | # librosa.output.write_wav(os.path.join(output_path, os.path.basename(mel_file_name)+'.wav'), output, hparams.sample_rate) 101 | with open(os.path.join(output_path, os.path.basename(mel_file_name)+'.pkl'), 'wb') as f: 102 | pickle.dump((output0,), f) 103 | print('done') 104 | -------------------------------------------------------------------------------- /test_wavernnvocoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | 4 | from scipy.io.wavfile import write 5 | 6 | 7 | sys.path.insert(0,'lib/build-src-RelDebInfo') 8 | sys.path.insert(0,'library/build-src-Desktop-RelWithDebInfo') 9 | import WaveRNNVocoder 10 | import numpy as np 11 | 12 | vocoder=WaveRNNVocoder.Vocoder() 13 | 14 | vocoder.loadWeights('model_outputs/model.bin') 15 | 16 | # mel_file='../TrainingData/LJSpeech-1.0.wavernn/mel/00001.npy' 17 | # mel1 = np.load(mel_file) 18 | # mel1 = mel1.astype('float32') 19 | # wav=vocoder.melToWav(mel) 20 | # print() 21 | 22 | filelist = glob.glob('eval/mel*.npy') 23 | 24 | for fname in filelist: 25 | mel = np.load(fname).T 26 | wav = vocoder.melToWav(mel) 27 | break 28 | 29 | #scaled = np.int16(wav/np.max(np.abs(wav)) * 32767) 30 | write('test.wav',16000, wav) 31 | 32 | print() 33 | 34 | fnames=['inputs/00000.npy','inputs/mel-northandsouth_01_f000001.npy'] 35 | mel0=np.load(fnames[0]) 36 | mel1=np.load(fnames[1]).T 37 | mel2=np.load(filelist[0]).T 38 | 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Training WaveRNN Model. 2 | 3 | usage: train.py [options] 4 | 5 | options: 6 | --checkpoint-dir= Directory where to save model checkpoints [default: checkpoints]. 7 | --checkpoint= Restore model from checkpoint path if given. 8 | --log-event-path= Path to tensorboard event log 9 | --dataset= Dataset type Tacotron, TTS, Audiobooks [default: Tacotron]. 10 | -h, --help Show this help message and exit 11 | """ 12 | import os 13 | 14 | import librosa 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | from datetime import datetime 19 | from docopt import docopt 20 | from os.path import join 21 | from tensorboardX import SummaryWriter 22 | from torch import nn 23 | from torch import optim 24 | from torch.utils.data import DataLoader 25 | from tqdm import tqdm 26 | 27 | from dataset import raw_collate, discrete_collate, AudiobookDataset, TacotronDataset, MozillaTTS 28 | from distributions import * 29 | from hparams import hparams as hp 30 | from loss_function import nll_loss 31 | from lrschedule import noam_learning_rate_decay, step_learning_rate_decay 32 | from model import build_model 33 | from utils import num_params_count 34 | 35 | global_step = 0 36 | global_epoch = 0 37 | global_test_step = 0 38 | use_cuda = torch.cuda.is_available() 39 | 40 | 41 | def np_now(tensor): 42 | return tensor.detach().cpu().numpy() 43 | 44 | 45 | def clamp(x, lo=0, hi=1): 46 | return max(lo, min(hi, x)) 47 | 48 | 49 | class PruneMask(): 50 | def __init__(self, layer, prune_rnn_input): 51 | self.mask = [] 52 | self.p_idx = [0] 53 | self.total_params = 0 54 | self.pruned_params = 0 55 | self.split_size = 0 56 | self.init_mask(layer, prune_rnn_input) 57 | 58 | def init_mask(self, layer, prune_rnn_input): 59 | # Determine the layer type and 60 | # num matrix splits if rnn 61 | layer_type = str(layer).split('(')[0] 62 | splits = {'Linear': 1, 'GRU': 3, 'LSTM': 4} 63 | 64 | # Organise the num and indices of layer parameters 65 | # Dense will have one index and rnns two (if pruning input) 66 | if layer_type != 'Linear': 67 | self.p_idx = [0, 1] if prune_rnn_input else [1] 68 | 69 | # Get list of parameters from layers 70 | params = self.get_params(layer) 71 | 72 | # For each param matrix in this layer, create a mask 73 | for W in params: 74 | self.mask += [torch.ones_like(W)] 75 | self.total_params += W.size(0) * W.size(1) 76 | 77 | # Need a split size for mask_from_matrix() later on 78 | self.split_size = self.mask[0].size(0) // splits[layer_type] 79 | 80 | def get_params(self, layer): 81 | params = [] 82 | for idx in self.p_idx: 83 | params += [list(layer.parameters())[idx].data] 84 | return params 85 | 86 | def update_mask(self, layer, z): 87 | params = self.get_params(layer) 88 | for i, W in enumerate(params): 89 | self.mask[i] = self.mask_from_matrix(W, z) 90 | self.update_prune_count() 91 | 92 | def apply_mask(self, layer): 93 | params = self.get_params(layer) 94 | for M, W in zip(self.mask, params): 95 | W *= M 96 | 97 | def mask_from_matrix(self, W, z): 98 | # Split into gate matrices (or not) 99 | if self.split_size>1: 100 | W_split = torch.split(W, self.split_size) 101 | else: 102 | W_split = W 103 | 104 | M = [] 105 | # Loop through splits 106 | for W in W_split: 107 | # Sort the magnitudes 108 | N = W.shape[1] 109 | 110 | W_abs = torch.abs(W) 111 | L = W_abs.reshape(W.shape[0], N // hp.sparse_group, hp.sparse_group) 112 | S = L.sum(dim=2) 113 | sorted_abs, _ = torch.sort(S.view(-1)) 114 | 115 | # Pick k (num weights to zero) 116 | k = int(W.shape[0] * W.shape[1] // hp.sparse_group * z) 117 | threshold = sorted_abs[k] 118 | mask = (S >= threshold).float() 119 | mask = mask.unsqueeze(2).expand(-1,-1,hp.sparse_group) 120 | mask = mask.reshape(W.shape[0], W.shape[1]) 121 | 122 | # Create the mask 123 | M += [mask] 124 | 125 | return torch.cat(M) 126 | 127 | def update_prune_count(self): 128 | self.pruned_params = 0 129 | for M in self.mask: 130 | self.pruned_params += int(np_now((M - 1).sum() * -1)) 131 | 132 | 133 | class Pruner(object): 134 | def __init__(self, layers, start_prune, prune_steps, target_sparsity, 135 | prune_rnn_input=True): 136 | self.z = 0 # Objects sparsity @ time t 137 | self.t_0 = start_prune 138 | self.S = prune_steps 139 | self.Z = target_sparsity 140 | self.num_pruned = 0 141 | self.total_params = 0 142 | self.masks = [] 143 | self.layers = layers 144 | for (layer,z) in layers: 145 | self.masks += [PruneMask(layer, prune_rnn_input)] 146 | self.count_total_params() 147 | 148 | def update_sparsity(self, t, Z): 149 | z = Z * (1 - (1 - (t - self.t_0) / self.S) ** 3) 150 | z = clamp(z, 0, Z) 151 | return z 152 | 153 | def prune(self, step): 154 | for ((l,z), m) in zip(self.layers, self.masks): 155 | z_curr = self.update_sparsity(step, z) 156 | m.update_mask(l, z_curr) 157 | m.apply_mask(l) 158 | return self.count_num_pruned(), z_curr 159 | 160 | def restart(self, layers, step): 161 | # In case training is stopped 162 | self.update_sparsity(step) 163 | for ((l, z), m) in zip(layers, self.masks): 164 | z_curr = self.update_sparsity(step, z) 165 | m.update_mask(l, z_curr) 166 | 167 | def count_num_pruned(self): 168 | self.num_pruned = 0 169 | for m in self.masks: 170 | self.num_pruned += m.pruned_params 171 | return self.num_pruned 172 | 173 | def count_total_params(self): 174 | for m in self.masks: 175 | self.total_params += m.total_params 176 | return self.total_params 177 | 178 | def save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch): 179 | checkpoint_path = join( 180 | checkpoint_dir, "checkpoint_step{:09d}.pth".format(step)) 181 | optimizer_state = optimizer.state_dict() 182 | global global_test_step 183 | torch.save({ 184 | "state_dict": model.state_dict(), 185 | "optimizer": optimizer_state, 186 | "global_step": step, 187 | "global_epoch": epoch, 188 | "global_test_step": global_test_step, 189 | }, checkpoint_path) 190 | print("Saved checkpoint:", checkpoint_path) 191 | 192 | 193 | def _load(checkpoint_path): 194 | if use_cuda: 195 | checkpoint = torch.load(checkpoint_path) 196 | else: 197 | checkpoint = torch.load(checkpoint_path, 198 | map_location=lambda storage, loc: storage) 199 | return checkpoint 200 | 201 | 202 | def load_checkpoint(path, model, optimizer, reset_optimizer): 203 | global global_step 204 | global global_epoch 205 | global global_test_step 206 | 207 | print("Load checkpoint from: {}".format(path)) 208 | checkpoint = _load(path) 209 | model.load_state_dict(checkpoint["state_dict"], strict=False) 210 | if not reset_optimizer: 211 | optimizer_state = checkpoint["optimizer"] 212 | if optimizer_state is not None: 213 | print("Load optimizer state from {}".format(path)) 214 | try: 215 | optimizer.load_state_dict(checkpoint["optimizer"]) 216 | except Exception as e: 217 | print(e) 218 | global_step = checkpoint["global_step"] 219 | global_epoch = checkpoint["global_epoch"] 220 | global_test_step = checkpoint.get("global_test_step", 0) 221 | 222 | return model 223 | 224 | 225 | def test_save_checkpoint(): 226 | checkpoint_path = "checkpoints/" 227 | device = torch.device("cuda" if use_cuda else "cpu") 228 | model = build_model() 229 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 230 | global global_step, global_epoch, global_test_step 231 | save_checkpoint(device, model, optimizer, global_step, checkpoint_path, global_epoch) 232 | 233 | model = load_checkpoint(checkpoint_path + "checkpoint_step000000000.pth", model, optimizer, False) 234 | 235 | 236 | def evaluate_model(model, data_loader, checkpoint_dir, limit_eval_to=5): 237 | """evaluate model and save generated wav and plot 238 | 239 | """ 240 | test_path = data_loader.dataset.test_path 241 | test_files = os.listdir(test_path) 242 | counter = 0 243 | output_dir = os.path.join(checkpoint_dir, 'eval') 244 | for f in test_files: 245 | if (f[-7:] == "mel.npy") or ('mel' in f): 246 | mel = np.load(os.path.join(test_path, f)) 247 | if mel.shape[-1]==hp.num_mels: #fix the order 248 | mel = mel.T 249 | wav = model.generate(mel, batched=True) 250 | # save wav 251 | wav_path = os.path.join(output_dir, "checkpoint_step{:09d}_wav_{}.wav".format(global_step, counter)) 252 | librosa.output.write_wav(wav_path, wav.astype('float32'), sr=hp.sample_rate) 253 | # save wav plot 254 | fig_path = os.path.join(output_dir, "checkpoint_step{:09d}_wav_{}.png".format(global_step, counter)) 255 | fig = plt.plot(wav.reshape(-1)) 256 | plt.savefig(fig_path) 257 | # clear fig to drawing to the same plot 258 | plt.clf() 259 | 260 | if counter == 0: 261 | wav = model.generate(mel, batched=False) 262 | # save wav 263 | wav_path = os.path.join(output_dir, 264 | "checkpoint_step{:09d}_wav_unbatched_{}.wav".format(global_step, counter)) 265 | librosa.output.write_wav(wav_path, wav.astype('float32'), sr=hp.sample_rate) 266 | # save wav plot 267 | fig_path = os.path.join(output_dir, 268 | "checkpoint_step{:09d}_wav_unbatched_{}.png".format(global_step, counter)) 269 | fig = plt.plot(wav.reshape(-1)) 270 | plt.savefig(fig_path) 271 | # clear fig to drawing to the same plot 272 | plt.clf() 273 | 274 | counter += 1 275 | 276 | # stop evaluation early via limit_eval_to 277 | if counter >= limit_eval_to: 278 | break 279 | 280 | 281 | def train_loop(device, model, data_loader, optimizer, checkpoint_dir): 282 | """Main training loop. 283 | 284 | """ 285 | # create loss and put on device 286 | if hp.input_type == 'raw': 287 | if hp.distribution == 'beta': 288 | criterion = beta_mle_loss 289 | elif hp.distribution == 'gaussian': 290 | criterion = gaussian_loss 291 | elif hp.input_type == 'mixture': 292 | criterion = discretized_mix_logistic_loss 293 | elif hp.input_type in ["bits", "mulaw"]: 294 | criterion = nll_loss 295 | else: 296 | raise ValueError("input_type:{} not supported".format(hp.input_type)) 297 | 298 | # Pruner for reducing memory footprint 299 | layers = [(model.I,hp.sparsity_target), (model.rnn1,hp.sparsity_target_rnn), (model.fc1,hp.sparsity_target), (model.fc3,hp.sparsity_target)] #(model.fc2,hp.sparsity_target), 300 | pruner = Pruner(layers, hp.start_prune, hp.prune_steps, hp.sparsity_target) 301 | 302 | global global_step, global_epoch, global_test_step 303 | while global_epoch < hp.nepochs: 304 | running_loss = 0 305 | for i, (x, m, y) in enumerate(tqdm(data_loader)): 306 | x, m, y = x.to(device), m.to(device), y.to(device) 307 | y_hat = model(x, m) 308 | y = y.unsqueeze(-1) 309 | loss = criterion(y_hat, y) 310 | # calculate learning rate and update learning rate 311 | if hp.fix_learning_rate: 312 | current_lr = hp.fix_learning_rate 313 | elif hp.lr_schedule_type == 'step': 314 | current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, 315 | hp.lr_step_interval) 316 | else: 317 | current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps) 318 | for param_group in optimizer.param_groups: 319 | param_group['lr'] = current_lr 320 | optimizer.zero_grad() 321 | loss.backward() 322 | # clip gradient norm 323 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm) 324 | optimizer.step() 325 | num_pruned, z = pruner.prune(global_step) 326 | 327 | running_loss += loss.item() 328 | avg_loss = running_loss / (i + 1) 329 | 330 | writer.add_scalar("loss", float(loss.item()), global_step) 331 | writer.add_scalar("avg_loss", float(avg_loss), global_step) 332 | writer.add_scalar("learning_rate", float(current_lr), global_step) 333 | writer.add_scalar("grad_norm", float(grad_norm), global_step) 334 | writer.add_scalar("num_pruned", float(num_pruned), global_step) 335 | writer.add_scalar("fraction_pruned", z, global_step) 336 | 337 | # saving checkpoint if needed 338 | if global_step != 0 and global_step % hp.save_every_step == 0: 339 | pruner.prune(global_step) 340 | save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch) 341 | # evaluate model if needed 342 | if global_step != 0 and global_test_step != True and global_step % hp.evaluate_every_step == 0: 343 | pruner.prune(global_step) 344 | print("step {}, evaluating model: generating wav from mel...".format(global_step)) 345 | evaluate_model(model, data_loader, checkpoint_dir) 346 | print("evaluation finished, resuming training...") 347 | 348 | # reset global_test_step status after evaluation 349 | if global_test_step is True: 350 | global_test_step = False 351 | global_step += 1 352 | 353 | print("epoch:{}, running loss:{}, average loss:{}, current lr:{}, num_pruned:{} ({}%)".format(global_epoch, running_loss, avg_loss, 354 | current_lr, num_pruned, z)) 355 | global_epoch += 1 356 | 357 | 358 | def test_prune(model): 359 | layers = [model.rnn1] #, model.rnn2] 360 | start_prune = 0 361 | prune_steps = 100 # 20000 362 | sparsity_target = 0.9375 363 | pruner = Pruner(layers, start_prune, prune_steps, sparsity_target) 364 | 365 | for i in range(100): 366 | n_pruned = pruner.prune(100) 367 | print(f'{i}: {n_pruned}') 368 | 369 | return layers 370 | 371 | 372 | datasetreader = {"Tacotron":TacotronDataset, "TTS":MozillaTTS, "Audiobooks":AudiobookDataset} 373 | if __name__ == "__main__": 374 | args = docopt(__doc__) 375 | # print("Command line args:\n", args) 376 | checkpoint_dir = args["--checkpoint-dir"] 377 | checkpoint_path = args["--checkpoint"] 378 | data_root = args[""] 379 | log_event_path = args["--log-event-path"] 380 | dataset_type = args["--dataset"] 381 | # make dirs, load dataloader and set up device 382 | os.makedirs(checkpoint_dir, exist_ok=True) 383 | os.makedirs(os.path.join(checkpoint_dir, 'eval'), exist_ok=True) 384 | #dataset = AudiobookDataset(data_root) 385 | #dataset = TacotronDataset(data_root) 386 | dataset = datasetreader[dataset_type]( data_root ) 387 | 388 | if hp.input_type == 'raw': 389 | collate_fn = raw_collate 390 | elif hp.input_type == 'mixture': 391 | collate_fn = raw_collate 392 | elif hp.input_type in ['bits', 'mulaw']: 393 | collate_fn = discrete_collate 394 | else: 395 | raise ValueError("input_type:{} not supported".format(hp.input_type)) 396 | data_loader = DataLoader(dataset, collate_fn=collate_fn, shuffle=True, num_workers=int(hp.num_workers), 397 | batch_size=hp.batch_size) 398 | device = torch.device("cuda" if use_cuda else "cpu") 399 | print("using device:{}".format(device)) 400 | 401 | if log_event_path is None: 402 | log_event_path = "log/log_" + datetime.now().strftime("%Y%m%d-%H%M%S") 403 | else: 404 | log_event_path += "/" + datetime.now().strftime("%Y%m%d-%H%M%S") 405 | print("Tensorboard event path: {}".format(log_event_path)) 406 | writer = SummaryWriter(log_dir=log_event_path) 407 | 408 | # build model, create optimizer 409 | model = build_model().to(device) 410 | print("Parameter Count:") 411 | print("I: %.3f million" % (num_params_count(model.I))) 412 | print("Upsample: %.3f million" % (num_params_count(model.upsample))) 413 | print("rnn1: %.3f million" % (num_params_count(model.rnn1))) 414 | #print("rnn2: %.3f million" % (num_params_count(model.rnn2))) 415 | print("fc1: %.3f million" % (num_params_count(model.fc1))) 416 | #print("fc2: %.3f million" % (num_params_count(model.fc2))) 417 | print("fc3: %.3f million" % (num_params_count(model.fc3))) 418 | print(model) 419 | 420 | optimizer = optim.Adam(model.parameters(), 421 | lr=hp.initial_learning_rate, betas=( 422 | hp.adam_beta1, hp.adam_beta2), 423 | eps=hp.adam_eps, weight_decay=hp.weight_decay, 424 | amsgrad=hp.amsgrad) 425 | 426 | if hp.fix_learning_rate: 427 | print("using fixed learning rate of :{}".format(hp.fix_learning_rate)) 428 | elif hp.lr_schedule_type == 'step': 429 | print("using exponential learning rate decay") 430 | elif hp.lr_schedule_type == 'noam': 431 | print("using noam learning rate decay") 432 | 433 | # load checkpoint 434 | if checkpoint_path is None: 435 | print("no checkpoint specified as --checkpoint argument, creating new model...") 436 | else: 437 | model = load_checkpoint(checkpoint_path, model, optimizer, True) #ei False 438 | print("loading model from checkpoint:{}".format(checkpoint_path)) 439 | # set global_test_step to True so we don't evaluate right when we load in the model 440 | global_test_step = True 441 | 442 | # main train loop 443 | try: 444 | train_loop(device, model, data_loader, optimizer, checkpoint_dir) 445 | except KeyboardInterrupt: 446 | print("Interrupted!") 447 | pass 448 | except Exception as e: 449 | print(e) 450 | finally: 451 | print("saving model....") 452 | save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch) 453 | 454 | 455 | def test_eval(): 456 | data_root = "data_dir" 457 | #dataset = AudiobookDataset(data_root) 458 | #dataset = TacotronDataset(data_root) 459 | dataset = MozillaTTS( data_root ) 460 | if hp.input_type == 'raw': 461 | collate_fn = raw_collate 462 | elif hp.input_type == 'bits': 463 | collate_fn = discrete_collate 464 | else: 465 | raise ValueError("input_type:{} not supported".format(hp.input_type)) 466 | data_loader = DataLoader(dataset, collate_fn=collate_fn, shuffle=True, num_workers=0, batch_size=hp.batch_size) 467 | device = torch.device("cuda" if use_cuda else "cpu") 468 | print("using device:{}".format(device)) 469 | 470 | # build model, create optimizer 471 | model = build_model().to(device) 472 | 473 | evaluate_model(model, data_loader) 474 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def num_params_count(model): 5 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 6 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 7 | return parameters 8 | 9 | def num_params(model): 10 | print('Trainable Parameters: %.3f million' % num_params_count(model)) 11 | 12 | 13 | # for mulaw encoding and decoding in torch tensors, modified from: https://github.com/pytorch/audio/blob/master/torchaudio/transforms.py 14 | def mulaw_quantize(x, quantization_channels=256): 15 | """Encode signal based on mu-law companding. For more info see the 16 | `Wikipedia Entry `_ 17 | 18 | This algorithm assumes the signal has been scaled to between -1 and 1 and 19 | returns a signal encoded with values from 0 to quantization_channels - 1 20 | 21 | Args: 22 | quantization_channels (int): Number of channels. default: 256 23 | 24 | """ 25 | mu = quantization_channels - 1 26 | if isinstance(x, np.ndarray): 27 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 28 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int) 29 | elif isinstance(x, (torch.Tensor, torch.LongTensor)): 30 | 31 | if isinstance(x, torch.LongTensor): 32 | x = x.float() 33 | mu = torch.FloatTensor([mu]) 34 | x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) 35 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() 36 | return x_mu 37 | 38 | 39 | def inv_mulaw_quantize(x_mu, quantization_channels=256, cuda=False): 40 | """Decode mu-law encoded signal. For more info see the 41 | `Wikipedia Entry `_ 42 | 43 | This expects an input with values between 0 and quantization_channels - 1 44 | and returns a signal scaled between -1 and 1. 45 | 46 | Args: 47 | quantization_channels (int): Number of channels. default: 256 48 | 49 | """ 50 | mu = quantization_channels - 1. 51 | if isinstance(x_mu, np.ndarray): 52 | x = ((x_mu) / mu) * 2 - 1. 53 | x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu 54 | elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)): 55 | if isinstance(x_mu, (torch.LongTensor, torch.cuda.LongTensor)): 56 | x_mu = x_mu.float() 57 | if cuda: 58 | mu = (torch.FloatTensor([mu])).cuda() 59 | else: 60 | mu = torch.FloatTensor([mu]) 61 | x = ((x_mu) / mu) * 2 - 1. 62 | x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu 63 | return x 64 | 65 | 66 | def test_inv_mulaw(): 67 | wav = torch.rand(5, 5000) 68 | wav = wav.cuda() 69 | de_quant = inv_mulaw_quantize(wav, 512, True) --------------------------------------------------------------------------------