├── .github └── stale.yml ├── .gitignore ├── .travis.yml ├── LICENSE.md ├── MANIFEST.in ├── README.md ├── appveyor.yml ├── audio.py ├── compute-meanvar-stats.py ├── datasets └── wavallin.py ├── docs ├── .gitignore ├── config.toml ├── content │ └── index.md ├── layouts │ ├── _default │ │ ├── list.html │ │ └── single.html │ ├── index.html │ └── partials │ │ ├── footer.html │ │ ├── header.html │ │ ├── mathjax.html │ │ └── social.html └── static │ ├── css │ ├── custom.css │ ├── normalize.css │ └── skeleton.css │ ├── favicon.png │ └── images │ └── r9y9.jpg ├── egs ├── README.md ├── gaussian │ ├── conf │ │ ├── gaussian_wavenet.json │ │ └── gaussian_wavenet_demo.json │ └── run.sh ├── mol │ ├── conf │ │ ├── mol_wavenet.json │ │ └── mol_wavenet_demo.json │ └── run.sh └── mulaw256 │ ├── conf │ ├── mulaw256_wavenet.json │ └── mulaw256_wavenet_demo.json │ └── run.sh ├── evaluate.py ├── hparams.py ├── lrschedule.py ├── mksubset.py ├── preprocess.py ├── preprocess_normalize.py ├── release.sh ├── setup.py ├── synthesis.py ├── tests ├── test_audio.py ├── test_misc.py ├── test_mixture.py └── test_model.py ├── tojson.py ├── tox.ini ├── train.py ├── utils └── parse_options.sh └── wavenet_vocoder ├── __init__.py ├── conv.py ├── mixture.py ├── modules.py ├── tfcompat ├── __init__.py ├── hparam.py └── readme.md ├── upsample.py ├── util.py ├── version.py └── wavenet.py /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an Issue or Pull Request becomes stale 2 | daysUntilStale: 60 3 | 4 | # Number of days of inactivity before an Issue or Pull Request with the stale label is closed. 5 | # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. 6 | daysUntilClose: 7 7 | 8 | # Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) 9 | onlyLabels: [] 10 | 11 | # Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable 12 | exemptLabels: 13 | - roadmap 14 | - bug 15 | - design 16 | - help wanted 17 | - doc 18 | 19 | # Set to true to ignore issues in a project (defaults to false) 20 | exemptProjects: true 21 | 22 | # Set to true to ignore issues in a milestone (defaults to false) 23 | exemptMilestones: true 24 | 25 | # Label to use when marking as stale 26 | staleLabel: wontfix 27 | 28 | # Comment to post when marking as stale. Set to `false` to disable 29 | markComment: > 30 | This issue has been automatically marked as stale because it has not had 31 | recent activity. It will be closed if no further activity occurs. Thank you 32 | for your contributions. 33 | 34 | # Limit the number of actions per hour, from 1-30. Default is 30 35 | limitPerRun: 30 36 | 37 | # Limit to only `issues` or `pulls` 38 | only: issues 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | foobar* 2 | pretrained_models 3 | notebooks 4 | checkpoints* 5 | log 6 | generated 7 | data 8 | text 9 | 10 | # recipe related (adapted from espnet) 11 | egs/*/data* 12 | egs/*/db 13 | egs/*/downloads 14 | egs/*/dump 15 | egs/*/enhan 16 | egs/*/exp 17 | egs/*/fbank 18 | egs/*/mfcc 19 | egs/*/stft 20 | egs/*/tensorboard 21 | egs/*/wav* 22 | 23 | # Created by https://www.gitignore.io 24 | 25 | ### Python ### 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | env/ 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | 82 | ### IPythonNotebook ### 83 | # Temporary data 84 | .ipynb_checkpoints/ 85 | 86 | 87 | ### SublimeText ### 88 | # cache files for sublime text 89 | *.tmlanguage.cache 90 | *.tmPreferences.cache 91 | *.stTheme.cache 92 | 93 | # workspace files are user-specific 94 | *.sublime-workspace 95 | 96 | # project files should be checked into the repository, unless a significant 97 | # proportion of contributors will probably not be using SublimeText 98 | # *.sublime-project 99 | 100 | # sftp configuration file 101 | sftp-config.json 102 | 103 | 104 | ### Emacs ### 105 | # -*- mode: gitignore; -*- 106 | *~ 107 | \#*\# 108 | /.emacs.desktop 109 | /.emacs.desktop.lock 110 | *.elc 111 | auto-save-list 112 | tramp 113 | .\#* 114 | 115 | # Org-mode 116 | .org-id-locations 117 | *_archive 118 | 119 | # flymake-mode 120 | *_flymake.* 121 | 122 | # eshell files 123 | /eshell/history 124 | /eshell/lastdir 125 | 126 | # elpa packages 127 | /elpa/ 128 | 129 | # reftex files 130 | *.rel 131 | 132 | # AUCTeX auto folder 133 | /auto/ 134 | 135 | # cask packages 136 | .cask/ 137 | 138 | 139 | ### Vim ### 140 | [._]*.s[a-w][a-z] 141 | [._]s[a-w][a-z] 142 | *.un~ 143 | Session.vim 144 | .netrwhist 145 | *~ 146 | 147 | 148 | ### C++ ### 149 | # Compiled Object files 150 | *.slo 151 | *.lo 152 | *.o 153 | *.obj 154 | 155 | # Precompiled Headers 156 | *.gch 157 | *.pch 158 | 159 | # Compiled Dynamic libraries 160 | *.so 161 | *.dylib 162 | *.dll 163 | 164 | # Fortran module files 165 | *.mod 166 | 167 | # Compiled Static libraries 168 | *.lai 169 | *.la 170 | *.a 171 | *.lib 172 | 173 | # Executables 174 | *.exe 175 | *.out 176 | *.app 177 | 178 | 179 | ### OSX ### 180 | .DS_Store 181 | .AppleDouble 182 | .LSOverride 183 | 184 | # Icon must end with two \r 185 | Icon 186 | 187 | 188 | # Thumbnails 189 | ._* 190 | 191 | # Files that might appear on external disk 192 | .Spotlight-V100 193 | .Trashes 194 | 195 | # Directories potentially created on remote AFP share 196 | .AppleDB 197 | .AppleDesktop 198 | Network Trash Folder 199 | Temporary Items 200 | .apdisk 201 | 202 | 203 | ### Linux ### 204 | *~ 205 | 206 | # KDE directory preferences 207 | .directory 208 | 209 | # Linux trash folder which might appear on any partition or disk 210 | .Trash-* 211 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | dist: xenial 3 | language: python 4 | 5 | python: 6 | - "3.6" 7 | 8 | notifications: 9 | email: false 10 | 11 | before_install: 12 | - sudo apt-get update 13 | - sudo apt-get install libsndfile-dev sox libav-tools 14 | - if [["$TRAVIS_PYTHON_VERSION" == "2.7"]]; then 15 | wget http://repo.continuum.io/miniconda/Miniconda-3.8.3-Linux-x86_64.sh -O miniconda.sh; 16 | else 17 | wget http://repo.continuum.io/miniconda/Miniconda3-3.8.3-Linux-x86_64.sh -O miniconda.sh; 18 | fi 19 | - bash miniconda.sh -b -p $HOME/miniconda 20 | - export PATH="$HOME/miniconda/bin:$PATH" 21 | - hash -r 22 | - conda config --set always_yes yes --set changeps1 no 23 | - conda update -q conda 24 | # Useful for debugging any issues with conda 25 | - conda config --add channels pypi 26 | - conda info -a 27 | - deps='pip numpy scipy cython nose pytorch' 28 | - conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps -c pytorch 29 | - source activate test-environment 30 | 31 | install: 32 | - pip install -e ".[test]" 33 | script: 34 | - nosetests -v -w tests/ -a '!local_only' 35 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The wavenet_vocoder package is licensed under the MIT "Expat" License: 2 | 3 | > Copyright (c) 2017: Ryuichi Yamamoto. 4 | > 5 | > Permission is hereby granted, free of charge, to any person obtaining 6 | > a copy of this software and associated documentation files (the 7 | > "Software"), to deal in the Software without restriction, including 8 | > without limitation the rights to use, copy, modify, merge, publish, 9 | > distribute, sublicense, and/or sell copies of the Software, and to 10 | > permit persons to whom the Software is furnished to do so, subject to 11 | > the following conditions: 12 | > 13 | > The above copyright notice and this permission notice shall be 14 | > included in all copies or substantial portions of the Software. 15 | > 16 | > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | > EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | > MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | > IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | > CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | > TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | > SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE.md 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveNet vocoder 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/wavenet_vocoder.svg)](https://pypi.python.org/pypi/wavenet_vocoder) 4 | [![Build Status](https://travis-ci.org/r9y9/wavenet_vocoder.svg?branch=master)](https://travis-ci.org/r9y9/wavenet_vocoder) 5 | [![Build status](https://ci.appveyor.com/api/projects/status/lvt9jtimtg0koxwj?svg=true)](https://ci.appveyor.com/project/r9y9/wavenet-vocoder) 6 | [![DOI](https://zenodo.org/badge/115492234.svg)](https://zenodo.org/badge/latestdoi/115492234) 7 | 8 | **NOTE**: This is the development version. If you need a stable version, please checkout the v0.1.1. 9 | 10 | The goal of the repository is to provide an implementation of the WaveNet vocoder, which can generate high quality raw speech samples conditioned on linguistic or acoustic features. 11 | 12 | Audio samples are available at https://r9y9.github.io/wavenet_vocoder/. 13 | 14 | ## News 15 | 16 | - 2019/10/31: The repository has been adapted to [ESPnet](https://github.com/espnet/espnet). English, Chinese, and Japanese samples and pretrained models are available there. See https://github.com/espnet/espnet and https://github.com/espnet/espnet#tts-results for details. 17 | 18 | ## Online TTS demo 19 | 20 | A notebook supposed to be executed on https://colab.research.google.com is available: 21 | 22 | - [Tacotron2: WaveNet-based text-to-speech demo](https://colab.research.google.com/github/r9y9/Colaboratory/blob/master/Tacotron2_and_WaveNet_text_to_speech_demo.ipynb) 23 | 24 | ## Highlights 25 | 26 | - Focus on local and global conditioning of WaveNet, which is essential for vocoder. 27 | - 16-bit raw audio modeling by mixture distributions: mixture of logistics (MoL), mixture of Gaussians, and single Gaussian distributions are supported. 28 | - Various audio samples and pre-trained models 29 | - Fast inference by caching intermediate states in convolutions. Similar to [arXiv:1611.09482](https://arxiv.org/abs/1611.09482) 30 | - Integration with ESPNet (https://github.com/espnet/espnet) 31 | 32 | ## Pre-trained models 33 | 34 | **Note**: This is not itself a text-to-speech (TTS) model. With a pre-trained model provided here, you can synthesize waveform given a *mel spectrogram*, not raw text. You will need mel-spectrogram prediction model (such as Tacotron2) to use the pre-trained models for TTS. 35 | 36 | **Note**: As for the pretrained model for LJSpeech, the model was fine-tuned multiple times and trained for more than 1000k steps in total. Please refer to the issues ([#1](https://github.com/r9y9/wavenet_vocoder/issues/1#issuecomment-361130247), [#75](https://github.com/r9y9/wavenet_vocoder/issues/75), [#45](https://github.com/r9y9/wavenet_vocoder/issues/45#issuecomment-383313651)) to know how the model was trained. 37 | 38 | | Model URL | Data | Hyper params URL | Git commit | Steps | 39 | |----------------------------------------------------------------------------------------------------------------------------------|------------|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------|---------------| 40 | | [link](https://www.dropbox.com/s/zdbfprugbagfp2w/20180510_mixture_lj_checkpoint_step000320000_ema.pth?dl=0) | LJSpeech | [link](https://www.dropbox.com/s/0vsd7973w20eskz/20180510_mixture_lj_checkpoint_step000320000_ema.json?dl=0) | [2092a64](https://github.com/r9y9/wavenet_vocoder/commit/2092a647e60ce002389818de1fa66d0a2c5763d8) | 1000k~ steps | 41 | | [link](https://www.dropbox.com/s/d0qk4ow9uuh2lww/20180212_mixture_multispeaker_cmu_arctic_checkpoint_step000740000_ema.pth?dl=0) | CMU ARCTIC | [link](https://www.dropbox.com/s/i35yigj5hvmeol8/20180212_multispeaker_cmu_arctic_mixture.json?dl=0) | [b1a1076](https://github.com/r9y9/wavenet_vocoder/tree/b1a1076e8b5d9b3e275c28f2f7f4d7cd0e75dae4) | 740k steps | 42 | 43 | To use pre-trained models, first checkout the specific git commit noted above. i.e., 44 | 45 | ``` 46 | git checkout ${commit_hash} 47 | ``` 48 | 49 | And then follows "Synthesize from a checkpoint" section in the README. Note that old version of synthesis.py may not accept `--preset=` parameter and you might have to change `hparams.py` according to the preset (json) file. 50 | 51 | You could try for example: 52 | 53 | ``` 54 | # Assuming you have downloaded LJSpeech-1.1 at ~/data/LJSpeech-1.1 55 | # pretrained model (20180510_mixture_lj_checkpoint_step000320000_ema.pth) 56 | # hparams (20180510_mixture_lj_checkpoint_step000320000_ema.json) 57 | git checkout 2092a64 58 | python preprocess.py ljspeech ~/data/LJSpeech-1.1 ./data/ljspeech \ 59 | --preset=20180510_mixture_lj_checkpoint_step000320000_ema.json 60 | python synthesis.py --preset=20180510_mixture_lj_checkpoint_step000320000_ema.json \ 61 | --conditional=./data/ljspeech/ljspeech-mel-00001.npy \ 62 | 20180510_mixture_lj_checkpoint_step000320000_ema.pth \ 63 | generated 64 | ``` 65 | 66 | You can find a generated wav file in `generated` directory. Wonder how it works? then take a look at code:) 67 | 68 | ## Repository structure 69 | 70 | The repository consists of 1) pytorch library, 2) command line tools, and 3) [ESPnet](https://github.com/espnet/espnet)-style recipes. The first one is a pytorch library to provide WavaNet functionality. The second one is a set of tools to run WaveNet training/inference, data processing, etc. The last one is the reproducible recipes combining the WaveNet library and utility tools. Please take a look at them depending on your purpose. If you want to build your WaveNet on your dataset (I guess this is the most likely case), the recipe is the way for you. 71 | 72 | ## Requirements 73 | 74 | - Python 3 75 | - CUDA >= 8.0 76 | - PyTorch >= v0.4.0 77 | 78 | ## Installation 79 | 80 | 81 | ``` 82 | git clone https://github.com/r9y9/wavenet_vocoder && cd wavenet_vocoder 83 | pip install -e . 84 | ``` 85 | 86 | If you only need the library part, you can install it from pypi: 87 | 88 | ``` 89 | pip install wavenet_vocoder 90 | ``` 91 | 92 | ## Getting started 93 | 94 | ### Kaldi-style recipes 95 | 96 | The repository provides Kaldi-style recipes to make experiments reproducible and easily manageable. Available recipes are as follows: 97 | 98 | - `mulaw256`: WaveNet that uses categorical output distribution. The input is 8-bit mulaw quantized waveform. 99 | - `mol`: Mixture of Logistics (MoL) WaveNet. The input is 16-bit raw audio. 100 | - `gaussian`: Single-Gaussian WaveNet (a.k.a. teacher WaveNet of [ClariNet](https://clarinet-demo.github.io/)). The input is 16-bit raw audio. 101 | 102 | All the recipe has `run.sh`, which specifies all the steps to perform WaveNet training/inference including data preprocessing. Please see run.sh in [egs](egs) directory for details. 103 | 104 | **NOTICE**: Global conditioning for multi-speaker WaveNet is not supported in the above recipes (it shouldn't be difficult to implement though). Please check v0.1.12 for the feature, or if you *really* need the feature, please raise an issue. 105 | 106 | #### Apply recipe to your own dataset 107 | 108 | The recipes are designed to be generic so that one can use them for any dataset. To apply recipes to your own dataset, you'd need to put *all* the wav files in a single flat directory. i.e., 109 | 110 | ``` 111 | > tree -L 1 ~/data/LJSpeech-1.1/wavs/ | head 112 | /Users/ryuichi/data/LJSpeech-1.1/wavs/ 113 | ├── LJ001-0001.wav 114 | ├── LJ001-0002.wav 115 | ├── LJ001-0003.wav 116 | ├── LJ001-0004.wav 117 | ├── LJ001-0005.wav 118 | ├── LJ001-0006.wav 119 | ├── LJ001-0007.wav 120 | ├── LJ001-0008.wav 121 | ├── LJ001-0009.wav 122 | ``` 123 | 124 | That's it! The last step is to modify `db_root` in run.sh or give `db_root` as the command line argment for run.sh. 125 | 126 | ``` 127 | ./run.sh --stage 0 --stop-stage 0 --db-root ~/data/LJSpeech-1.1/wavs/ 128 | ``` 129 | 130 | ### Step-by-step 131 | 132 | A recipe typically consists of multiple steps. It is strongly recommended to run the recipe step-by-step to understand how it works for the first time. To do so, specify `stage` and `stop_stage` as follows: 133 | 134 | ``` 135 | ./run.sh --stage 0 --stop-stage 0 136 | ``` 137 | 138 | ``` 139 | ./run.sh --stage 1 --stop-stage 1 140 | ``` 141 | 142 | ``` 143 | ./run.sh --stage 2 --stop-stage 2 144 | ``` 145 | 146 | In typical situations, you'd need to specify CUDA devices explciitly expecially for training step. 147 | 148 | ``` 149 | CUDA_VISIBLE_DEVICES="0,1" ./run.sh --stage 2 --stop-stage 2 150 | ``` 151 | 152 | ### Docs for command line tools 153 | 154 | Command line tools are writtern with [docopt](http://docopt.org/). See each docstring for the basic usages. 155 | 156 | #### tojson.py 157 | 158 | Dump hyperparameters to a json file. 159 | 160 | Usage: 161 | 162 | ``` 163 | python tojson.py --hparams="parameters you want to override" 164 | ``` 165 | 166 | #### preprocess.py 167 | 168 | Usage: 169 | 170 | ``` 171 | python preprocess.py wavallin ${dataset_path} ${out_dir} --preset= 172 | ``` 173 | 174 | #### train.py 175 | 176 | > Note: for multi gpu training, you have better ensure that batch_size % num_gpu == 0 177 | 178 | Usage: 179 | 180 | ``` 181 | python train.py --dump-root=${dump-root} --preset=\ 182 | --hparams="parameters you want to override" 183 | ``` 184 | 185 | 186 | #### evaluate.py 187 | 188 | Given a directoy that contains local conditioning features, synthesize waveforms for them. 189 | 190 | Usage: 191 | 192 | ``` 193 | python evaluate.py ${dump_root} ${checkpoint} ${output_dir} --dump-root="data location"\ 194 | --preset= --hparams="parameters you want to override" 195 | ``` 196 | 197 | Options: 198 | 199 | - `--num-utterances=`: Number of utterances to be generated. If not specified, generate all uttereances. This is useful for debugging. 200 | 201 | #### synthesis.py 202 | 203 | **NOTICE**: This is probably not working now. Please use evaluate.py instead. 204 | 205 | Synthesize waveform give a conditioning feature. 206 | 207 | Usage: 208 | 209 | ``` 210 | python synthesis.py ${checkpoint_path} ${output_dir} --preset= --hparams="parameters you want to override" 211 | ``` 212 | 213 | Important options: 214 | 215 | - `--conditional=`: (Required for conditional WaveNet) Path of local conditional features (.npy). If this is specified, number of time steps to generate is determined by the size of conditional feature. 216 | 217 | 218 | ### Training scenarios 219 | 220 | #### Training un-conditional WaveNet 221 | 222 | **NOTICE**: This is probably not working now. Please check v0.1.1 for the working version. 223 | 224 | ``` 225 | python train.py --dump-root=./data/cmu_arctic/ 226 | --hparams="cin_channels=-1,gin_channels=-1" 227 | ``` 228 | 229 | You have to disable global and local conditioning by setting `gin_channels` and `cin_channels` to negative values. 230 | 231 | #### Training WaveNet conditioned on mel-spectrogram 232 | 233 | ``` 234 | python train.py --dump-root=./data/cmu_arctic/ --speaker-id=0 \ 235 | --hparams="cin_channels=80,gin_channels=-1" 236 | ``` 237 | 238 | #### Training WaveNet conditioned on mel-spectrogram and speaker embedding 239 | 240 | **NOTICE**: This is probably not working now. Please check v0.1.1 for the working version. 241 | 242 | ``` 243 | python train.py --dump-root=./data/cmu_arctic/ \ 244 | --hparams="cin_channels=80,gin_channels=16,n_speakers=7" 245 | ``` 246 | 247 | ### Misc 248 | 249 | #### Monitor with Tensorboard 250 | 251 | Logs are dumped in `./log` directory by default. You can monitor logs by tensorboard: 252 | 253 | ``` 254 | tensorboard --logdir=log 255 | ``` 256 | 257 | 258 | ### List of papers that used the repository 259 | 260 | - A Comparison of Recent Neural Vocoders for Speech Signal Reconstruction https://www.isca-speech.org/archive/SSW_2019/abstracts/SSW10_O_1-2.html 261 | - WaveGlow: A Flow-based Generative Network for Speech Synthesis https://arxiv.org/abs/1811.00002 262 | - WaveCycleGAN2: Time-domain Neural Post-filter for Speech Waveform Generation https://arxiv.org/abs/1904.02892 263 | - Parametric Resynthesis with neural vocoders https://arxiv.org/abs/1906.06762 264 | - Representation Mixing fo TTS Synthesis https://arxiv.org/abs/1811.07240 265 | - A Unified Neural Architecture for Instrumental Audio Tasks https://arxiv.org/abs/1903.00142 266 | - ESPnet-TTS: Unified, Reproducible, and Integratable Open Source End-to-End Text-to-Speech Toolkit: https://arxiv.org/abs/1910.10909 267 | 268 | Thank you very much!! If you find a new one, please submit a PR. 269 | 270 | ## Sponsors 271 | 272 | - https://github.com/echelon 273 | 274 | ## References 275 | 276 | - [Aaron van den Oord, Sander Dieleman, Heiga Zen, et al, "WaveNet: A Generative Model for Raw Audio", arXiv:1609.03499, Sep 2016.](https://arxiv.org/abs/1609.03499) 277 | - [Aaron van den Oord, Yazhe Li, Igor Babuschkin, et al, "Parallel WaveNet: Fast High-Fidelity Speech Synthesis", arXiv:1711.10433, Nov 2017.](https://arxiv.org/abs/1711.10433) 278 | - [Tamamori, Akira, et al. "Speaker-dependent WaveNet vocoder." Proceedings of Interspeech. 2017.](http://www.isca-speech.org/archive/Interspeech_2017/pdfs/0314.PDF) 279 | - [Jonathan Shen, Ruoming Pang, Ron J. Weiss, et al, "Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions", arXiv:1712.05884, Dec 2017.](https://arxiv.org/abs/1712.05884) 280 | - [Wei Ping, Kainan Peng, Andrew Gibiansky, et al, "Deep Voice 3: 2000-Speaker Neural Text-to-Speech", arXiv:1710.07654, Oct. 2017.](https://arxiv.org/abs/1710.07654) 281 | - [Tom Le Paine, Pooya Khorrami, Shiyu Chang, et al, "Fast Wavenet Generation Algorithm", arXiv:1611.09482, Nov. 2016](https://arxiv.org/abs/1611.09482) 282 | - [Ye Jia, Yu Zhang, Ron J. Weiss, Quan Wang, Jonathan Shen, Fei Ren, Zhifeng Chen, Patrick Nguyen, Ruoming Pang, Ignacio Lopez Moreno, Yonghui Wu, et al, "Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis" , arXiv:1806.04558v4 cs.CL 2 Jan 2019](https://arxiv.org/abs/1806.04558) 283 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | environment: 2 | matrix: 3 | - PYTHON_VERSION: "3.6" 4 | PYTHON_ARCH: "64" 5 | MINICONDA: C:\Miniconda36-x64 6 | 7 | branches: 8 | only: 9 | - master 10 | - /release-.*/ 11 | 12 | skip_commits: 13 | message: /\[av skip\]/ 14 | 15 | notifications: 16 | - provider: Email 17 | on_build_success: false 18 | on_build_failure: false 19 | on_build_status_changed: false 20 | 21 | init: 22 | - "ECHO %PYTHON_VERSION% %PYTHON_ARCH% %MINICONDA%" 23 | 24 | install: 25 | - "SET PATH=%MINICONDA%;%MINICONDA%\\Scripts;%PATH%" 26 | - conda config --set always_yes yes --set changeps1 no 27 | - conda update -q conda 28 | - conda info -a 29 | - "conda create -q -n test-environment python=%PYTHON_VERSION% numpy scipy cython nose pytorch -c pytorch" 30 | - activate test-environment 31 | 32 | build_script: 33 | - pip install -e ".[test]" 34 | 35 | test_script: 36 | - nosetests -v -w tests/ -a "!local_only" 37 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | from hparams import hparams 5 | from scipy.io import wavfile 6 | from nnmnkwii import preprocessing as P 7 | 8 | 9 | def low_cut_filter(x, fs, cutoff=70): 10 | """APPLY LOW CUT FILTER. 11 | 12 | https://github.com/kan-bayashi/PytorchWaveNetVocoder 13 | 14 | Args: 15 | x (ndarray): Waveform sequence. 16 | fs (int): Sampling frequency. 17 | cutoff (float): Cutoff frequency of low cut filter. 18 | Return: 19 | ndarray: Low cut filtered waveform sequence. 20 | """ 21 | nyquist = fs // 2 22 | norm_cutoff = cutoff / nyquist 23 | from scipy.signal import firwin, lfilter 24 | 25 | # low cut filter 26 | fil = firwin(255, norm_cutoff, pass_zero=False) 27 | lcf_x = lfilter(fil, 1, x) 28 | 29 | return lcf_x 30 | 31 | 32 | def load_wav(path): 33 | sr, x = wavfile.read(path) 34 | signed_int16_max = 2**15 35 | if x.dtype == np.int16: 36 | x = x.astype(np.float32) / signed_int16_max 37 | if sr != hparams.sample_rate: 38 | x = librosa.resample(x, sr, hparams.sample_rate) 39 | x = np.clip(x, -1.0, 1.0) 40 | return x 41 | 42 | 43 | def save_wav(wav, path): 44 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 45 | wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 46 | 47 | 48 | def trim(quantized): 49 | start, end = start_and_end_indices(quantized, hparams.silence_threshold) 50 | return quantized[start:end] 51 | 52 | 53 | def preemphasis(x, coef=0.85): 54 | return P.preemphasis(x, coef) 55 | 56 | 57 | def inv_preemphasis(x, coef=0.85): 58 | return P.inv_preemphasis(x, coef) 59 | 60 | 61 | def adjust_time_resolution(quantized, mel): 62 | """Adjust time resolution by repeating features 63 | 64 | Args: 65 | quantized (ndarray): (T,) 66 | mel (ndarray): (N, D) 67 | 68 | Returns: 69 | tuple: Tuple of (T,) and (T, D) 70 | """ 71 | assert len(quantized.shape) == 1 72 | assert len(mel.shape) == 2 73 | 74 | upsample_factor = quantized.size // mel.shape[0] 75 | mel = np.repeat(mel, upsample_factor, axis=0) 76 | n_pad = quantized.size - mel.shape[0] 77 | if n_pad != 0: 78 | assert n_pad > 0 79 | mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) 80 | 81 | # trim 82 | start, end = start_and_end_indices(quantized, hparams.silence_threshold) 83 | 84 | return quantized[start:end], mel[start:end, :] 85 | 86 | 87 | def start_and_end_indices(quantized, silence_threshold=2): 88 | for start in range(quantized.size): 89 | if abs(quantized[start] - 127) > silence_threshold: 90 | break 91 | for end in range(quantized.size - 1, 1, -1): 92 | if abs(quantized[end] - 127) > silence_threshold: 93 | break 94 | 95 | assert abs(quantized[start] - 127) > silence_threshold 96 | assert abs(quantized[end] - 127) > silence_threshold 97 | 98 | return start, end 99 | 100 | 101 | def logmelspectrogram(y, pad_mode="reflect"): 102 | """Same log-melspectrogram computation as espnet 103 | https://github.com/espnet/espnet 104 | from espnet.transform.spectrogram import logmelspectrogram 105 | """ 106 | D = _stft(y, pad_mode=pad_mode) 107 | S = _linear_to_mel(np.abs(D)) 108 | S = np.log10(np.maximum(S, 1e-10)) 109 | return S 110 | 111 | 112 | def get_hop_size(): 113 | hop_size = hparams.hop_size 114 | if hop_size is None: 115 | assert hparams.frame_shift_ms is not None 116 | hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 117 | return hop_size 118 | 119 | 120 | def get_win_length(): 121 | win_length = hparams.win_length 122 | if win_length < 0: 123 | assert hparams.win_length_ms > 0 124 | win_length = int(hparams.win_length_ms / 1000 * hparams.sample_rate) 125 | return win_length 126 | 127 | 128 | def _stft(y, pad_mode="constant"): 129 | # use constant padding (defaults to zeros) instead of reflection padding 130 | return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size(), 131 | win_length=get_win_length(), window=hparams.window, 132 | pad_mode=pad_mode) 133 | 134 | 135 | def pad_lr(x, fsize, fshift): 136 | return (0, fsize) 137 | 138 | # Conversions: 139 | 140 | 141 | _mel_basis = None 142 | 143 | 144 | def _linear_to_mel(spectrogram): 145 | global _mel_basis 146 | if _mel_basis is None: 147 | _mel_basis = _build_mel_basis() 148 | return np.dot(_mel_basis, spectrogram) 149 | 150 | 151 | def _build_mel_basis(): 152 | if hparams.fmax is not None: 153 | assert hparams.fmax <= hparams.sample_rate // 2 154 | return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, 155 | fmin=hparams.fmin, fmax=hparams.fmax, 156 | n_mels=hparams.num_mels) 157 | 158 | 159 | def _amp_to_db(x): 160 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) 161 | return 20 * np.log10(np.maximum(min_level, x)) 162 | 163 | 164 | def _db_to_amp(x): 165 | return np.power(10.0, x * 0.05) 166 | 167 | 168 | def _normalize(S): 169 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 170 | 171 | 172 | def _denormalize(S): 173 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 174 | -------------------------------------------------------------------------------- /compute-meanvar-stats.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """Compute mean-variance normalization stats. 3 | 4 | usage: compute_meanvar_stats.py [options] 5 | 6 | options: 7 | -h, --help Show help message. 8 | --verbose= Verbosity [default: 0]. 9 | """ 10 | from docopt import docopt 11 | import sys 12 | from tqdm import tqdm 13 | import numpy as np 14 | import json 15 | 16 | from sklearn.preprocessing import StandardScaler 17 | import joblib 18 | 19 | if __name__ == "__main__": 20 | args = docopt(__doc__) 21 | list_file = args[""] 22 | out_path = args[""] 23 | verbose = int(args["--verbose"]) 24 | 25 | scaler = StandardScaler() 26 | with open(list_file) as f: 27 | lines = f.readlines() 28 | assert len(lines) > 0 29 | for path in tqdm(lines): 30 | c = np.load(path.strip()) 31 | scaler.partial_fit(c) 32 | joblib.dump(scaler, out_path) 33 | 34 | if verbose > 0: 35 | print("mean:\n{}".format(scaler.mean_)) 36 | print("var:\n{}".format(scaler.var_)) 37 | 38 | sys.exit(0) 39 | -------------------------------------------------------------------------------- /datasets/wavallin.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | import audio 6 | 7 | from nnmnkwii import preprocessing as P 8 | from hparams import hparams 9 | from os.path import exists, basename, splitext 10 | import librosa 11 | from glob import glob 12 | from os.path import join 13 | 14 | from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw 15 | 16 | 17 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): 18 | executor = ProcessPoolExecutor(max_workers=num_workers) 19 | futures = [] 20 | index = 1 21 | src_files = sorted(glob(join(in_dir, "*.wav"))) 22 | for wav_path in src_files: 23 | futures.append(executor.submit( 24 | partial(_process_utterance, out_dir, index, wav_path, "dummy"))) 25 | index += 1 26 | return [future.result() for future in tqdm(futures)] 27 | 28 | 29 | def _process_utterance(out_dir, index, wav_path, text): 30 | # Load the audio to a numpy array: 31 | wav = audio.load_wav(wav_path) 32 | 33 | # Trim begin/end silences 34 | # NOTE: the threshold was chosen for clean signals 35 | wav, _ = librosa.effects.trim(wav, top_db=60, frame_length=2048, hop_length=512) 36 | 37 | if hparams.highpass_cutoff > 0.0: 38 | wav = audio.low_cut_filter(wav, hparams.sample_rate, hparams.highpass_cutoff) 39 | 40 | # Mu-law quantize 41 | if is_mulaw_quantize(hparams.input_type): 42 | # Trim silences in mul-aw quantized domain 43 | silence_threshold = 0 44 | if silence_threshold > 0: 45 | # [0, quantize_channels) 46 | out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) 47 | start, end = audio.start_and_end_indices(out, silence_threshold) 48 | wav = wav[start:end] 49 | constant_values = P.mulaw_quantize(0, hparams.quantize_channels - 1) 50 | out_dtype = np.int16 51 | elif is_mulaw(hparams.input_type): 52 | # [-1, 1] 53 | constant_values = P.mulaw(0.0, hparams.quantize_channels - 1) 54 | out_dtype = np.float32 55 | else: 56 | # [-1, 1] 57 | constant_values = 0.0 58 | out_dtype = np.float32 59 | 60 | # Compute a mel-scale spectrogram from the trimmed wav: 61 | # (N, D) 62 | mel_spectrogram = audio.logmelspectrogram(wav).astype(np.float32).T 63 | 64 | if hparams.global_gain_scale > 0: 65 | wav *= hparams.global_gain_scale 66 | 67 | # Time domain preprocessing 68 | if hparams.preprocess is not None and hparams.preprocess not in ["", "none"]: 69 | f = getattr(audio, hparams.preprocess) 70 | wav = f(wav) 71 | 72 | # Clip 73 | if np.abs(wav).max() > 1.0: 74 | print("""Warning: abs max value exceeds 1.0: {}""".format(np.abs(wav).max())) 75 | # ignore this sample 76 | return ("dummy", "dummy", -1, "dummy") 77 | 78 | wav = np.clip(wav, -1.0, 1.0) 79 | 80 | # Set waveform target (out) 81 | if is_mulaw_quantize(hparams.input_type): 82 | out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) 83 | elif is_mulaw(hparams.input_type): 84 | out = P.mulaw(wav, hparams.quantize_channels - 1) 85 | else: 86 | out = wav 87 | 88 | # zero pad 89 | # this is needed to adjust time resolution between audio and mel-spectrogram 90 | l, r = audio.pad_lr(out, hparams.fft_size, audio.get_hop_size()) 91 | if l > 0 or r > 0: 92 | out = np.pad(out, (l, r), mode="constant", constant_values=constant_values) 93 | N = mel_spectrogram.shape[0] 94 | assert len(out) >= N * audio.get_hop_size() 95 | 96 | # time resolution adjustment 97 | # ensure length of raw audio is multiple of hop_size so that we can use 98 | # transposed convolution to upsample 99 | out = out[:N * audio.get_hop_size()] 100 | assert len(out) % audio.get_hop_size() == 0 101 | 102 | # Write the spectrograms to disk: 103 | name = splitext(basename(wav_path))[0] 104 | audio_filename = '%s-wave.npy' % (name) 105 | mel_filename = '%s-feats.npy' % (name) 106 | np.save(os.path.join(out_dir, audio_filename), 107 | out.astype(out_dtype), allow_pickle=False) 108 | np.save(os.path.join(out_dir, mel_filename), 109 | mel_spectrogram.astype(np.float32), allow_pickle=False) 110 | 111 | # Return a tuple describing this training example: 112 | return (audio_filename, mel_filename, N, text) 113 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | public 2 | static/audio -------------------------------------------------------------------------------- /docs/config.toml: -------------------------------------------------------------------------------- 1 | baseURL = "https://r9y9.github.io/wavenet_vocoder/" 2 | languageCode = "ja-jp" 3 | title = "An open source implementation of WaveNet vocoder" 4 | author = "Ryuichi YAMAMOTO" 5 | 6 | [params] 7 | author = "Ryuichi YAMAMOTO" 8 | project = "wavenet_vocoder" 9 | logo = "/images/r9y9.jpg" 10 | twitter = "r9y9" 11 | github = "r9y9" 12 | analytics = "UA-44433856-1" 13 | -------------------------------------------------------------------------------- /docs/content/index.md: -------------------------------------------------------------------------------- 1 | +++ 2 | Categories = [] 3 | Description = "" 4 | Keywords = [] 5 | Tags = [] 6 | date = "2018-01-04T19:42:01+09:00" 7 | title = "index" 8 | type = "index" 9 | +++ 10 | 11 |
12 | 13 | - Github: https://github.com/r9y9/wavenet_vocoder 14 | 15 | This page provides audio samples for the open source implementation of the **WaveNet (WN)** vocoder. 16 | Text-to-speech samples are found at the last section. 17 | 18 | - WN conditioned on mel-spectrogram (16-bit linear PCM, 22.5kHz) 19 | - WN conditioned on mel-spectrogram (8-bit mu-law, 16kHz) 20 | - WN conditioned on mel-spectrogram and speaker-embedding (16-bit linear PCM, 16kHz) 21 | - Tacotron2: WN-based text-to-speech (**New!**) 22 | 23 | ## WN conditioned on mel-spectrogram (16-bit linear PCM, 22.5kHz) 24 | 25 | - Samples from a model trained for over 400k steps. 26 | - Left: generated, Right: ground truth 27 | 28 | 32 | 36 | 37 | 41 | 45 | 46 | 50 | 54 | 55 | 59 | 63 | 64 | 68 | 72 | 73 | 77 | 81 | 82 | 86 | 90 | 91 | 95 | 99 | 100 | 104 | 108 | 109 | 113 | 117 | 118 | | key | value | 119 | |---------------------------------|------------------------------------------------------| 120 | | Data | LJSpeech (12522 for training, 578 for testing) | 121 | | Input type | 16-bit linear PCM | 122 | | Sampling frequency | 22.5kHz | 123 | | Local conditioning | 80-dim mel-spectrogram | 124 | | Hop size | 256 | 125 | | Global conditioning | N/A | 126 | | Total layers | 24 | 127 | | Num cycles | 4 | 128 | | Residual / Gate / Skip-out channels | 512 / 512 / 256 | 129 | | Receptive field (samples / ms) | 505 / 22.9 | 130 | | Numer of mixtures | 10 | 131 | | Number of upsampling layers | 4 | 132 | 133 | ## WN conditioned on mel-spectrogram (8-bit mu-law, 16kHz) 134 | 135 | - Samples from a model trained for 100k steps (~22 hours) 136 | - Left: generated, Right: (mu-law encoded) ground truth 137 | 138 | 142 | 146 | 147 | 151 | 155 | 156 | 160 | 164 | 165 | 169 | 173 | 174 | 178 | 182 | 183 | 187 | 191 | 192 | 196 | 200 | 201 | 205 | 209 | 210 | 214 | 218 | 219 | 223 | 227 | 228 | | key | value | 229 | |---------------------------------|------------------------------------------------------| 230 | | Data | CMU ARCTIC (`clb`) (1183 for training, 50 for testing) | 231 | | Input type | 8-bit mu-law encoded one-hot vector | 232 | | Sampling frequency | 16kHz | 233 | | Local conditioning | 80-dim mel-spectrogram | 234 | | Hop size | 256 | 235 | | Global conditioning | N/A | 236 | | Total layers | 16 | 237 | | Num cycles | 2 | 238 | | Residual / Gate / Skip-out channels | 512 / 512 / 256 | 239 | | Receptive field (samples / ms) | 1021 / 63.8 | 240 | | Number of upsampling layers | N/A | 241 | 242 | 243 | ## WN conditioned on mel-spectrogram and speaker-embedding (16-bit linear PCM, 16kHz) 244 | 245 | - Samples from a model trained for over 1000k steps 246 | - Left: generated, Right: ground truth 247 | 248 | **awb** 249 | 250 | 254 | 258 | 259 | 263 | 267 | 268 | **bdl** 269 | 270 | 274 | 278 | 279 | 283 | 287 | 288 | **clb** 289 | 290 | 294 | 298 | 299 | 303 | 307 | 308 | **jmk** 309 | 310 | 314 | 318 | 319 | 323 | 327 | 328 | 329 | **ksp** 330 | 331 | 335 | 339 | 340 | 344 | 348 | 349 | 350 | **rms** 351 | 352 | 356 | 360 | 361 | 365 | 369 | 370 | **slt** 371 | 372 | 376 | 380 | 381 | 385 | 389 | 390 | | key | value | 391 | |---------------------------------|------------------------------------------------------| 392 | | Data | CMU ARCTIC (7580 for training, 350 for testing) | 393 | | Input type | 16-bit linear PCM | 394 | | Local conditioning | 80-dim mel-spectrogram | 395 | | Hop size | 256 | 396 | | Global conditioning | 16-dim speaker embedding [^1] | 397 | | Total layers | 24 | 398 | | Num cycles | 4 | 399 | | Residual / Gate / Skip-out channels | 512 / 512 / 256 | 400 | | Receptive field (samples / ms) | 505 / 22.9 | 401 | | Numer of mixtures | 10 | 402 | | Number of upsampling layers | 4 | 403 | 404 | [^1]: Note that mel-spectrogram used in local conditioning is dependent on speaker characteristics, so we cannot simply change the speaker identity of the generated audio samples using the model. It should work without speaker embedding, but it might have helped training speed. 405 | 406 | ## Tacotron2: WN-based text-to-speech 407 | 408 | - Tacotron2 (mel-spectrogram prediction part): trained 189k steps on LJSpeech dataset ([Pre-trained model](https://www.dropbox.com/s/vx7y4qqs732sqgg/pretrained.tar.gz?dl=0), [Hyper params](https://github.com/r9y9/Tacotron-2/blob/9ce1a0e65b9217cdc19599c192c5cd68b4cece5b/hparams.py)). The work has been done by [@Rayhane-mamah](https://github.com/Rayhane-mamah). See https://github.com/Rayhane-mamah/Tacotron-2 for details. 409 | - WaveNet: trained over 1000k steps on LJSpeech dataset ([Pre-trained model](https://www.dropbox.com/s/zdbfprugbagfp2w/20180510_mixture_lj_checkpoint_step000320000_ema.pth?dl=0), [Hyper params](https://www.dropbox.com/s/0vsd7973w20eskz/20180510_mixture_lj_checkpoint_step000320000_ema.json?dl=0)) 410 | 411 | 412 | Scientists at the CERN laboratory say they have discovered a new particle. 413 | 414 | 418 | 419 | 420 | There's a way to measure the acute emotional intelligence that has never gone out of style. 421 | 422 | 426 | 427 | 428 | President Trump met with other leaders at the Group of 20 conference. 429 | 430 | 434 | 435 | 436 | The Senate's bill to repeal and replace the Affordable Care Act is now imperiled. 437 | 438 | 442 | 443 | 444 | Generative adversarial network or variational auto-encoder. 445 | 446 | 450 | 451 | 452 | Basilar membrane and otolaryngology are not auto-correlations. 453 | 454 | 458 | 459 | 460 | He has read the whole thing. 461 | 462 | 466 | 467 | 468 | He reads books. 469 | 470 | 474 | 475 | 476 | Don't desert me here in the desert! 477 | 478 | 482 | 483 | 484 | He thought it was time to present the present. 485 | 486 | 490 | 491 | Thisss isrealy awhsome. 492 | 493 | 497 | 498 | 499 | Punctuation sensitivity, is working. 500 | 501 | 505 | 506 | 507 | Punctuation sensitivity is working. 508 | 509 | 513 | 514 | 515 | The buses aren't the problem, they actually provide a solution. 516 | 517 | 521 | 522 | 523 | The buses aren't the PROBLEM, they actually provide a SOLUTION. 524 | 525 | 529 | 530 | 531 | The quick brown fox jumps over the lazy dog. 532 | 533 | 537 | 538 | Does the quick brown fox jump over the lazy dog? 539 | 540 | 544 | 545 | 546 | Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick? 547 | 548 | 552 | 553 | 554 | She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure. 555 | 556 | 560 | 561 | 562 | The blue lagoon is a nineteen eighty American romance adventure film. 563 | 564 | 568 | 569 | 570 | ### On-line demo 571 | 572 | A demonstration notebook supposed to be run on Google colab can be found at [Tacotron2: WaveNet-basd text-to-speech demo](https://colab.research.google.com/github/r9y9/Colaboratory/blob/master/Tacotron2_and_WaveNet_text_to_speech_demo.ipynb). 573 | 574 | 575 | ## References 576 | 577 | - [Aaron van den Oord, Sander Dieleman, Heiga Zen, et al, "WaveNet: A Generative Model for Raw Audio", arXiv:1609.03499, Sep 2016.](https://arxiv.org/abs/1609.03499) 578 | - [Aaron van den Oord, Yazhe Li, Igor Babuschkin, et al, "Parallel WaveNet: Fast High-Fidelity Speech Synthesis", arXiv:1711.10433, Nov 2017.](https://arxiv.org/abs/1711.10433) 579 | - [Tamamori, Akira, et al. "Speaker-dependent WaveNet vocoder." Proceedings of Interspeech. 2017.](http://www.isca-speech.org/archive/Interspeech_2017/pdfs/0314.PDF) 580 | - [Jonathan Shen, Ruoming Pang, Ron J. Weiss, et al, "Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions", arXiv:1712.05884, Dec 2017.](https://arxiv.org/abs/1712.05884) 581 | - [Wei Ping, Kainan Peng, Andrew Gibiansky, et al, "Deep Voice 3: 2000-Speaker Neural Text-to-Speech", arXiv:1710.07654, Oct. 2017.](https://arxiv.org/abs/1710.07654) 582 | - [Jonathan Shen, Ruoming Pang, Ron J. Weiss, et al, "Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions", arXiv:1712.05884, Dec 2017.](https://arxiv.org/abs/1712.05884) 583 | -------------------------------------------------------------------------------- /docs/layouts/_default/list.html: -------------------------------------------------------------------------------- 1 | {{ partial "header.html" . }} 2 | 3 |
4 |

{{ .Title }}

5 | {{ range .Data.Pages }} 6 | 10 | {{ end }} 11 |
12 | 13 | {{ partial "footer.html" . }} -------------------------------------------------------------------------------- /docs/layouts/_default/single.html: -------------------------------------------------------------------------------- 1 | {{ partial "header.html" . }} 2 | 3 |
4 |
5 |

{{ .Title }}

6 | 7 |
8 | {{ .Content }} 9 | {{ partial "social.html" . }} 10 |
11 |
12 |
13 | 14 | {{ partial "footer.html" . }} 15 | -------------------------------------------------------------------------------- /docs/layouts/index.html: -------------------------------------------------------------------------------- 1 | {{ template "partials/header.html" . }} 2 | {{ range .Data.Pages }} 3 | {{if eq .Type "index" }} 4 | {{.Content}} 5 | {{end}} 6 | {{ end }} 7 | {{ template "partials/footer.html" . }} 8 | -------------------------------------------------------------------------------- /docs/layouts/partials/footer.html: -------------------------------------------------------------------------------- 1 | 2 | 22 | 23 | 24 | 25 | {{ with .Site.Params.analytics }}{{ end }} 33 | 34 | 35 | 36 | 37 | {{ partial "mathjax.html" . }} 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/layouts/partials/header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ .Hugo.Generator }} 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | {{ $isHomePage := eq .Title .Site.Title }}{{ .Title }}{{ if eq $isHomePage false }} - {{ .Site.Title }}{{ end }} 15 | 16 | 17 | 18 |
19 | 20 |
21 | 24 | {{ if eq $isHomePage true }}

{{ .Site.Title }}

{{ end }} 25 |
26 | -------------------------------------------------------------------------------- /docs/layouts/partials/mathjax.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 22 | 23 | 30 | 31 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/layouts/partials/social.html: -------------------------------------------------------------------------------- 1 | {{ if isset .Site.Params "twitter" }} 2 | 8 | {{ end }} 9 | -------------------------------------------------------------------------------- /docs/static/css/custom.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Roboto", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; 3 | background-color: #FCFCFC; 4 | -webkit-font-smoothing: antialiased; 5 | font-size: 1.8em; 6 | line-height: 1.5; 7 | font-weight: 300; 8 | } 9 | 10 | h1, h2, h3, h4, h5, h6 { 11 | color: #263c4c; 12 | } 13 | h2, h3, h4, h5, h6 { 14 | margin-top: 5rem; 15 | margin-bottom: 3rem; 16 | font-weight: bold; 17 | padding-bottom: 10px; 18 | } 19 | 20 | h1 { font-size: 3.0rem; } 21 | h2 { 22 | margin-top: 6rem; 23 | font-size: 2.6rem; 24 | } 25 | h3 { font-size: 2.1rem; } 26 | h4, 27 | h5, 28 | h6 { font-size: 1.9rem; } 29 | 30 | h2.entry-title { 31 | font-size: 2.1rem; 32 | margin-top: 0; 33 | font-weight: 400; 34 | border-bottom: none; 35 | } 36 | 37 | li { 38 | margin-bottom: 0.5rem; 39 | margin-left: 0.7em; 40 | } 41 | 42 | img { 43 | max-width: 100%; 44 | height: auto; 45 | vertical-align: middle; 46 | border: 0; 47 | margin: 1em 0; 48 | } 49 | 50 | header, 51 | footer { 52 | margin: 4rem 0; 53 | text-align: center; 54 | } 55 | 56 | main { 57 | margin: 4rem 0; 58 | } 59 | 60 | .container { 61 | width: 90%; 62 | max-width: 700px; 63 | } 64 | 65 | .header-logo img { 66 | border-radius: 50%; 67 | border: 2px solid #E1E1E1; 68 | } 69 | 70 | .header-logo img:hover { 71 | border-color: #F1F1F1; 72 | } 73 | 74 | .site-title { 75 | margin-top: 2rem; 76 | } 77 | 78 | .entry-title { 79 | margin-bottom: 0; 80 | } 81 | 82 | .entry-title a { 83 | text-decoration: none; 84 | } 85 | 86 | .entry-meta { 87 | display: inline-block; 88 | margin-bottom: 2rem; 89 | font-size: 1.6rem; 90 | color: #888; 91 | } 92 | 93 | .footer-link { 94 | margin: 2rem 0; 95 | } 96 | 97 | .hr { 98 | height: 1px; 99 | margin: 2rem 0; 100 | background: #E1E1E1; 101 | background: -webkit-gradient(linear, left top, right top, from(white), color-stop(#E1E1E1), to(white)); 102 | background: -webkit-linear-gradient(left, white, #E1E1E1, white); 103 | background: linear-gradient(to right, white, #E1E1E1, white); 104 | } 105 | 106 | article .social { 107 | height: 40px; 108 | padding: 10px 0; 109 | } 110 | 111 | address { 112 | margin: 0; 113 | font-size:0.9em; 114 | max-height: 60px; 115 | font-weight: 300; 116 | font-style: normal; 117 | display: block; 118 | } 119 | 120 | address a { 121 | text-decoration: none; 122 | } 123 | 124 | .avatar-bottom img { 125 | border-radius: 50%; 126 | border: 1px solid #E1E1E1; 127 | float: left; 128 | max-width: 100%; 129 | vertical-align: middle; 130 | width: 32px; 131 | height: 32px; 132 | margin: 0 20px 0 0; 133 | margin-top: -7px; 134 | } 135 | 136 | .avatar-bottom img:hover { 137 | border-color: #F1F1F1; 138 | } 139 | 140 | .copyright { 141 | font-size:0.9em; 142 | font-weight: 300; 143 | } 144 | 145 | .github { 146 | float: right; 147 | } 148 | 149 | blockquote { 150 | position: relative; 151 | padding: 10px 10px 10px 32px; 152 | box-sizing: border-box; 153 | font-style: italic; 154 | color: #464646; 155 | background: #e0e0e0; 156 | } 157 | 158 | blockquote:before{ 159 | display: inline-block; 160 | position: absolute; 161 | top: 0; 162 | left: 0; 163 | vertical-align: middle; 164 | content: "\f10d"; 165 | font-family: FontAwesome; 166 | color: #e0e0e0; 167 | font-size: 22px; 168 | line-height: 1; 169 | z-index: 2; 170 | } 171 | 172 | blockquote:after{ 173 | position: absolute; 174 | content: ''; 175 | left: 0; 176 | top: 0; 177 | border-width: 0 0 40px 40px; 178 | border-style: solid; 179 | border-color: transparent #ffffff; 180 | } 181 | 182 | blockquote p { 183 | position: relative; 184 | padding: 0; 185 | margin: 10px 0; 186 | z-index: 3; 187 | line-height: 1.7; 188 | } 189 | 190 | blockquote cite { 191 | display: block; 192 | text-align: right; 193 | color: #888888; 194 | font-size: 0.9em; 195 | } 196 | -------------------------------------------------------------------------------- /docs/static/css/normalize.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v3.0.2 | MIT License | git.io/normalize */ 2 | 3 | /** 4 | * 1. Set default font family to sans-serif. 5 | * 2. Prevent iOS text size adjust after orientation change, without disabling 6 | * user zoom. 7 | */ 8 | 9 | html { 10 | font-family: sans-serif; /* 1 */ 11 | -ms-text-size-adjust: 100%; /* 2 */ 12 | -webkit-text-size-adjust: 100%; /* 2 */ 13 | } 14 | 15 | /** 16 | * Remove default margin. 17 | */ 18 | 19 | body { 20 | margin: 0; 21 | } 22 | 23 | /* HTML5 display definitions 24 | ========================================================================== */ 25 | 26 | /** 27 | * Correct `block` display not defined for any HTML5 element in IE 8/9. 28 | * Correct `block` display not defined for `details` or `summary` in IE 10/11 29 | * and Firefox. 30 | * Correct `block` display not defined for `main` in IE 11. 31 | */ 32 | 33 | article, 34 | aside, 35 | details, 36 | figcaption, 37 | figure, 38 | footer, 39 | header, 40 | hgroup, 41 | main, 42 | menu, 43 | nav, 44 | section, 45 | summary { 46 | display: block; 47 | } 48 | 49 | /** 50 | * 1. Correct `inline-block` display not defined in IE 8/9. 51 | * 2. Normalize vertical alignment of `progress` in Chrome, Firefox, and Opera. 52 | */ 53 | 54 | audio, 55 | canvas, 56 | progress, 57 | video { 58 | display: inline-block; /* 1 */ 59 | vertical-align: baseline; /* 2 */ 60 | } 61 | 62 | /** 63 | * Prevent modern browsers from displaying `audio` without controls. 64 | * Remove excess height in iOS 5 devices. 65 | */ 66 | 67 | audio:not([controls]) { 68 | display: none; 69 | height: 0; 70 | } 71 | 72 | /** 73 | * Address `[hidden]` styling not present in IE 8/9/10. 74 | * Hide the `template` element in IE 8/9/11, Safari, and Firefox < 22. 75 | */ 76 | 77 | [hidden], 78 | template { 79 | display: none; 80 | } 81 | 82 | /* Links 83 | ========================================================================== */ 84 | 85 | /** 86 | * Remove the gray background color from active links in IE 10. 87 | */ 88 | 89 | a { 90 | background-color: transparent; 91 | } 92 | 93 | /** 94 | * Improve readability when focused and also mouse hovered in all browsers. 95 | */ 96 | 97 | a:active, 98 | a:hover { 99 | outline: 0; 100 | } 101 | 102 | /* Text-level semantics 103 | ========================================================================== */ 104 | 105 | /** 106 | * Address styling not present in IE 8/9/10/11, Safari, and Chrome. 107 | */ 108 | 109 | abbr[title] { 110 | border-bottom: 1px dotted; 111 | } 112 | 113 | /** 114 | * Address style set to `bolder` in Firefox 4+, Safari, and Chrome. 115 | */ 116 | 117 | b, 118 | strong { 119 | font-weight: bold; 120 | } 121 | 122 | /** 123 | * Address styling not present in Safari and Chrome. 124 | */ 125 | 126 | dfn { 127 | font-style: italic; 128 | } 129 | 130 | /** 131 | * Address variable `h1` font-size and margin within `section` and `article` 132 | * contexts in Firefox 4+, Safari, and Chrome. 133 | */ 134 | 135 | h1 { 136 | font-size: 2em; 137 | margin: 0.67em 0; 138 | } 139 | 140 | /** 141 | * Address styling not present in IE 8/9. 142 | */ 143 | 144 | mark { 145 | background: #ff0; 146 | color: #000; 147 | } 148 | 149 | /** 150 | * Address inconsistent and variable font size in all browsers. 151 | */ 152 | 153 | small { 154 | font-size: 80%; 155 | } 156 | 157 | /** 158 | * Prevent `sub` and `sup` affecting `line-height` in all browsers. 159 | */ 160 | 161 | sub, 162 | sup { 163 | font-size: 75%; 164 | line-height: 0; 165 | position: relative; 166 | vertical-align: baseline; 167 | } 168 | 169 | sup { 170 | top: -0.5em; 171 | } 172 | 173 | sub { 174 | bottom: -0.25em; 175 | } 176 | 177 | /* Embedded content 178 | ========================================================================== */ 179 | 180 | /** 181 | * Remove border when inside `a` element in IE 8/9/10. 182 | */ 183 | 184 | img { 185 | border: 0; 186 | } 187 | 188 | /** 189 | * Correct overflow not hidden in IE 9/10/11. 190 | */ 191 | 192 | svg:not(:root) { 193 | overflow: hidden; 194 | } 195 | 196 | /* Grouping content 197 | ========================================================================== */ 198 | 199 | /** 200 | * Address margin not present in IE 8/9 and Safari. 201 | */ 202 | 203 | figure { 204 | margin: 1em 40px; 205 | } 206 | 207 | /** 208 | * Address differences between Firefox and other browsers. 209 | */ 210 | 211 | hr { 212 | -moz-box-sizing: content-box; 213 | box-sizing: content-box; 214 | height: 0; 215 | } 216 | 217 | /** 218 | * Contain overflow in all browsers. 219 | */ 220 | 221 | pre { 222 | overflow: auto; 223 | } 224 | 225 | /** 226 | * Address odd `em`-unit font size rendering in all browsers. 227 | */ 228 | 229 | code, 230 | kbd, 231 | pre, 232 | samp { 233 | font-family: monospace, monospace; 234 | font-size: 1em; 235 | } 236 | 237 | /* Forms 238 | ========================================================================== */ 239 | 240 | /** 241 | * Known limitation: by default, Chrome and Safari on OS X allow very limited 242 | * styling of `select`, unless a `border` property is set. 243 | */ 244 | 245 | /** 246 | * 1. Correct color not being inherited. 247 | * Known issue: affects color of disabled elements. 248 | * 2. Correct font properties not being inherited. 249 | * 3. Address margins set differently in Firefox 4+, Safari, and Chrome. 250 | */ 251 | 252 | button, 253 | input, 254 | optgroup, 255 | select, 256 | textarea { 257 | color: inherit; /* 1 */ 258 | font: inherit; /* 2 */ 259 | margin: 0; /* 3 */ 260 | } 261 | 262 | /** 263 | * Address `overflow` set to `hidden` in IE 8/9/10/11. 264 | */ 265 | 266 | button { 267 | overflow: visible; 268 | } 269 | 270 | /** 271 | * Address inconsistent `text-transform` inheritance for `button` and `select`. 272 | * All other form control elements do not inherit `text-transform` values. 273 | * Correct `button` style inheritance in Firefox, IE 8/9/10/11, and Opera. 274 | * Correct `select` style inheritance in Firefox. 275 | */ 276 | 277 | button, 278 | select { 279 | text-transform: none; 280 | } 281 | 282 | /** 283 | * 1. Avoid the WebKit bug in Android 4.0.* where (2) destroys native `audio` 284 | * and `video` controls. 285 | * 2. Correct inability to style clickable `input` types in iOS. 286 | * 3. Improve usability and consistency of cursor style between image-type 287 | * `input` and others. 288 | */ 289 | 290 | button, 291 | html input[type="button"], /* 1 */ 292 | input[type="reset"], 293 | input[type="submit"] { 294 | -webkit-appearance: button; /* 2 */ 295 | cursor: pointer; /* 3 */ 296 | } 297 | 298 | /** 299 | * Re-set default cursor for disabled elements. 300 | */ 301 | 302 | button[disabled], 303 | html input[disabled] { 304 | cursor: default; 305 | } 306 | 307 | /** 308 | * Remove inner padding and border in Firefox 4+. 309 | */ 310 | 311 | button::-moz-focus-inner, 312 | input::-moz-focus-inner { 313 | border: 0; 314 | padding: 0; 315 | } 316 | 317 | /** 318 | * Address Firefox 4+ setting `line-height` on `input` using `!important` in 319 | * the UA stylesheet. 320 | */ 321 | 322 | input { 323 | line-height: normal; 324 | } 325 | 326 | /** 327 | * It's recommended that you don't attempt to style these elements. 328 | * Firefox's implementation doesn't respect box-sizing, padding, or width. 329 | * 330 | * 1. Address box sizing set to `content-box` in IE 8/9/10. 331 | * 2. Remove excess padding in IE 8/9/10. 332 | */ 333 | 334 | input[type="checkbox"], 335 | input[type="radio"] { 336 | box-sizing: border-box; /* 1 */ 337 | padding: 0; /* 2 */ 338 | } 339 | 340 | /** 341 | * Fix the cursor style for Chrome's increment/decrement buttons. For certain 342 | * `font-size` values of the `input`, it causes the cursor style of the 343 | * decrement button to change from `default` to `text`. 344 | */ 345 | 346 | input[type="number"]::-webkit-inner-spin-button, 347 | input[type="number"]::-webkit-outer-spin-button { 348 | height: auto; 349 | } 350 | 351 | /** 352 | * 1. Address `appearance` set to `searchfield` in Safari and Chrome. 353 | * 2. Address `box-sizing` set to `border-box` in Safari and Chrome 354 | * (include `-moz` to future-proof). 355 | */ 356 | 357 | input[type="search"] { 358 | -webkit-appearance: textfield; /* 1 */ 359 | -moz-box-sizing: content-box; 360 | -webkit-box-sizing: content-box; /* 2 */ 361 | box-sizing: content-box; 362 | } 363 | 364 | /** 365 | * Remove inner padding and search cancel button in Safari and Chrome on OS X. 366 | * Safari (but not Chrome) clips the cancel button when the search input has 367 | * padding (and `textfield` appearance). 368 | */ 369 | 370 | input[type="search"]::-webkit-search-cancel-button, 371 | input[type="search"]::-webkit-search-decoration { 372 | -webkit-appearance: none; 373 | } 374 | 375 | /** 376 | * Define consistent border, margin, and padding. 377 | */ 378 | 379 | fieldset { 380 | border: 1px solid #c0c0c0; 381 | margin: 0 2px; 382 | padding: 0.35em 0.625em 0.75em; 383 | } 384 | 385 | /** 386 | * 1. Correct `color` not being inherited in IE 8/9/10/11. 387 | * 2. Remove padding so people aren't caught out if they zero out fieldsets. 388 | */ 389 | 390 | legend { 391 | border: 0; /* 1 */ 392 | padding: 0; /* 2 */ 393 | } 394 | 395 | /** 396 | * Remove default vertical scrollbar in IE 8/9/10/11. 397 | */ 398 | 399 | textarea { 400 | overflow: auto; 401 | } 402 | 403 | /** 404 | * Don't inherit the `font-weight` (applied by a rule above). 405 | * NOTE: the default cannot safely be changed in Chrome and Safari on OS X. 406 | */ 407 | 408 | optgroup { 409 | font-weight: bold; 410 | } 411 | 412 | /* Tables 413 | ========================================================================== */ 414 | 415 | /** 416 | * Remove most spacing between table cells. 417 | */ 418 | 419 | table { 420 | border-collapse: collapse; 421 | border-spacing: 0; 422 | } 423 | 424 | td, 425 | th { 426 | padding: 0; 427 | } -------------------------------------------------------------------------------- /docs/static/css/skeleton.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Skeleton V2.0.4 3 | * Copyright 2014, Dave Gamache 4 | * www.getskeleton.com 5 | * Free to use under the MIT license. 6 | * http://www.opensource.org/licenses/mit-license.php 7 | * 12/29/2014 8 | */ 9 | 10 | 11 | /* Table of contents 12 | –––––––––––––––––––––––––––––––––––––––––––––––––– 13 | - Grid 14 | - Base Styles 15 | - Typography 16 | - Links 17 | - Buttons 18 | - Forms 19 | - Lists 20 | - Code 21 | - Tables 22 | - Spacing 23 | - Utilities 24 | - Clearing 25 | - Media Queries 26 | */ 27 | 28 | 29 | /* Grid 30 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 31 | .container { 32 | position: relative; 33 | width: 100%; 34 | max-width: 960px; 35 | margin: 0 auto; 36 | padding: 0 20px; 37 | box-sizing: border-box; } 38 | .column, 39 | .columns { 40 | width: 100%; 41 | float: left; 42 | box-sizing: border-box; } 43 | 44 | /* For devices larger than 400px */ 45 | @media (min-width: 400px) { 46 | .container { 47 | width: 85%; 48 | padding: 0; } 49 | } 50 | 51 | /* For devices larger than 550px */ 52 | @media (min-width: 550px) { 53 | .container { 54 | width: 80%; } 55 | .column, 56 | .columns { 57 | margin-left: 4%; } 58 | .column:first-child, 59 | .columns:first-child { 60 | margin-left: 0; } 61 | 62 | .one.column, 63 | .one.columns { width: 4.66666666667%; } 64 | .two.columns { width: 13.3333333333%; } 65 | .three.columns { width: 22%; } 66 | .four.columns { width: 30.6666666667%; } 67 | .five.columns { width: 39.3333333333%; } 68 | .six.columns { width: 48%; } 69 | .seven.columns { width: 56.6666666667%; } 70 | .eight.columns { width: 65.3333333333%; } 71 | .nine.columns { width: 74.0%; } 72 | .ten.columns { width: 82.6666666667%; } 73 | .eleven.columns { width: 91.3333333333%; } 74 | .twelve.columns { width: 100%; margin-left: 0; } 75 | 76 | .one-third.column { width: 30.6666666667%; } 77 | .two-thirds.column { width: 65.3333333333%; } 78 | 79 | .one-half.column { width: 48%; } 80 | 81 | /* Offsets */ 82 | .offset-by-one.column, 83 | .offset-by-one.columns { margin-left: 8.66666666667%; } 84 | .offset-by-two.column, 85 | .offset-by-two.columns { margin-left: 17.3333333333%; } 86 | .offset-by-three.column, 87 | .offset-by-three.columns { margin-left: 26%; } 88 | .offset-by-four.column, 89 | .offset-by-four.columns { margin-left: 34.6666666667%; } 90 | .offset-by-five.column, 91 | .offset-by-five.columns { margin-left: 43.3333333333%; } 92 | .offset-by-six.column, 93 | .offset-by-six.columns { margin-left: 52%; } 94 | .offset-by-seven.column, 95 | .offset-by-seven.columns { margin-left: 60.6666666667%; } 96 | .offset-by-eight.column, 97 | .offset-by-eight.columns { margin-left: 69.3333333333%; } 98 | .offset-by-nine.column, 99 | .offset-by-nine.columns { margin-left: 78.0%; } 100 | .offset-by-ten.column, 101 | .offset-by-ten.columns { margin-left: 86.6666666667%; } 102 | .offset-by-eleven.column, 103 | .offset-by-eleven.columns { margin-left: 95.3333333333%; } 104 | 105 | .offset-by-one-third.column, 106 | .offset-by-one-third.columns { margin-left: 34.6666666667%; } 107 | .offset-by-two-thirds.column, 108 | .offset-by-two-thirds.columns { margin-left: 69.3333333333%; } 109 | 110 | .offset-by-one-half.column, 111 | .offset-by-one-half.columns { margin-left: 52%; } 112 | 113 | } 114 | 115 | 116 | /* Base Styles 117 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 118 | /* NOTE 119 | html is set to 62.5% so that all the REM measurements throughout Skeleton 120 | are based on 10px sizing. So basically 1.5rem = 15px :) */ 121 | html { 122 | font-size: 62.5%; } 123 | body { 124 | font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */ 125 | line-height: 1.6; 126 | font-weight: 400; 127 | font-family: "Raleway", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; 128 | color: #222; } 129 | 130 | 131 | /* Typography 132 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 133 | h1, h2, h3, h4, h5, h6 { 134 | margin-top: 0; 135 | margin-bottom: 2rem; 136 | font-weight: 300; } 137 | h1 { font-size: 4.0rem; line-height: 1.2; letter-spacing: -.1rem;} 138 | h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; } 139 | h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; } 140 | h4 { font-size: 2.4rem; line-height: 1.35; letter-spacing: -.08rem; } 141 | h5 { font-size: 1.8rem; line-height: 1.5; letter-spacing: -.05rem; } 142 | h6 { font-size: 1.5rem; line-height: 1.6; letter-spacing: 0; } 143 | 144 | /* Larger than phablet */ 145 | @media (min-width: 550px) { 146 | h1 { font-size: 5.0rem; } 147 | h2 { font-size: 4.2rem; } 148 | h3 { font-size: 3.6rem; } 149 | h4 { font-size: 3.0rem; } 150 | h5 { font-size: 2.4rem; } 151 | h6 { font-size: 1.5rem; } 152 | } 153 | 154 | p { 155 | margin-top: 0; } 156 | 157 | 158 | /* Links 159 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 160 | a { 161 | color: #1EAEDB; } 162 | a:hover { 163 | color: #0FA0CE; } 164 | 165 | 166 | /* Buttons 167 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 168 | .button, 169 | button, 170 | input[type="submit"], 171 | input[type="reset"], 172 | input[type="button"] { 173 | display: inline-block; 174 | height: 38px; 175 | padding: 0 30px; 176 | color: #555; 177 | text-align: center; 178 | font-size: 11px; 179 | font-weight: 600; 180 | line-height: 38px; 181 | letter-spacing: .1rem; 182 | text-transform: uppercase; 183 | text-decoration: none; 184 | white-space: nowrap; 185 | background-color: transparent; 186 | border-radius: 4px; 187 | border: 1px solid #bbb; 188 | cursor: pointer; 189 | box-sizing: border-box; } 190 | .button:hover, 191 | button:hover, 192 | input[type="submit"]:hover, 193 | input[type="reset"]:hover, 194 | input[type="button"]:hover, 195 | .button:focus, 196 | button:focus, 197 | input[type="submit"]:focus, 198 | input[type="reset"]:focus, 199 | input[type="button"]:focus { 200 | color: #333; 201 | border-color: #888; 202 | outline: 0; } 203 | .button.button-primary, 204 | button.button-primary, 205 | input[type="submit"].button-primary, 206 | input[type="reset"].button-primary, 207 | input[type="button"].button-primary { 208 | color: #FFF; 209 | background-color: #33C3F0; 210 | border-color: #33C3F0; } 211 | .button.button-primary:hover, 212 | button.button-primary:hover, 213 | input[type="submit"].button-primary:hover, 214 | input[type="reset"].button-primary:hover, 215 | input[type="button"].button-primary:hover, 216 | .button.button-primary:focus, 217 | button.button-primary:focus, 218 | input[type="submit"].button-primary:focus, 219 | input[type="reset"].button-primary:focus, 220 | input[type="button"].button-primary:focus { 221 | color: #FFF; 222 | background-color: #1EAEDB; 223 | border-color: #1EAEDB; } 224 | 225 | 226 | /* Forms 227 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 228 | input[type="email"], 229 | input[type="number"], 230 | input[type="search"], 231 | input[type="text"], 232 | input[type="tel"], 233 | input[type="url"], 234 | input[type="password"], 235 | textarea, 236 | select { 237 | height: 38px; 238 | padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */ 239 | background-color: #fff; 240 | border: 1px solid #D1D1D1; 241 | border-radius: 4px; 242 | box-shadow: none; 243 | box-sizing: border-box; } 244 | /* Removes awkward default styles on some inputs for iOS */ 245 | input[type="email"], 246 | input[type="number"], 247 | input[type="search"], 248 | input[type="text"], 249 | input[type="tel"], 250 | input[type="url"], 251 | input[type="password"], 252 | textarea { 253 | -webkit-appearance: none; 254 | -moz-appearance: none; 255 | appearance: none; } 256 | textarea { 257 | min-height: 65px; 258 | padding-top: 6px; 259 | padding-bottom: 6px; } 260 | input[type="email"]:focus, 261 | input[type="number"]:focus, 262 | input[type="search"]:focus, 263 | input[type="text"]:focus, 264 | input[type="tel"]:focus, 265 | input[type="url"]:focus, 266 | input[type="password"]:focus, 267 | textarea:focus, 268 | select:focus { 269 | border: 1px solid #33C3F0; 270 | outline: 0; } 271 | label, 272 | legend { 273 | display: block; 274 | margin-bottom: .5rem; 275 | font-weight: 600; } 276 | fieldset { 277 | padding: 0; 278 | border-width: 0; } 279 | input[type="checkbox"], 280 | input[type="radio"] { 281 | display: inline; } 282 | label > .label-body { 283 | display: inline-block; 284 | margin-left: .5rem; 285 | font-weight: normal; } 286 | 287 | 288 | /* Lists 289 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 290 | ul { 291 | list-style: circle inside; } 292 | ol { 293 | list-style: decimal inside; } 294 | ol, ul { 295 | padding-left: 0; 296 | margin-top: 0; } 297 | ul ul, 298 | ul ol, 299 | ol ol, 300 | ol ul { 301 | margin: 1.5rem 0 1.5rem 3rem; 302 | font-size: 90%; } 303 | li { 304 | margin-bottom: 1rem; } 305 | 306 | 307 | /* Code 308 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 309 | code { 310 | padding: .2rem .5rem; 311 | margin: 0 .2rem; 312 | font-size: 90%; 313 | white-space: nowrap; 314 | background: #F1F1F1; 315 | border: 1px solid #E1E1E1; 316 | border-radius: 4px; } 317 | pre > code { 318 | display: block; 319 | padding: 1rem 1.5rem; 320 | white-space: pre; } 321 | 322 | 323 | /* Tables 324 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 325 | th, 326 | td { 327 | padding: 12px 15px; 328 | text-align: left; 329 | border-bottom: 1px solid #E1E1E1; } 330 | th:first-child, 331 | td:first-child { 332 | padding-left: 0; } 333 | th:last-child, 334 | td:last-child { 335 | padding-right: 0; } 336 | 337 | 338 | /* Spacing 339 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 340 | button, 341 | .button { 342 | margin-bottom: 1rem; } 343 | input, 344 | textarea, 345 | select, 346 | fieldset { 347 | margin-bottom: 1.5rem; } 348 | pre, 349 | blockquote, 350 | dl, 351 | figure, 352 | table, 353 | p, 354 | ul, 355 | ol, 356 | form { 357 | margin-bottom: 2.5rem; } 358 | 359 | 360 | /* Utilities 361 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 362 | .u-full-width { 363 | width: 100%; 364 | box-sizing: border-box; } 365 | .u-max-full-width { 366 | max-width: 100%; 367 | box-sizing: border-box; } 368 | .u-pull-right { 369 | float: right; } 370 | .u-pull-left { 371 | float: left; } 372 | 373 | 374 | /* Misc 375 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 376 | hr { 377 | margin-top: 3rem; 378 | margin-bottom: 3.5rem; 379 | border-width: 0; 380 | border-top: 1px solid #E1E1E1; } 381 | 382 | 383 | /* Clearing 384 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 385 | 386 | /* Self Clearing Goodness */ 387 | .container:after, 388 | .row:after, 389 | .u-cf { 390 | content: ""; 391 | display: table; 392 | clear: both; } 393 | 394 | 395 | /* Media Queries 396 | –––––––––––––––––––––––––––––––––––––––––––––––––– */ 397 | /* 398 | Note: The best way to structure the use of media queries is to create the queries 399 | near the relevant code. For example, if you wanted to change the styles for buttons 400 | on small devices, paste the mobile query code up in the buttons section and style it 401 | there. 402 | */ 403 | 404 | 405 | /* Larger than mobile */ 406 | @media (min-width: 400px) {} 407 | 408 | /* Larger than phablet (also point when grid becomes active) */ 409 | @media (min-width: 550px) {} 410 | 411 | /* Larger than tablet */ 412 | @media (min-width: 750px) {} 413 | 414 | /* Larger than desktop */ 415 | @media (min-width: 1000px) {} 416 | 417 | /* Larger than Desktop HD */ 418 | @media (min-width: 1200px) {} 419 | -------------------------------------------------------------------------------- /docs/static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r9y9/wavenet_vocoder/a35fff76ea3687b05e1a10023cad3f7f64fa25a3/docs/static/favicon.png -------------------------------------------------------------------------------- /docs/static/images/r9y9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r9y9/wavenet_vocoder/a35fff76ea3687b05e1a10023cad3f7f64fa25a3/docs/static/images/r9y9.jpg -------------------------------------------------------------------------------- /egs/README.md: -------------------------------------------------------------------------------- 1 | ## Recipes 2 | 3 | Experimental https://github.com/espnet/espnet style recipes. 4 | -------------------------------------------------------------------------------- /egs/gaussian/conf/gaussian_wavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "raw", 4 | "quantize_channels": 65536, 5 | "preprocess": "preemphasis", 6 | "postprocess": "inv_preemphasis", 7 | "global_gain_scale": 0.55, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Normal", 21 | "log_scale_min": -16.0, 22 | "out_channels": 2, 23 | "layers": 24, 24 | "stacks": 4, 25 | "residual_channels": 128, 26 | "gate_channels": 256, 27 | "skip_out_channels": 128, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 8, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 1000000, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 10240, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 100000, 66 | "train_eval_interval": 100000, 67 | "test_eval_epoch_interval": 50, 68 | "save_optimizer_state": true 69 | } 70 | -------------------------------------------------------------------------------- /egs/gaussian/conf/gaussian_wavenet_demo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "raw", 4 | "quantize_channels": 65536, 5 | "preprocess": "", 6 | "postprocess": "", 7 | "global_gain_scale": 1.0, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Normal", 21 | "log_scale_min": -9.0, 22 | "out_channels": 2, 23 | "layers": 2, 24 | "stacks": 1, 25 | "residual_channels": 4, 26 | "gate_channels": 4, 27 | "skip_out_channels": 4, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 1, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 100, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 2560, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 50, 66 | "train_eval_interval": 50, 67 | "test_eval_epoch_interval": 1, 68 | "save_optimizer_state": true 69 | } -------------------------------------------------------------------------------- /egs/gaussian/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 4 | VOC_DIR=$script_dir/../../ 5 | 6 | # Directory that contains all wav files 7 | # **CHANGE** this to your database path 8 | db_root=~/data/LJSpeech-1.1/wavs/ 9 | 10 | spk="lj" 11 | dumpdir=dump 12 | 13 | # train/dev/eval split 14 | dev_size=10 15 | eval_size=10 16 | # Maximum size of train/dev/eval data (in hours). 17 | # set small value (e.g. 0.2) for testing 18 | limit=1000000 19 | 20 | # waveform global gain normalization scale 21 | global_gain_scale=0.55 22 | 23 | stage=0 24 | stop_stage=0 25 | 26 | # Hyper parameters (.json) 27 | # **CHANGE** here to your own hparams 28 | hparams=conf/gaussian_wavenet_demo.json 29 | 30 | # Batch size at inference time. 31 | inference_batch_size=32 32 | # Leave empty to use latest checkpoint 33 | eval_checkpoint= 34 | # Max number of utts. for evaluation( for debugging) 35 | eval_max_num_utt=1000000 36 | 37 | # exp tag 38 | tag="" # tag for managing experiments. 39 | 40 | . $VOC_DIR/utils/parse_options.sh || exit 1; 41 | 42 | # Set bash to 'debug' mode, it will exit on : 43 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 44 | set -e 45 | set -u 46 | set -o pipefail 47 | 48 | train_set="train_no_dev" 49 | dev_set="dev" 50 | eval_set="eval" 51 | datasets=($train_set $dev_set $eval_set) 52 | 53 | # exp name 54 | if [ -z ${tag} ]; then 55 | expname=${spk}_${train_set}_$(basename ${hparams%.*}) 56 | else 57 | expname=${spk}_${train_set}_${tag} 58 | fi 59 | expdir=exp/$expname 60 | 61 | feat_typ="logmelspectrogram" 62 | 63 | # Output directories 64 | data_root=data/$spk # train/dev/eval splitted data 65 | dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) 66 | dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) 67 | 68 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 69 | echo "stage 0: train/dev/eval split" 70 | if [ -z $db_root ]; then 71 | echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." 72 | echo " Use option --db-root \${path_contains_wav_files}" 73 | exit 1 74 | fi 75 | python $VOC_DIR/mksubset.py $db_root $data_root \ 76 | --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ 77 | --limit=$limit 78 | fi 79 | 80 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 81 | echo "stage 1: Feature Generation" 82 | for s in ${datasets[@]}; 83 | do 84 | python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ 85 | --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams 86 | done 87 | 88 | # Compute mean-var normalization stats 89 | find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt 90 | python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib 91 | rm -f train_list.txt 92 | 93 | # Apply normalization 94 | for s in ${datasets[@]}; 95 | do 96 | python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ 97 | $dump_org_dir/meanvar.joblib 98 | done 99 | cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib 100 | fi 101 | 102 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 103 | echo "stage 2: WaveNet training" 104 | python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ 105 | --checkpoint-dir=$expdir \ 106 | --log-event-path=tensorboard/${expname} 107 | fi 108 | 109 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 110 | echo "stage 3: Synthesis waveform from WaveNet" 111 | if [ -z $eval_checkpoint ]; then 112 | eval_checkpoint=$expdir/checkpoint_latest.pth 113 | fi 114 | name=$(basename $eval_checkpoint) 115 | name=${name/.pth/} 116 | for s in $dev_set $eval_set; 117 | do 118 | dst_dir=$expdir/generated/$name/$s 119 | python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ 120 | --preset $hparams --hparams="batch_size=$inference_batch_size" \ 121 | --num-utterances=$eval_max_num_utt 122 | done 123 | fi 124 | -------------------------------------------------------------------------------- /egs/mol/conf/mol_wavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "raw", 4 | "quantize_channels": 65536, 5 | "preprocess": "preemphasis", 6 | "postprocess": "inv_preemphasis", 7 | "global_gain_scale": 0.55, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Logistic", 21 | "log_scale_min": -16.0, 22 | "out_channels": 30, 23 | "layers": 24, 24 | "stacks": 4, 25 | "residual_channels": 128, 26 | "gate_channels": 256, 27 | "skip_out_channels": 128, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 8, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 1000000, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 10240, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 100000, 66 | "train_eval_interval": 100000, 67 | "test_eval_epoch_interval": 50, 68 | "save_optimizer_state": true 69 | } 70 | -------------------------------------------------------------------------------- /egs/mol/conf/mol_wavenet_demo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "raw", 4 | "quantize_channels": 65536, 5 | "preprocess": "", 6 | "postprocess": "", 7 | "global_gain_scale": 1.0, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Logistic", 21 | "log_scale_min": -9.0, 22 | "out_channels": 30, 23 | "layers": 2, 24 | "stacks": 1, 25 | "residual_channels": 4, 26 | "gate_channels": 4, 27 | "skip_out_channels": 4, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 1, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 100, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 2560, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 50, 66 | "train_eval_interval": 50, 67 | "test_eval_epoch_interval": 1, 68 | "save_optimizer_state": true 69 | } -------------------------------------------------------------------------------- /egs/mol/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 4 | VOC_DIR=$script_dir/../../ 5 | 6 | # Directory that contains all wav files 7 | # **CHANGE** this to your database path 8 | db_root=~/data/LJSpeech-1.1/wavs/ 9 | 10 | spk="lj" 11 | dumpdir=dump 12 | 13 | # train/dev/eval split 14 | dev_size=10 15 | eval_size=10 16 | # Maximum size of train/dev/eval data (in hours). 17 | # set small value (e.g. 0.2) for testing 18 | limit=1000000 19 | 20 | # waveform global gain normalization scale 21 | global_gain_scale=0.55 22 | 23 | stage=0 24 | stop_stage=0 25 | 26 | # Hyper parameters (.json) 27 | # **CHANGE** here to your own hparams 28 | hparams=conf/mol_wavenet_demo.json 29 | 30 | # Batch size at inference time. 31 | inference_batch_size=32 32 | # Leave empty to use latest checkpoint 33 | eval_checkpoint= 34 | # Max number of utts. for evaluation( for debugging) 35 | eval_max_num_utt=1000000 36 | 37 | # exp tag 38 | tag="" # tag for managing experiments. 39 | 40 | . $VOC_DIR/utils/parse_options.sh || exit 1; 41 | 42 | # Set bash to 'debug' mode, it will exit on : 43 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 44 | set -e 45 | set -u 46 | set -o pipefail 47 | 48 | train_set="train_no_dev" 49 | dev_set="dev" 50 | eval_set="eval" 51 | datasets=($train_set $dev_set $eval_set) 52 | 53 | # exp name 54 | if [ -z ${tag} ]; then 55 | expname=${spk}_${train_set}_$(basename ${hparams%.*}) 56 | else 57 | expname=${spk}_${train_set}_${tag} 58 | fi 59 | expdir=exp/$expname 60 | 61 | feat_typ="logmelspectrogram" 62 | 63 | # Output directories 64 | data_root=data/$spk # train/dev/eval splitted data 65 | dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) 66 | dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) 67 | 68 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 69 | echo "stage 0: train/dev/eval split" 70 | if [ -z $db_root ]; then 71 | echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." 72 | echo " Use option --db-root \${path_contains_wav_files}" 73 | exit 1 74 | fi 75 | python $VOC_DIR/mksubset.py $db_root $data_root \ 76 | --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ 77 | --limit=$limit 78 | fi 79 | 80 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 81 | echo "stage 1: Feature Generation" 82 | for s in ${datasets[@]}; 83 | do 84 | python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ 85 | --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams 86 | done 87 | 88 | # Compute mean-var normalization stats 89 | find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt 90 | python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib 91 | rm -f train_list.txt 92 | 93 | # Apply normalization 94 | for s in ${datasets[@]}; 95 | do 96 | python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ 97 | $dump_org_dir/meanvar.joblib 98 | done 99 | cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib 100 | fi 101 | 102 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 103 | echo "stage 2: WaveNet training" 104 | python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ 105 | --checkpoint-dir=$expdir \ 106 | --log-event-path=tensorboard/${expname} 107 | fi 108 | 109 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 110 | echo "stage 3: Synthesis waveform from WaveNet" 111 | if [ -z $eval_checkpoint ]; then 112 | eval_checkpoint=$expdir/checkpoint_latest.pth 113 | fi 114 | name=$(basename $eval_checkpoint) 115 | name=${name/.pth/} 116 | for s in $dev_set $eval_set; 117 | do 118 | dst_dir=$expdir/generated/$name/$s 119 | python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ 120 | --preset $hparams --hparams="batch_size=$inference_batch_size" \ 121 | --num-utterances=$eval_max_num_utt 122 | done 123 | fi 124 | -------------------------------------------------------------------------------- /egs/mulaw256/conf/mulaw256_wavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "mulaw-quantize", 4 | "quantize_channels": 256, 5 | "preprocess": "preemphasis", 6 | "postprocess": "inv_preemphasis", 7 | "global_gain_scale": 0.55, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Logistic", 21 | "log_scale_min": -9.0, 22 | "out_channels": 256, 23 | "layers": 30, 24 | "stacks": 3, 25 | "residual_channels": 128, 26 | "gate_channels": 256, 27 | "skip_out_channels": 128, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 8, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 500000, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 10240, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 100000, 66 | "train_eval_interval": 100000, 67 | "test_eval_epoch_interval": 50, 68 | "save_optimizer_state": true 69 | } -------------------------------------------------------------------------------- /egs/mulaw256/conf/mulaw256_wavenet_demo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wavenet_vocoder", 3 | "input_type": "mulaw-quantize", 4 | "quantize_channels": 256, 5 | "preprocess": "", 6 | "postprocess": "", 7 | "global_gain_scale": 1.0, 8 | "sample_rate": 22050, 9 | "silence_threshold": 2, 10 | "num_mels": 80, 11 | "fmin": 125, 12 | "fmax": 7600, 13 | "fft_size": 1024, 14 | "hop_size": 256, 15 | "frame_shift_ms": null, 16 | "win_length": 1024, 17 | "win_length_ms": -1.0, 18 | "window": "hann", 19 | "highpass_cutoff": 70.0, 20 | "output_distribution": "Logistic", 21 | "log_scale_min": -9.0, 22 | "out_channels": 256, 23 | "layers": 2, 24 | "stacks": 1, 25 | "residual_channels": 4, 26 | "gate_channels": 4, 27 | "skip_out_channels": 4, 28 | "dropout": 0.0, 29 | "kernel_size": 3, 30 | "cin_channels": 80, 31 | "cin_pad": 2, 32 | "upsample_conditional_features": true, 33 | "upsample_net": "ConvInUpsampleNetwork", 34 | "upsample_params": { 35 | "upsample_scales": [ 36 | 4, 37 | 4, 38 | 4, 39 | 4 40 | ] 41 | }, 42 | "gin_channels": -1, 43 | "n_speakers": 7, 44 | "pin_memory": true, 45 | "num_workers": 2, 46 | "batch_size": 1, 47 | "optimizer": "Adam", 48 | "optimizer_params": { 49 | "lr": 0.001, 50 | "eps": 1e-08, 51 | "weight_decay": 0.0 52 | }, 53 | "lr_schedule": "step_learning_rate_decay", 54 | "lr_schedule_kwargs": { 55 | "anneal_rate": 0.5, 56 | "anneal_interval": 200000 57 | }, 58 | "max_train_steps": 100, 59 | "nepochs": 2000, 60 | "clip_thresh": -1, 61 | "max_time_sec": null, 62 | "max_time_steps": 2560, 63 | "exponential_moving_average": true, 64 | "ema_decay": 0.9999, 65 | "checkpoint_interval": 50, 66 | "train_eval_interval": 50, 67 | "test_eval_epoch_interval": 1, 68 | "save_optimizer_state": true 69 | } -------------------------------------------------------------------------------- /egs/mulaw256/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 4 | VOC_DIR=$script_dir/../../ 5 | 6 | # Directory that contains all wav files 7 | # **CHANGE** this to your database path 8 | db_root=~/data/LJSpeech-1.1/wavs/ 9 | 10 | spk="lj" 11 | dumpdir=dump 12 | 13 | # train/dev/eval split 14 | dev_size=10 15 | eval_size=10 16 | # Maximum size of train/dev/eval data (in hours). 17 | # set small value (e.g. 0.2) for testing 18 | limit=1000000 19 | 20 | # waveform global gain normalization scale 21 | global_gain_scale=0.55 22 | 23 | stage=0 24 | stop_stage=0 25 | 26 | # Hyper parameters (.json) 27 | # **CHANGE** here to your own hparams 28 | hparams=conf/mulaw256_wavenet_demo.json 29 | 30 | # Batch size at inference time. 31 | inference_batch_size=32 32 | # Leave empty to use latest checkpoint 33 | eval_checkpoint= 34 | # Max number of utts. for evaluation( for debugging) 35 | eval_max_num_utt=1000000 36 | 37 | # exp tag 38 | tag="" # tag for managing experiments. 39 | 40 | . $VOC_DIR/utils/parse_options.sh || exit 1; 41 | 42 | # Set bash to 'debug' mode, it will exit on : 43 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 44 | set -e 45 | set -u 46 | set -o pipefail 47 | 48 | train_set="train_no_dev" 49 | dev_set="dev" 50 | eval_set="eval" 51 | datasets=($train_set $dev_set $eval_set) 52 | 53 | # exp name 54 | if [ -z ${tag} ]; then 55 | expname=${spk}_${train_set}_$(basename ${hparams%.*}) 56 | else 57 | expname=${spk}_${train_set}_${tag} 58 | fi 59 | expdir=exp/$expname 60 | 61 | feat_typ="logmelspectrogram" 62 | 63 | # Output directories 64 | data_root=data/$spk # train/dev/eval splitted data 65 | dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) 66 | dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) 67 | 68 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 69 | echo "stage 0: train/dev/eval split" 70 | if [ -z $db_root ]; then 71 | echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." 72 | echo " Use option --db-root \${path_contains_wav_files}" 73 | exit 1 74 | fi 75 | python $VOC_DIR/mksubset.py $db_root $data_root \ 76 | --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ 77 | --limit=$limit 78 | fi 79 | 80 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 81 | echo "stage 1: Feature Generation" 82 | for s in ${datasets[@]}; 83 | do 84 | python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ 85 | --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams 86 | done 87 | 88 | # Compute mean-var normalization stats 89 | find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt 90 | python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib 91 | rm -f train_list.txt 92 | 93 | # Apply normalization 94 | for s in ${datasets[@]}; 95 | do 96 | python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ 97 | $dump_org_dir/meanvar.joblib 98 | done 99 | cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib 100 | fi 101 | 102 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 103 | echo "stage 2: WaveNet training" 104 | python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ 105 | --checkpoint-dir=$expdir \ 106 | --log-event-path=tensorboard/${expname} 107 | fi 108 | 109 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 110 | echo "stage 3: Synthesis waveform from WaveNet" 111 | if [ -z $eval_checkpoint ]; then 112 | eval_checkpoint=$expdir/checkpoint_latest.pth 113 | fi 114 | name=$(basename $eval_checkpoint) 115 | name=${name/.pth/} 116 | for s in $dev_set $eval_set; 117 | do 118 | dst_dir=$expdir/generated/$name/$s 119 | python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ 120 | --preset $hparams --hparams="batch_size=$inference_batch_size" \ 121 | --num-utterances=$eval_max_num_utt 122 | done 123 | fi 124 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Synthesis waveform for testset 4 | 5 | usage: evaluate.py [options] 6 | 7 | options: 8 | --hparams= Hyper parameters [default: ]. 9 | --preset= Path of preset parameters (json). 10 | --length= Steps to generate [default: 32000]. 11 | --speaker-id= Use specific speaker of data in case for multi-speaker datasets. 12 | --initial-value= Initial value for the WaveNet decoder. 13 | --output-html Output html for blog post. 14 | --num-utterances=N> Generate N utterenaces per speaker [default: -1]. 15 | --verbose= Verbosity level [default: 0]. 16 | -h, --help Show help message. 17 | """ 18 | from docopt import docopt 19 | 20 | import sys 21 | from glob import glob 22 | import os 23 | from os.path import dirname, join, basename, splitext, exists 24 | import torch 25 | import numpy as np 26 | from nnmnkwii import preprocessing as P 27 | from tqdm import tqdm 28 | from scipy.io import wavfile 29 | from torch.utils import data as data_utils 30 | from torch.nn import functional as F 31 | 32 | from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw 33 | 34 | import audio 35 | from hparams import hparams 36 | from train import RawAudioDataSource, MelSpecDataSource, PyTorchDataset, _pad_2d 37 | from nnmnkwii.datasets import FileSourceDataset 38 | 39 | use_cuda = torch.cuda.is_available() 40 | device = torch.device("cuda" if use_cuda else "cpu") 41 | 42 | 43 | def to_int16(x): 44 | if x.dtype == np.int16: 45 | return x 46 | assert x.dtype == np.float32 47 | assert x.min() >= -1 and x.max() <= 1.0 48 | return (x * 32767).astype(np.int16) 49 | 50 | 51 | def dummy_collate(batch): 52 | N = len(batch) 53 | input_lengths = [(len(x) - hparams.cin_pad * 2) * audio.get_hop_size() for x in batch] 54 | input_lengths = torch.LongTensor(input_lengths) 55 | max_len = max([len(x) for x in batch]) 56 | c_batch = np.array([_pad_2d(x, max_len) for x in batch], dtype=np.float32) 57 | c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() 58 | return [None]*N, [None]*N, c_batch, None, input_lengths 59 | 60 | 61 | def get_data_loader(data_dir, collate_fn): 62 | wav_paths = glob(join(data_dir, "*-wave.npy")) 63 | if len(wav_paths) != 0: 64 | X = FileSourceDataset(RawAudioDataSource(data_dir, 65 | hop_size=audio.get_hop_size(), 66 | max_steps=None, cin_pad=hparams.cin_pad)) 67 | else: 68 | X = None 69 | C = FileSourceDataset(MelSpecDataSource(data_dir, 70 | hop_size=audio.get_hop_size(), 71 | max_steps=None, cin_pad=hparams.cin_pad)) 72 | # No audio found: 73 | if X is None: 74 | assert len(C) > 0 75 | data_loader = data_utils.DataLoader( 76 | C, batch_size=hparams.batch_size, drop_last=False, 77 | num_workers=hparams.num_workers, sampler=None, shuffle=False, 78 | collate_fn=dummy_collate, pin_memory=hparams.pin_memory) 79 | else: 80 | assert len(X) == len(C) 81 | if C[0].shape[-1] != hparams.cin_channels: 82 | raise RuntimeError( 83 | """Invalid cin_channnels {}. Expectd to be {}.""".format( 84 | hparams.cin_channels, C[0].shape[-1])) 85 | dataset = PyTorchDataset(X, C) 86 | 87 | data_loader = data_utils.DataLoader( 88 | dataset, batch_size=hparams.batch_size, drop_last=False, 89 | num_workers=hparams.num_workers, sampler=None, shuffle=False, 90 | collate_fn=collate_fn, pin_memory=hparams.pin_memory) 91 | 92 | return data_loader 93 | 94 | 95 | if __name__ == "__main__": 96 | args = docopt(__doc__) 97 | verbose = int(args["--verbose"]) 98 | if verbose > 0: 99 | print("Command line args:\n", args) 100 | data_root = args[""] 101 | checkpoint_path = args[""] 102 | dst_dir = args[""] 103 | 104 | length = int(args["--length"]) 105 | # Note that speaker-id is used for filtering out unrelated-speaker from 106 | # multi-speaker dataset. 107 | speaker_id = args["--speaker-id"] 108 | speaker_id = int(speaker_id) if speaker_id is not None else None 109 | initial_value = args["--initial-value"] 110 | initial_value = None if initial_value is None else float(initial_value) 111 | output_html = args["--output-html"] 112 | num_utterances = int(args["--num-utterances"]) 113 | preset = args["--preset"] 114 | 115 | # Load preset if specified 116 | if preset is not None: 117 | with open(preset) as f: 118 | hparams.parse_json(f.read()) 119 | else: 120 | hparams_json = join(dirname(checkpoint_path), "hparams.json") 121 | if exists(hparams_json): 122 | print("Loading hparams from {}".format(hparams_json)) 123 | with open(hparams_json) as f: 124 | hparams.parse_json(f.read()) 125 | 126 | # Override hyper parameters 127 | hparams.parse(args["--hparams"]) 128 | assert hparams.name == "wavenet_vocoder" 129 | 130 | hparams.max_time_sec = None 131 | hparams.max_time_steps = None 132 | 133 | from train import build_model, get_data_loaders 134 | from synthesis import batch_wavegen 135 | 136 | # Data 137 | # Use exactly same testset used in training script 138 | # disable shuffle for convenience 139 | # test_data_loader = get_data_loaders(data_root, speaker_id, test_shuffle=False)["test"] 140 | from train import collate_fn 141 | test_data_loader = get_data_loader(data_root, collate_fn) 142 | test_dataset = test_data_loader.dataset 143 | 144 | # Model 145 | model = build_model().to(device) 146 | 147 | # Load checkpoint 148 | print("Load checkpoint from {}".format(checkpoint_path)) 149 | if use_cuda: 150 | checkpoint = torch.load(checkpoint_path) 151 | else: 152 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 153 | model.load_state_dict(checkpoint["state_dict"]) 154 | checkpoint_name = splitext(basename(checkpoint_path))[0] 155 | 156 | os.makedirs(dst_dir, exist_ok=True) 157 | dst_dir_name = basename(os.path.normpath(dst_dir)) 158 | 159 | generated_utterances = {} 160 | cin_pad = hparams.cin_pad 161 | file_idx = 0 162 | for idx, (x, y, c, g, input_lengths) in enumerate(test_data_loader): 163 | if cin_pad > 0: 164 | c = F.pad(c, pad=(cin_pad, cin_pad), mode="replicate") 165 | 166 | # B x 1 x T 167 | if x[0] is not None: 168 | B, _, T = x.shape 169 | else: 170 | B, _, Tn = c.shape 171 | T = Tn * audio.get_hop_size() 172 | 173 | if g is None and num_utterances > 0 and B * idx >= num_utterances: 174 | break 175 | 176 | ref_files = [] 177 | ref_feats = [] 178 | for i in range(B): 179 | # Yes this is ugly... 180 | if hasattr(test_data_loader.dataset, "X"): 181 | ref_files.append(test_data_loader.dataset.X.collected_files[file_idx][0]) 182 | else: 183 | pass 184 | if hasattr(test_data_loader.dataset, "Mel"): 185 | ref_feats.append(test_data_loader.dataset.Mel.collected_files[file_idx][0]) 186 | else: 187 | ref_feats.append(test_data_loader.dataset.collected_files[file_idx][0]) 188 | file_idx += 1 189 | 190 | if num_utterances > 0 and g is not None: 191 | try: 192 | generated_utterances[g] += 1 193 | if generated_utterances[g] > num_utterances: 194 | continue 195 | except KeyError: 196 | generated_utterances[g] = 1 197 | 198 | if output_html: 199 | def _tqdm(x): return x 200 | else: 201 | _tqdm = tqdm 202 | 203 | # Generate 204 | y_hats = batch_wavegen(model, c=c, g=g, fast=True, tqdm=_tqdm) 205 | 206 | # Save each utt. 207 | has_ref_file = len(ref_files) > 0 208 | for i, (ref, gen, length) in enumerate(zip(x, y_hats, input_lengths)): 209 | if has_ref_file: 210 | if is_mulaw_quantize(hparams.input_type): 211 | # needs to be float since mulaw_inv returns in range of [-1, 1] 212 | ref = ref.max(0)[1].view(-1).float().cpu().numpy()[:length] 213 | else: 214 | ref = ref.view(-1).cpu().numpy()[:length] 215 | gen = gen[:length] 216 | if has_ref_file: 217 | target_audio_path = ref_files[i] 218 | name = splitext(basename(target_audio_path))[0].replace("-wave", "") 219 | else: 220 | target_feat_path = ref_feats[i] 221 | name = splitext(basename(target_feat_path))[0].replace("-feats", "") 222 | 223 | # Paths 224 | if g is None: 225 | dst_wav_path = join(dst_dir, "{}_gen.wav".format( 226 | name)) 227 | target_wav_path = join(dst_dir, "{}_ref.wav".format( 228 | name)) 229 | else: 230 | dst_wav_path = join(dst_dir, "speaker{}_{}_gen.wav".format( 231 | g, name)) 232 | target_wav_path = join(dst_dir, "speaker{}_{}_ref.wav".format( 233 | g, name)) 234 | 235 | # save 236 | if has_ref_file: 237 | if is_mulaw_quantize(hparams.input_type): 238 | ref = P.inv_mulaw_quantize(ref, hparams.quantize_channels - 1) 239 | elif is_mulaw(hparams.input_type): 240 | ref = P.inv_mulaw(ref, hparams.quantize_channels - 1) 241 | if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]: 242 | ref = getattr(audio, hparams.postprocess)(ref) 243 | if hparams.global_gain_scale > 0: 244 | ref /= hparams.global_gain_scale 245 | 246 | # clip (just in case) 247 | gen = np.clip(gen, -1.0, 1.0) 248 | if has_ref_file: 249 | ref = np.clip(ref, -1.0, 1.0) 250 | 251 | wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen)) 252 | if has_ref_file: 253 | wavfile.write(target_wav_path, hparams.sample_rate, to_int16(ref)) 254 | 255 | # log (TODO) 256 | if output_html and False: 257 | print(""" 258 | 262 | """.format(hparams.name, dst_dir_name, basename(dst_wav_path))) 263 | 264 | print("Finished! Check out {} for generated audio samples.".format(dst_dir)) 265 | sys.exit(0) 266 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from wavenet_vocoder.tfcompat.hparam import HParams 2 | import numpy as np 3 | 4 | # NOTE: If you want full control for model architecture. please take a look 5 | # at the code and change whatever you want. Some hyper parameters are hardcoded. 6 | 7 | # Default hyperparameters: 8 | hparams = HParams( 9 | name="wavenet_vocoder", 10 | 11 | # Input type: 12 | # 1. raw [-1, 1] 13 | # 2. mulaw [-1, 1] 14 | # 3. mulaw-quantize [0, mu] 15 | # If input_type is raw or mulaw, network assumes scalar input and 16 | # discretized mixture of logistic distributions output, otherwise one-hot 17 | # input and softmax output are assumed. 18 | # **NOTE**: if you change the one of the two parameters below, you need to 19 | # re-run preprocessing before training. 20 | input_type="raw", 21 | quantize_channels=65536, # 65536 or 256 22 | 23 | # Audio: 24 | # time-domain pre/post-processing 25 | # e.g., preemphasis/inv_preemphasis 26 | # ref: LPCNet https://arxiv.org/abs/1810.11846 27 | preprocess="", 28 | postprocess="", 29 | # waveform domain scaling 30 | global_gain_scale=1.0, 31 | 32 | sample_rate=22050, 33 | # this is only valid for mulaw is True 34 | silence_threshold=2, 35 | num_mels=80, 36 | fmin=125, 37 | fmax=7600, 38 | fft_size=1024, 39 | # shift can be specified by either hop_size or frame_shift_ms 40 | hop_size=256, 41 | frame_shift_ms=None, 42 | win_length=1024, 43 | win_length_ms=-1.0, 44 | window="hann", 45 | 46 | # DC removal 47 | highpass_cutoff=70.0, 48 | 49 | # Parametric output distribution type for scalar input 50 | # 1) Logistic or 2) Normal 51 | output_distribution="Logistic", 52 | log_scale_min=-16.0, 53 | 54 | # Model: 55 | # This should equal to `quantize_channels` if mu-law quantize enabled 56 | # otherwise num_mixture * 3 (pi, mean, log_scale) 57 | # single mixture case: 2 58 | out_channels=10 * 3, 59 | layers=24, 60 | stacks=4, 61 | residual_channels=128, 62 | gate_channels=256, # split into 2 gropus internally for gated activation 63 | skip_out_channels=128, 64 | dropout=0.0, 65 | kernel_size=3, 66 | 67 | # Local conditioning (set negative value to disable)) 68 | cin_channels=80, 69 | cin_pad=2, 70 | # If True, use transposed convolutions to upsample conditional features, 71 | # otherwise repeat features to adjust time resolution 72 | upsample_conditional_features=True, 73 | upsample_net="ConvInUpsampleNetwork", 74 | upsample_params={ 75 | "upsample_scales": [4, 4, 4, 4], # should np.prod(upsample_scales) == hop_size 76 | }, 77 | 78 | # Global conditioning (set negative value to disable) 79 | # currently limited for speaker embedding 80 | # this should only be enabled for multi-speaker dataset 81 | gin_channels=-1, # i.e., speaker embedding dim 82 | n_speakers=7, # 7 for CMU ARCTIC 83 | 84 | # Data loader 85 | pin_memory=True, 86 | num_workers=2, 87 | 88 | # Loss 89 | 90 | # Training: 91 | batch_size=8, 92 | optimizer="Adam", 93 | optimizer_params={ 94 | "lr": 1e-3, 95 | "eps": 1e-8, 96 | "weight_decay": 0.0, 97 | }, 98 | 99 | # see lrschedule.py for available lr_schedule 100 | lr_schedule="step_learning_rate_decay", 101 | lr_schedule_kwargs={"anneal_rate": 0.5, "anneal_interval": 200000}, 102 | 103 | max_train_steps=1000000, 104 | nepochs=2000, 105 | 106 | clip_thresh=-1, 107 | 108 | # max time steps can either be specified as sec or steps 109 | # if both are None, then full audio samples are used in a batch 110 | max_time_sec=None, 111 | max_time_steps=10240, # 256 * 40 112 | 113 | # Hold moving averaged parameters and use them for evaluation 114 | exponential_moving_average=True, 115 | # averaged = decay * averaged + (1 - decay) * x 116 | ema_decay=0.9999, 117 | 118 | # Save 119 | # per-step intervals 120 | checkpoint_interval=100000, 121 | train_eval_interval=100000, 122 | # per-epoch interval 123 | test_eval_epoch_interval=50, 124 | save_optimizer_state=True, 125 | 126 | # Eval: 127 | ) 128 | 129 | 130 | def hparams_debug_string(): 131 | values = hparams.values() 132 | hp = [' %s: %s' % (name, values[name]) for name in sorted(values)] 133 | return 'Hyperparameters:\n' + '\n'.join(hp) 134 | -------------------------------------------------------------------------------- /lrschedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # https://github.com/tensorflow/tensor2tensor/issues/280#issuecomment-339110329 5 | def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000): 6 | # Noam scheme from tensor2tensor: 7 | warmup_steps = float(warmup_steps) 8 | step = global_step + 1. 9 | lr = init_lr * warmup_steps**0.5 * np.minimum( 10 | step * warmup_steps**-1.5, step**-0.5) 11 | return lr 12 | 13 | 14 | def step_learning_rate_decay(init_lr, global_step, 15 | anneal_rate=0.98, 16 | anneal_interval=30000): 17 | return init_lr * anneal_rate ** (global_step // anneal_interval) 18 | 19 | 20 | def cyclic_cosine_annealing(init_lr, global_step, T, M): 21 | """Cyclic cosine annealing 22 | 23 | https://arxiv.org/pdf/1704.00109.pdf 24 | 25 | Args: 26 | init_lr (float): Initial learning rate 27 | global_step (int): Current iteration number 28 | T (int): Total iteration number (i,e. nepoch) 29 | M (int): Number of ensembles we want 30 | 31 | Returns: 32 | float: Annealed learning rate 33 | """ 34 | TdivM = T // M 35 | return init_lr / 2.0 * (np.cos(np.pi * ((global_step - 1) % TdivM) / TdivM) + 1.0) 36 | -------------------------------------------------------------------------------- /mksubset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Make subset of dataset 4 | 5 | usage: mksubset.py [options] 6 | 7 | options: 8 | -h, --help Show help message. 9 | --limit= Limit dataset size by N-hours [default: 10000]. 10 | --train-dev-test-split Train/test split. 11 | --dev-size= Development size or rate [default: 0.1]. 12 | --test-size= Test size or rate [default: 0.1]. 13 | --target-sr= Resampling. 14 | --random-state= Random seed [default: 1234]. 15 | """ 16 | from docopt import docopt 17 | import librosa 18 | from glob import glob 19 | from os.path import join, basename, exists, splitext 20 | from tqdm import tqdm 21 | import sys 22 | import os 23 | from shutil import copy2 24 | from scipy.io import wavfile 25 | import numpy as np 26 | 27 | 28 | def read_wav_or_raw(src_file, is_raw): 29 | if is_raw: 30 | sr = 24000 # hard coded for now 31 | x = np.fromfile(src_file, dtype=np.int16) 32 | else: 33 | sr, x = wavfile.read(src_file) 34 | return sr, x 35 | 36 | 37 | def write_wav_or_raw(dst_path, sr, x, is_raw): 38 | if is_raw: 39 | x.tofile(dst_path) 40 | else: 41 | wavfile.write(dst_path, sr, x) 42 | 43 | if __name__ == "__main__": 44 | args = docopt(__doc__) 45 | in_dir = args[""] 46 | out_dir = args[""] 47 | limit = float(args["--limit"]) 48 | train_dev_test_split = args["--train-dev-test-split"] 49 | dev_size = float(args["--dev-size"]) 50 | test_size = float(args["--test-size"]) 51 | target_sr = args["--target-sr"] 52 | target_sr = int(target_sr) if target_sr is not None else None 53 | random_state = int(args["--random-state"]) 54 | 55 | src_files = sorted(glob(join(in_dir, "*.wav"))) 56 | raw_files = sorted(glob(join(in_dir, "*.raw"))) 57 | is_raw = len(src_files) == 0 and len(raw_files) > 0 58 | if is_raw: 59 | print("Assuming 24kHz /16bit audio data") 60 | src_files = raw_files 61 | if len(src_files) == 0: 62 | raise RuntimeError("No files found in {}".format(in_dir)) 63 | 64 | total_samples = 0 65 | indices = [] 66 | signed_int16_max = 2**15 67 | 68 | os.makedirs(out_dir, exist_ok=True) 69 | if train_dev_test_split: 70 | os.makedirs(join(out_dir, "train_no_dev"), exist_ok=True) 71 | os.makedirs(join(out_dir, "dev"), exist_ok=True) 72 | os.makedirs(join(out_dir, "eval"), exist_ok=True) 73 | 74 | print("Total number of utterances: {}".format(len(src_files))) 75 | for idx, src_file in tqdm(enumerate(src_files)): 76 | sr, x = read_wav_or_raw(src_file, is_raw) 77 | if x.dtype == np.int16: 78 | x = x.astype(np.float32) / signed_int16_max 79 | total_samples += len(x) 80 | total_hours = float(total_samples) / sr / 3600.0 81 | indices.append(idx) 82 | 83 | if total_hours > limit: 84 | print("Total hours {:.3f} exceeded limit ({} hours).".format(total_hours, limit)) 85 | break 86 | print("Total number of collected utterances: {}".format(len(indices))) 87 | 88 | if train_dev_test_split: 89 | from sklearn.model_selection import train_test_split as split 90 | # Get test and dev set from last 91 | if test_size > 1 and dev_size > 1: 92 | test_size = int(test_size) 93 | dev_size = int(dev_size) 94 | testdev_size = test_size + dev_size 95 | train_indices = indices[:-testdev_size] 96 | dev_indices = indices[-testdev_size:-testdev_size + dev_size] 97 | test_indices = indices[-test_size:] 98 | else: 99 | train_indices, dev_test_indices = split( 100 | indices, test_size=test_size + dev_size, random_state=random_state) 101 | dev_indices, test_indices = split( 102 | dev_test_indices, test_size=test_size / (test_size + dev_size), 103 | random_state=random_state) 104 | sets = [ 105 | (sorted(train_indices), join(out_dir, "train_no_dev")), 106 | (sorted(dev_indices), join(out_dir, "dev")), 107 | (sorted(test_indices), join(out_dir, "eval")), 108 | ] 109 | else: 110 | sets = [(indices, out_dir)] 111 | 112 | from sklearn.preprocessing import MinMaxScaler 113 | scaler = MinMaxScaler() 114 | 115 | total_samples = {} 116 | sr = 0 117 | for indices, d in sets: 118 | set_name = basename(d) 119 | total_samples[set_name] = 0 120 | for idx in tqdm(indices): 121 | src_file = src_files[idx] 122 | dst_path = join(d, basename(src_file)) 123 | if target_sr is not None: 124 | sr, x = read_wav_or_raw(src_file, is_raw) 125 | is_int16 = x.dtype == np.int16 126 | if is_int16: 127 | x = x.astype(np.float32) / signed_int16_max 128 | if target_sr is not None and target_sr != sr: 129 | x = librosa.resample(x, sr, target_sr) 130 | sr = target_sr 131 | scaler.partial_fit(x.astype(np.float64).reshape(-1, 1)) 132 | if is_int16: 133 | x = (x * signed_int16_max).astype(np.int16) 134 | write_wav_or_raw(dst_path, sr, x, is_raw) 135 | total_samples[set_name] += len(x) 136 | else: 137 | sr, x = read_wav_or_raw(src_file, is_raw) 138 | is_int16 = x.dtype == np.int16 139 | if is_int16: 140 | x = x.astype(np.float32) / signed_int16_max 141 | scaler.partial_fit(x.astype(np.float64).reshape(-1, 1)) 142 | total_samples[set_name] += len(x) 143 | copy2(src_file, dst_path) 144 | 145 | print("Waveform min: {}".format(scaler.data_min_)) 146 | print("Waveform max: {}".format(scaler.data_max_)) 147 | absmax = max(np.abs(scaler.data_min_[0]), np.abs(scaler.data_max_[0])) 148 | print("Waveform absolute max: {}".format(absmax)) 149 | if absmax > 1.0: 150 | print("There were clipping(s) in your dataset.") 151 | print("Global scaling factor would be around {}".format(1.0 / absmax)) 152 | 153 | if train_dev_test_split: 154 | print("Train/dev/test split:") 155 | for n, s in zip(["train_no_dev", "dev", "eval"], sets): 156 | hours = total_samples[n] / sr / 3600.0 157 | print("{}: {:.2f} hours ({} utt)".format(n, hours, len(s[0]))) 158 | 159 | sys.exit(0) 160 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Preprocess dataset 4 | 5 | usage: preprocess.py [options] 6 | 7 | options: 8 | --num_workers= Num workers. 9 | --hparams= Hyper parameters [default: ]. 10 | --preset= Path of preset parameters (json). 11 | -h, --help Show help message. 12 | """ 13 | from docopt import docopt 14 | import os 15 | from os.path import join 16 | from multiprocessing import cpu_count 17 | from tqdm import tqdm 18 | import importlib 19 | from hparams import hparams 20 | 21 | 22 | def preprocess(mod, in_dir, out_root, num_workers): 23 | os.makedirs(out_dir, exist_ok=True) 24 | metadata = mod.build_from_path(in_dir, out_dir, num_workers, tqdm=tqdm) 25 | write_metadata(metadata, out_dir) 26 | 27 | 28 | def write_metadata(metadata, out_dir): 29 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: 30 | for m in metadata: 31 | f.write('|'.join([str(x) for x in m]) + '\n') 32 | frames = sum([m[2] for m in metadata]) 33 | sr = hparams.sample_rate 34 | hours = frames / sr / 3600 35 | print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) 36 | print('Min frame length: %d' % min(m[2] for m in metadata)) 37 | print('Max frame length: %d' % max(m[2] for m in metadata)) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = docopt(__doc__) 42 | name = args[""] 43 | in_dir = args[""] 44 | out_dir = args[""] 45 | num_workers = args["--num_workers"] 46 | num_workers = cpu_count() // 2 if num_workers is None else int(num_workers) 47 | preset = args["--preset"] 48 | 49 | # Load preset if specified 50 | if preset is not None: 51 | with open(preset) as f: 52 | hparams.parse_json(f.read()) 53 | # Override hyper parameters 54 | hparams.parse(args["--hparams"]) 55 | assert hparams.name == "wavenet_vocoder" 56 | 57 | print("Sampling frequency: {}".format(hparams.sample_rate)) 58 | if name in ["cmu_arctic", "jsut", "librivox"]: 59 | print("""warn!: {} is no longer explicitly supported! 60 | 61 | Please use a generic dataest 'wavallin' instead. 62 | All you need to do is to put all wav files in a single directory.""".format(name)) 63 | sys.exit(1) 64 | 65 | if name == "ljspeech": 66 | print("""warn: ljspeech is deprecated! 67 | Please use a generic dataset 'wavallin' instead.""") 68 | sys.exit(1) 69 | 70 | mod = importlib.import_module("datasets." + name) 71 | preprocess(mod, in_dir, out_dir, num_workers) 72 | -------------------------------------------------------------------------------- /preprocess_normalize.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """Perform meanvar normalization to preprocessed features. 3 | 4 | usage: preprocess_normalize.py [options] 5 | 6 | options: 7 | --inverse Inverse transform. 8 | --num_workers= Num workers. 9 | -h, --help Show help message. 10 | """ 11 | from docopt import docopt 12 | import os 13 | from os.path import join, exists, basename, splitext 14 | from multiprocessing import cpu_count 15 | from tqdm import tqdm 16 | from nnmnkwii import preprocessing as P 17 | import numpy as np 18 | import json 19 | from concurrent.futures import ProcessPoolExecutor 20 | from functools import partial 21 | from shutil import copyfile 22 | 23 | import joblib 24 | from glob import glob 25 | from itertools import zip_longest 26 | 27 | 28 | def get_paths_by_glob(in_dir, filt): 29 | return glob(join(in_dir, filt)) 30 | 31 | 32 | def _process_utterance(out_dir, audio_path, feat_path, scaler, inverse): 33 | # [Optional] copy audio with the same name if exists 34 | if audio_path is not None and exists(audio_path): 35 | name = splitext(basename(audio_path))[0] 36 | np.save(join(out_dir, name), np.load(audio_path), allow_pickle=False) 37 | 38 | # [Required] apply normalization for features 39 | assert exists(feat_path) 40 | x = np.load(feat_path) 41 | if inverse: 42 | y = scaler.inverse_transform(x) 43 | else: 44 | y = scaler.transform(x) 45 | assert x.dtype == y.dtype 46 | name = splitext(basename(feat_path))[0] 47 | np.save(join(out_dir, name), y, allow_pickle=False) 48 | 49 | 50 | def apply_normalization_dir2dir(in_dir, out_dir, scaler, inverse, num_workers): 51 | # NOTE: at this point, audio_paths can be empty 52 | audio_paths = get_paths_by_glob(in_dir, "*-wave.npy") 53 | feature_paths = get_paths_by_glob(in_dir, "*-feats.npy") 54 | executor = ProcessPoolExecutor(max_workers=num_workers) 55 | futures = [] 56 | for audio_path, feature_path in zip_longest(audio_paths, feature_paths): 57 | futures.append(executor.submit( 58 | partial(_process_utterance, out_dir, audio_path, feature_path, scaler, inverse))) 59 | for future in tqdm(futures): 60 | future.result() 61 | 62 | 63 | if __name__ == "__main__": 64 | args = docopt(__doc__) 65 | in_dir = args[""] 66 | out_dir = args[""] 67 | scaler_path = args[""] 68 | scaler = joblib.load(scaler_path) 69 | inverse = args["--inverse"] 70 | num_workers = args["--num_workers"] 71 | num_workers = cpu_count() // 2 if num_workers is None else int(num_workers) 72 | 73 | os.makedirs(out_dir, exist_ok=True) 74 | apply_normalization_dir2dir(in_dir, out_dir, scaler, inverse, num_workers) 75 | 76 | # Copy meta information if exists 77 | traintxt = join(in_dir, "train.txt") 78 | if exists(traintxt): 79 | copyfile(join(in_dir, "train.txt"), join(out_dir, "train.txt")) 80 | -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script for Pypi release 4 | # 0. Make sure you are on git tag 5 | # 1. Run the script 6 | # 2. Upload sdist 7 | 8 | set -e 9 | 10 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 11 | cd $script_dir 12 | 13 | TAG=$(git describe --exact-match --tags HEAD) 14 | 15 | VERSION=${TAG/v/} 16 | 17 | WAVENET_VOCODER_BUILD_VERSION=$VERSION python setup.py develop sdist 18 | echo "*** Ready to release! wavenet_vocoder $TAG ***" 19 | echo "Please make sure that release verion is correct." 20 | cat wavenet_vocoder/version.py 21 | echo "Please run the following command manually:" 22 | echo twine upload dist/wavenet_vocoder-${VERSION}.tar.gz --repository-url https://upload.pypi.org/legacy/ 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | from importlib.machinery import SourceFileLoader 6 | 7 | version = SourceFileLoader('wavenet_vocoder.version', 8 | 'wavenet_vocoder/version.py').load_module().version 9 | 10 | 11 | setup(name='wavenet_vocoder', 12 | version=version, 13 | description='PyTorch implementation of WaveNet vocoder', 14 | packages=find_packages(), 15 | install_requires=[ 16 | "numpy", 17 | "scipy", 18 | "torch >= 0.4.1", 19 | "docopt", 20 | "joblib", 21 | "tqdm", 22 | "tensorboardX", 23 | "nnmnkwii >= 0.0.11", 24 | "scikit-learn", 25 | "librosa", 26 | ], 27 | extras_require={ 28 | "test": [ 29 | "nose", 30 | "pysptk >= 0.1.9", 31 | "matplotlib", 32 | ], 33 | }) 34 | -------------------------------------------------------------------------------- /synthesis.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Synthesis waveform from trained WaveNet. 4 | 5 | usage: synthesis.py [options] 6 | 7 | options: 8 | --hparams= Hyper parameters [default: ]. 9 | --preset= Path of preset parameters (json). 10 | --length= Steps to generate [default: 32000]. 11 | --initial-value= Initial value for the WaveNet decoder. 12 | --conditional=

Conditional features path. 13 | --file-name-suffix= File name suffix [default: ]. 14 | --speaker-id= Speaker ID (for multi-speaker model). 15 | --output-html Output html for blog post. 16 | -h, --help Show help message. 17 | """ 18 | from docopt import docopt 19 | 20 | import sys 21 | import os 22 | from os.path import dirname, join, basename, splitext 23 | import torch 24 | import numpy as np 25 | from nnmnkwii import preprocessing as P 26 | from tqdm import tqdm 27 | import librosa 28 | 29 | from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw 30 | 31 | import audio 32 | from hparams import hparams 33 | 34 | from train import to_categorical 35 | 36 | 37 | torch.set_num_threads(4) 38 | use_cuda = torch.cuda.is_available() 39 | device = torch.device("cuda" if use_cuda else "cpu") 40 | 41 | 42 | def batch_wavegen(model, c=None, g=None, fast=True, tqdm=tqdm): 43 | from train import sanity_check 44 | sanity_check(model, c, g) 45 | assert c is not None 46 | B = c.shape[0] 47 | model.eval() 48 | if fast: 49 | model.make_generation_fast_() 50 | 51 | # Transform data to GPU 52 | g = None if g is None else g.to(device) 53 | c = None if c is None else c.to(device) 54 | 55 | if hparams.upsample_conditional_features: 56 | length = (c.shape[-1] - hparams.cin_pad * 2) * audio.get_hop_size() 57 | else: 58 | # already dupulicated 59 | length = c.shape[-1] 60 | 61 | with torch.no_grad(): 62 | y_hat = model.incremental_forward( 63 | c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True, 64 | log_scale_min=hparams.log_scale_min) 65 | 66 | if is_mulaw_quantize(hparams.input_type): 67 | # needs to be float since mulaw_inv returns in range of [-1, 1] 68 | y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy() 69 | for i in range(B): 70 | y_hat[i] = P.inv_mulaw_quantize(y_hat[i], hparams.quantize_channels - 1) 71 | elif is_mulaw(hparams.input_type): 72 | y_hat = y_hat.view(B, -1).cpu().data.numpy() 73 | for i in range(B): 74 | y_hat[i] = P.inv_mulaw(y_hat[i], hparams.quantize_channels - 1) 75 | else: 76 | y_hat = y_hat.view(B, -1).cpu().data.numpy() 77 | 78 | if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]: 79 | for i in range(B): 80 | y_hat[i] = getattr(audio, hparams.postprocess)(y_hat[i]) 81 | 82 | if hparams.global_gain_scale > 0: 83 | for i in range(B): 84 | y_hat[i] /= hparams.global_gain_scale 85 | 86 | return y_hat 87 | 88 | 89 | def _to_numpy(x): 90 | # this is ugly 91 | if x is None: 92 | return None 93 | if isinstance(x, np.ndarray) or np.isscalar(x): 94 | return x 95 | # remove batch axis 96 | if x.dim() == 3: 97 | x = x.squeeze(0) 98 | return x.numpy() 99 | 100 | 101 | def wavegen(model, length=None, c=None, g=None, initial_value=None, 102 | fast=False, tqdm=tqdm): 103 | """Generate waveform samples by WaveNet. 104 | 105 | Args: 106 | model (nn.Module) : WaveNet decoder 107 | length (int): Time steps to generate. If conditinlal features are given, 108 | then this is determined by the feature size. 109 | c (numpy.ndarray): Conditional features, of shape T x C 110 | g (scaler): Speaker ID 111 | initial_value (int) : initial_value for the WaveNet decoder. 112 | fast (Bool): Whether to remove weight normalization or not. 113 | tqdm (lambda): tqdm 114 | 115 | Returns: 116 | numpy.ndarray : Generated waveform samples 117 | """ 118 | from train import sanity_check 119 | sanity_check(model, c, g) 120 | 121 | c = _to_numpy(c) 122 | g = _to_numpy(g) 123 | 124 | model.eval() 125 | if fast: 126 | model.make_generation_fast_() 127 | 128 | if c is None: 129 | assert length is not None 130 | else: 131 | # (Tc, D) 132 | if c.ndim != 2: 133 | raise RuntimeError( 134 | "Expected 2-dim shape (T, {}) for the conditional feature, but {} was actually given.".format(hparams.cin_channels, c.shape)) 135 | assert c.ndim == 2 136 | Tc = c.shape[0] 137 | upsample_factor = audio.get_hop_size() 138 | # Overwrite length according to feature size 139 | length = Tc * upsample_factor 140 | # (Tc, D) -> (Tc', D) 141 | # Repeat features before feeding it to the network 142 | if not hparams.upsample_conditional_features: 143 | c = np.repeat(c, upsample_factor, axis=0) 144 | 145 | # B x C x T 146 | c = torch.FloatTensor(c.T).unsqueeze(0) 147 | 148 | if initial_value is None: 149 | if is_mulaw_quantize(hparams.input_type): 150 | initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1) 151 | else: 152 | initial_value = 0.0 153 | 154 | if is_mulaw_quantize(hparams.input_type): 155 | assert initial_value >= 0 and initial_value < hparams.quantize_channels 156 | initial_input = np_utils.to_categorical( 157 | initial_value, num_classes=hparams.quantize_channels).astype(np.float32) 158 | initial_input = torch.from_numpy(initial_input).view( 159 | 1, 1, hparams.quantize_channels) 160 | else: 161 | initial_input = torch.zeros(1, 1, 1).fill_(initial_value) 162 | 163 | g = None if g is None else torch.LongTensor([g]) 164 | 165 | # Transform data to GPU 166 | initial_input = initial_input.to(device) 167 | g = None if g is None else g.to(device) 168 | c = None if c is None else c.to(device) 169 | 170 | with torch.no_grad(): 171 | y_hat = model.incremental_forward( 172 | initial_input, c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True, 173 | log_scale_min=hparams.log_scale_min) 174 | 175 | if is_mulaw_quantize(hparams.input_type): 176 | y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() 177 | y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) 178 | elif is_mulaw(hparams.input_type): 179 | y_hat = P.inv_mulaw(y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels) 180 | else: 181 | y_hat = y_hat.view(-1).cpu().data.numpy() 182 | 183 | if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]: 184 | y_hat = getattr(audio, hparams.postprocess)(y_hat) 185 | 186 | if hparams.global_gain_scale > 0: 187 | y_hat /= hparams.global_gain_scale 188 | 189 | return y_hat 190 | 191 | 192 | if __name__ == "__main__": 193 | args = docopt(__doc__) 194 | print("Command line args:\n", args) 195 | checkpoint_path = args[""] 196 | dst_dir = args[""] 197 | 198 | length = int(args["--length"]) 199 | initial_value = args["--initial-value"] 200 | initial_value = None if initial_value is None else float(initial_value) 201 | conditional_path = args["--conditional"] 202 | 203 | file_name_suffix = args["--file-name-suffix"] 204 | output_html = args["--output-html"] 205 | speaker_id = args["--speaker-id"] 206 | speaker_id = None if speaker_id is None else int(speaker_id) 207 | preset = args["--preset"] 208 | 209 | # Load preset if specified 210 | if preset is not None: 211 | with open(preset) as f: 212 | hparams.parse_json(f.read()) 213 | # Override hyper parameters 214 | hparams.parse(args["--hparams"]) 215 | assert hparams.name == "wavenet_vocoder" 216 | 217 | # Load conditional features 218 | if conditional_path is not None: 219 | c = np.load(conditional_path) 220 | if c.shape[1] != hparams.num_mels: 221 | c = np.swapaxes(c, 0, 1) 222 | else: 223 | c = None 224 | 225 | from train import build_model 226 | 227 | # Model 228 | model = build_model().to(device) 229 | 230 | # Load checkpoint 231 | print("Load checkpoint from {}".format(checkpoint_path)) 232 | if use_cuda: 233 | checkpoint = torch.load(checkpoint_path) 234 | else: 235 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 236 | model.load_state_dict(checkpoint["state_dict"]) 237 | checkpoint_name = splitext(basename(checkpoint_path))[0] 238 | 239 | os.makedirs(dst_dir, exist_ok=True) 240 | dst_wav_path = join(dst_dir, "{}{}.wav".format(checkpoint_name, file_name_suffix)) 241 | 242 | # DO generate 243 | waveform = batch_wavegen(model, length, c=c, g=speaker_id, initial_value=initial_value, fast=True) 244 | 245 | # save 246 | librosa.output.write_wav(dst_wav_path, waveform, sr=hparams.sample_rate) 247 | 248 | print("Finished! Check out {} for generated audio samples.".format(dst_dir)) 249 | sys.exit(0) 250 | -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import sys 5 | from os.path import dirname, join 6 | sys.path.insert(0, join(dirname(__file__), "..")) 7 | 8 | import numpy as np 9 | from nose.plugins.attrib import attr 10 | 11 | import logging 12 | logging.getLogger('tensorflow').disabled = True 13 | 14 | 15 | @attr("local_only") 16 | def test_amp_to_db(): 17 | import audio 18 | x = np.random.rand(10) 19 | x_hat = audio._db_to_amp(audio._amp_to_db(x)) 20 | assert np.allclose(x, x_hat) 21 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | from wavenet_vocoder import receptive_field_size 5 | 6 | 7 | def test_receptive_field_size(): 8 | # Table 4 in https://arxiv.org/abs/1711.10433 9 | assert receptive_field_size(total_layers=30, num_cycles=3, kernel_size=3) == 6139 10 | assert receptive_field_size(total_layers=24, num_cycles=4, kernel_size=3) == 505 11 | assert receptive_field_size(total_layers=12, num_cycles=2, kernel_size=3) == 253 12 | assert receptive_field_size(total_layers=30, num_cycles=1, 13 | kernel_size=3, dilation=lambda x: 1) == 61 14 | -------------------------------------------------------------------------------- /tests/test_mixture.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | import librosa 10 | import pysptk 11 | from nose.plugins.attrib import attr 12 | 13 | 14 | from wavenet_vocoder.mixture import discretized_mix_logistic_loss 15 | from wavenet_vocoder.mixture import sample_from_discretized_mix_logistic 16 | from wavenet_vocoder.mixture import mix_gaussian_loss 17 | from wavenet_vocoder.mixture import sample_from_mix_gaussian 18 | 19 | 20 | def log_prob_from_logits(x): 21 | """ numerically stable log_softmax implementation that prevents overflow """ 22 | # TF ordering 23 | axis = len(x.size()) - 1 24 | m, _ = torch.max(x, dim=-1, keepdim=True) 25 | return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True)) 26 | 27 | 28 | @attr("mixture") 29 | def test_log_softmax(): 30 | x = torch.rand(2, 16000, 30) 31 | y = log_prob_from_logits(x) 32 | y_hat = F.log_softmax(x, -1) 33 | 34 | y = y.data.cpu().numpy() 35 | y_hat = y_hat.data.cpu().numpy() 36 | assert np.allclose(y, y_hat) 37 | 38 | 39 | @attr("mixture") 40 | def test_logistic_mixture(): 41 | np.random.seed(1234) 42 | 43 | x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None) 44 | assert sr == 16000 45 | 46 | T = len(x) 47 | x = x.reshape(1, T, 1) 48 | y = torch.from_numpy(x).float() 49 | y_hat = torch.rand(1, 30, T).float() 50 | 51 | print(y.shape, y_hat.shape) 52 | 53 | loss = discretized_mix_logistic_loss(y_hat, y) 54 | print(loss) 55 | 56 | loss = discretized_mix_logistic_loss(y_hat, y, reduce=False) 57 | print(loss.size(), y.size()) 58 | assert loss.size() == y.size() 59 | 60 | y = sample_from_discretized_mix_logistic(y_hat) 61 | print(y.shape) 62 | 63 | 64 | @attr("mixture") 65 | def test_gaussian_mixture(): 66 | np.random.seed(1234) 67 | 68 | x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None) 69 | assert sr == 16000 70 | 71 | T = len(x) 72 | x = x.reshape(1, T, 1) 73 | y = torch.from_numpy(x).float() 74 | y_hat = torch.rand(1, 30, T).float() 75 | 76 | print(y.shape, y_hat.shape) 77 | 78 | loss = mix_gaussian_loss(y_hat, y) 79 | print(loss) 80 | 81 | loss = mix_gaussian_loss(y_hat, y, reduce=False) 82 | print(loss.size(), y.size()) 83 | assert loss.size() == y.size() 84 | 85 | y = sample_from_mix_gaussian(y_hat) 86 | print(y.shape) 87 | 88 | 89 | @attr("mixture") 90 | def test_misc(): 91 | # https://en.wikipedia.org/wiki/Logistic_distribution 92 | # what i have learned 93 | # m = (x - mu) / s 94 | m = torch.rand(10, 10) 95 | log_pdf_mid1 = -2 * torch.log(torch.exp(m / 2) + torch.exp(-m / 2)) 96 | log_pdf_mid2 = m - 2 * F.softplus(m) 97 | assert np.allclose(log_pdf_mid1.data.numpy(), log_pdf_mid2.data.numpy()) 98 | 99 | # Edge case for 0 100 | plus_in = torch.rand(10, 10) 101 | log_cdf_plus1 = torch.sigmoid(m).log() 102 | log_cdf_plus2 = m - F.softplus(m) 103 | assert np.allclose(log_cdf_plus1.data.numpy(), log_cdf_plus2.data.numpy()) 104 | 105 | # Edge case for 255 106 | min_in = torch.rand(10, 10) 107 | log_one_minus_cdf_min1 = (1 - torch.sigmoid(min_in)).log() 108 | log_one_minus_cdf_min2 = -F.softplus(min_in) 109 | assert np.allclose(log_one_minus_cdf_min1.data.numpy(), log_one_minus_cdf_min2.data.numpy()) 110 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from nnmnkwii import preprocessing as P 8 | from pysptk.util import example_audio_file 9 | import librosa 10 | import numpy as np 11 | from tqdm import tqdm 12 | from os.path import join, dirname, exists 13 | from functools import partial 14 | from nose.plugins.attrib import attr 15 | 16 | from wavenet_vocoder.modules import ResidualConv1dGLU 17 | from wavenet_vocoder import WaveNet 18 | 19 | use_cuda = False 20 | device = torch.device("cuda" if use_cuda else "cpu") 21 | 22 | # For test 23 | build_compact_model = partial(WaveNet, layers=4, stacks=2, residual_channels=32, 24 | gate_channels=32, skip_out_channels=32, 25 | scalar_input=False) 26 | 27 | # https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py 28 | # copied to avoid keras dependency in tests 29 | 30 | 31 | def to_categorical(y, num_classes=None): 32 | """Converts a class vector (integers) to binary class matrix. 33 | E.g. for use with categorical_crossentropy. 34 | # Arguments 35 | y: class vector to be converted into a matrix 36 | (integers from 0 to num_classes). 37 | num_classes: total number of classes. 38 | # Returns 39 | A binary matrix representation of the input. 40 | """ 41 | y = np.array(y, dtype='int') 42 | input_shape = y.shape 43 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: 44 | input_shape = tuple(input_shape[:-1]) 45 | y = y.ravel() 46 | if not num_classes: 47 | num_classes = np.max(y) + 1 48 | n = y.shape[0] 49 | categorical = np.zeros((n, num_classes)) 50 | categorical[np.arange(n), y] = 1 51 | output_shape = input_shape + (num_classes,) 52 | categorical = np.reshape(categorical, output_shape) 53 | return categorical 54 | 55 | 56 | def test_conv_block(): 57 | conv = ResidualConv1dGLU(30, 30, kernel_size=3, dropout=1 - 0.95) 58 | print(conv) 59 | x = torch.zeros(16, 30, 16000) 60 | y, h = conv(x) 61 | print(y.size(), h.size()) 62 | 63 | 64 | def test_wavenet(): 65 | model = build_compact_model() 66 | print(model) 67 | x = torch.zeros(16, 256, 1000) 68 | y = model(x) 69 | print(y.size()) 70 | 71 | 72 | def _test_data(sr=4000, N=3000, returns_power=False, mulaw=True): 73 | x, _ = librosa.load(example_audio_file(), sr=sr) 74 | x, _ = librosa.effects.trim(x, top_db=15) 75 | 76 | # To save computational cost 77 | x = x[:N] 78 | 79 | # For power conditioning wavenet 80 | if returns_power: 81 | # (1 x N') 82 | p = librosa.feature.rms(x, frame_length=256, hop_length=128) 83 | upsample_factor = x.size // p.size 84 | # (1 x N) 85 | p = np.repeat(p, upsample_factor, axis=-1) 86 | if p.size < x.size: 87 | # pad against time axis 88 | p = np.pad(p, [(0, 0), (0, x.size - p.size)], mode="constant", constant_values=0) 89 | 90 | # shape adajst 91 | p = p.reshape(1, 1, -1) 92 | 93 | # (T,) 94 | if mulaw: 95 | x = P.mulaw_quantize(x) 96 | x_org = P.inv_mulaw_quantize(x) 97 | # (C, T) 98 | x = to_categorical(x, num_classes=256).T 99 | # (1, C, T) 100 | x = x.reshape(1, 256, -1).astype(np.float32) 101 | else: 102 | x_org = x 103 | x = x.reshape(1, 1, -1) 104 | 105 | if returns_power: 106 | return x, x_org, p 107 | 108 | return x, x_org 109 | 110 | 111 | @attr("mixture") 112 | def test_mixture_wavenet(): 113 | x, x_org, c = _test_data(returns_power=True, mulaw=False) 114 | # 10 mixtures 115 | model = build_compact_model(out_channels=3 * 10, cin_channels=1, 116 | scalar_input=True) 117 | T = x.shape[-1] 118 | print(model.first_conv) 119 | 120 | # scalar input, not one-hot 121 | assert x.shape[1] == 1 122 | 123 | x = torch.from_numpy(x).contiguous().to(device) 124 | c = torch.from_numpy(c).contiguous().to(device) 125 | 126 | # make batch 127 | x = x.expand((x.shape[0] * 2, x.shape[1], x.shape[2])) 128 | c = c.expand((c.shape[0] * 2, c.shape[1], c.shape[2])) 129 | 130 | print(c.size()) 131 | 132 | model.eval() 133 | 134 | # Incremental forward with forced teaching 135 | y_online = model.incremental_forward( 136 | test_inputs=x, c=c, T=None, tqdm=tqdm) 137 | 138 | assert y_online.size() == x.size() 139 | 140 | y_online2 = model.incremental_forward( 141 | test_inputs=None, c=c, T=T, tqdm=tqdm) 142 | 143 | assert y_online2.size() == x.size() 144 | print(x.size()) 145 | 146 | 147 | @attr("local_conditioning") 148 | def test_local_conditioning_correctness(): 149 | # condition by power 150 | x, x_org, c = _test_data(returns_power=True) 151 | model = build_compact_model(cin_channels=1) 152 | assert model.local_conditioning_enabled() 153 | assert not model.has_speaker_embedding() 154 | 155 | x = torch.from_numpy(x).contiguous().to(device) 156 | 157 | c = torch.from_numpy(c).contiguous().to(device) 158 | print(x.size(), c.size()) 159 | 160 | model.eval() 161 | 162 | y_offline = model(x, c=c, softmax=True) 163 | 164 | # Incremental forward with forced teaching 165 | y_online = model.incremental_forward( 166 | test_inputs=x, c=c, T=None, tqdm=tqdm, softmax=True, quantize=False) 167 | 168 | # (1 x C x T) 169 | c = (y_offline - y_online).abs() 170 | print(c.mean(), c.max()) 171 | 172 | try: 173 | assert np.allclose(y_offline.cpu().data.numpy(), 174 | y_online.cpu().data.numpy(), atol=1e-4) 175 | except Exception: 176 | from warnings import warn 177 | warn("oops! must be a bug!") 178 | 179 | 180 | @attr("local_conditioning") 181 | def test_local_conditioning_upsample_correctness(): 182 | # condition by power 183 | x, x_org, c = _test_data(returns_power=True) 184 | 185 | # downsample by 4 186 | assert c.shape[-1] % 4 == 0 187 | c = c[:, :, 0::4] 188 | 189 | model = build_compact_model( 190 | cin_channels=1, upsample_conditional_features=True, 191 | upsample_params={"upsample_scales": [2, 2], "cin_channels": 1}) 192 | assert model.local_conditioning_enabled() 193 | assert not model.has_speaker_embedding() 194 | 195 | x = torch.from_numpy(x).contiguous().to(device) 196 | 197 | c = torch.from_numpy(c).contiguous().to(device) 198 | print(x.size(), c.size()) 199 | 200 | model.eval() 201 | 202 | y_offline = model(x, c=c, softmax=True) 203 | 204 | # Incremental forward with forced teaching 205 | y_online = model.incremental_forward( 206 | test_inputs=x, c=c, T=None, tqdm=tqdm, softmax=True, quantize=False) 207 | 208 | # (1 x C x T) 209 | c = (y_offline - y_online).abs() 210 | print(c.mean(), c.max()) 211 | 212 | try: 213 | assert np.allclose(y_offline.cpu().data.numpy(), 214 | y_online.cpu().data.numpy(), atol=1e-4) 215 | except Exception: 216 | from warnings import warn 217 | warn("oops! must be a bug!") 218 | 219 | 220 | @attr("global_conditioning") 221 | def test_global_conditioning_with_embedding_correctness(): 222 | # condition by mean power 223 | x, x_org, c = _test_data(returns_power=True) 224 | g = c.mean(axis=-1, keepdims=True).astype(np.int) 225 | model = build_compact_model(gin_channels=16, n_speakers=256, 226 | use_speaker_embedding=True) 227 | assert not model.local_conditioning_enabled() 228 | assert model.has_speaker_embedding() 229 | 230 | x = torch.from_numpy(x).contiguous().to(device) 231 | 232 | g = torch.from_numpy(g).long().contiguous().to(device) 233 | print(g.size()) 234 | 235 | model.eval() 236 | 237 | y_offline = model(x, g=g, softmax=True) 238 | 239 | # Incremental forward with forced teaching 240 | y_online = model.incremental_forward( 241 | test_inputs=x, g=g, T=None, tqdm=tqdm, softmax=True, quantize=False) 242 | 243 | # (1 x C x T) 244 | c = (y_offline - y_online).abs() 245 | print(c.mean(), c.max()) 246 | 247 | try: 248 | assert np.allclose(y_offline.cpu().data.numpy(), 249 | y_online.cpu().data.numpy(), atol=1e-4) 250 | except Exception: 251 | from warnings import warn 252 | warn("oops! must be a bug!") 253 | 254 | 255 | @attr("global_conditioning") 256 | def test_global_conditioning_correctness(): 257 | # condition by mean power 258 | x, x_org, c = _test_data(returns_power=True) 259 | # must be floating-point type 260 | g = c.mean(axis=-1, keepdims=True).astype(np.float32) 261 | model = build_compact_model(gin_channels=1, use_speaker_embedding=False) 262 | assert not model.local_conditioning_enabled() 263 | # `use_speaker_embedding` False should diable embedding layer 264 | assert not model.has_speaker_embedding() 265 | 266 | x = torch.from_numpy(x).contiguous().to(device) 267 | 268 | g = torch.from_numpy(g).contiguous().to(device) 269 | print(g.size()) 270 | 271 | model.eval() 272 | y_offline = model(x, g=g, softmax=True) 273 | 274 | # Incremental forward with forced teaching 275 | y_online = model.incremental_forward( 276 | test_inputs=x, g=g, T=None, tqdm=tqdm, softmax=True, quantize=False) 277 | 278 | # (1 x C x T) 279 | c = (y_offline - y_online).abs() 280 | print(c.mean(), c.max()) 281 | 282 | try: 283 | assert np.allclose(y_offline.cpu().data.numpy(), 284 | y_online.cpu().data.numpy(), atol=1e-4) 285 | except Exception: 286 | from warnings import warn 287 | warn("oops! must be a bug!") 288 | 289 | 290 | @attr("local_and_global_conditioning") 291 | def test_global_and_local_conditioning_correctness(): 292 | x, x_org, c = _test_data(returns_power=True) 293 | g = c.mean(axis=-1, keepdims=True).astype(np.int) 294 | model = build_compact_model( 295 | cin_channels=1, gin_channels=16, use_speaker_embedding=True, n_speakers=256) 296 | assert model.local_conditioning_enabled() 297 | assert model.has_speaker_embedding() 298 | 299 | x = torch.from_numpy(x).contiguous().to(device) 300 | 301 | # per-sample power 302 | c = torch.from_numpy(c).contiguous().to(device) 303 | 304 | # mean power 305 | g = torch.from_numpy(g).long().contiguous().to(device) 306 | 307 | print(c.size(), g.size()) 308 | 309 | model.eval() 310 | 311 | y_offline = model(x, c=c, g=g, softmax=True) 312 | 313 | # Incremental forward with forced teaching 314 | y_online = model.incremental_forward( 315 | test_inputs=x, c=c, g=g, T=None, tqdm=tqdm, softmax=True, quantize=False) 316 | # (1 x C x T) 317 | 318 | c = (y_offline - y_online).abs() 319 | print(c.mean(), c.max()) 320 | 321 | try: 322 | assert np.allclose(y_offline.cpu().data.numpy(), 323 | y_online.cpu().data.numpy(), atol=1e-4) 324 | except Exception: 325 | from warnings import warn 326 | warn("oops! must be a bug!") 327 | 328 | 329 | @attr("local_only") 330 | def test_incremental_forward_correctness(): 331 | import librosa.display 332 | from matplotlib import pyplot as plt 333 | 334 | model = build_compact_model().to(device) 335 | 336 | checkpoint_path = join(dirname(__file__), "..", "foobar/checkpoint_step000058000.pth") 337 | if exists(checkpoint_path): 338 | print("Loading from:", checkpoint_path) 339 | checkpoint = torch.load(checkpoint_path) 340 | model.load_state_dict(checkpoint["state_dict"]) 341 | 342 | sr = 4000 343 | x, x_org = _test_data(sr=sr, N=3000) 344 | x = torch.from_numpy(x).contiguous().to(device) 345 | 346 | model.eval() 347 | 348 | # Batch forward 349 | y_offline = model(x, softmax=True) 350 | 351 | # Test from zero start 352 | y_online = model.incremental_forward(initial_input=None, T=100, tqdm=tqdm, softmax=True) 353 | 354 | # Incremental forward with forced teaching 355 | y_online = model.incremental_forward(test_inputs=x, tqdm=tqdm, softmax=True, quantize=False) 356 | 357 | # (1 x C x T) 358 | c = (y_offline - y_online).abs() 359 | print(c.mean(), c.max()) 360 | 361 | try: 362 | assert np.allclose(y_offline.cpu().data.numpy(), 363 | y_online.cpu().data.numpy(), atol=1e-4) 364 | except Exception: 365 | from warnings import warn 366 | warn("oops! must be a bug!") 367 | 368 | # (1, T, C) 369 | xt = x.transpose(1, 2).contiguous() 370 | 371 | initial_input = xt[:, 0, :].unsqueeze(1).contiguous() 372 | print(initial_input.size()) 373 | print("Inital value:", initial_input.view(-1).max(0)[1]) 374 | 375 | # With zero start 376 | zerostart = True 377 | if zerostart: 378 | y_inference = model.incremental_forward( 379 | initial_input=initial_input, T=xt.size(1), tqdm=tqdm, softmax=True, quantize=True) 380 | else: 381 | # Feed a few samples as test_inputs and then generate auto-regressively 382 | N = 1000 383 | y_inference = model.incremental_forward( 384 | initial_input=None, test_inputs=xt[:, :N, :], 385 | T=xt.size(1), tqdm=tqdm, softmax=True, quantize=True) 386 | 387 | # Waveforms 388 | # (T,) 389 | y_offline = y_offline.max(1)[1].view(-1) 390 | y_online = y_online.max(1)[1].view(-1) 391 | y_inference = y_inference.max(1)[1].view(-1) 392 | 393 | y_offline = P.inv_mulaw_quantize(y_offline.cpu().data.long().numpy()) 394 | y_online = P.inv_mulaw_quantize(y_online.cpu().data.long().numpy()) 395 | y_inference = P.inv_mulaw_quantize(y_inference.cpu().data.long().numpy()) 396 | 397 | plt.figure(figsize=(16, 10)) 398 | plt.subplot(4, 1, 1) 399 | librosa.display.waveplot(x_org, sr=sr) 400 | plt.subplot(4, 1, 2) 401 | librosa.display.waveplot(y_offline, sr=sr) 402 | plt.subplot(4, 1, 3) 403 | librosa.display.waveplot(y_online, sr=sr) 404 | plt.subplot(4, 1, 4) 405 | librosa.display.waveplot(y_inference, sr=sr) 406 | plt.show() 407 | 408 | save_audio = False 409 | if save_audio: 410 | librosa.output.write_wav("target.wav", x_org, sr=sr) 411 | librosa.output.write_wav("online.wav", y_online, sr=sr) 412 | librosa.output.write_wav("inference.wav", y_inference, sr=sr) 413 | -------------------------------------------------------------------------------- /tojson.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Dump hyper parameters to json file. 4 | 5 | usage: tojson.py [options] 6 | 7 | options: 8 | --hparams= Hyper parameters [default: ]. 9 | -h, --help Show help message. 10 | """ 11 | from docopt import docopt 12 | 13 | import sys 14 | import os 15 | from os.path import dirname, join, basename, splitext 16 | import json 17 | 18 | from hparams import hparams 19 | 20 | if __name__ == "__main__": 21 | args = docopt(__doc__) 22 | output_json_path = args[""] 23 | 24 | hparams.parse(args["--hparams"]) 25 | j = hparams.values() 26 | with open(output_json_path, "w") as f: 27 | json.dump(j, f, indent=2) 28 | sys.exit(0) 29 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E704,E721,E741,F401,F403,F405,F821,F841,F999 4 | exclude = docs/,data,build,dist,notebooks,checkpoints*,legacy 5 | -------------------------------------------------------------------------------- /utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /wavenet_vocoder/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | from .version import version as __version__ 5 | 6 | from .wavenet import receptive_field_size, WaveNet 7 | -------------------------------------------------------------------------------- /wavenet_vocoder/conv.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Conv1d(nn.Conv1d): 8 | """Extended nn.Conv1d for incremental dilated convolutions 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.clear_buffer() 14 | self._linearized_weight = None 15 | self.register_backward_hook(self._clear_linearized_weight) 16 | 17 | def incremental_forward(self, input): 18 | # input: (B, T, C) 19 | if self.training: 20 | raise RuntimeError('incremental_forward only supports eval mode') 21 | 22 | # run forward pre hooks (e.g., weight norm) 23 | for hook in self._forward_pre_hooks.values(): 24 | hook(self, input) 25 | 26 | # reshape weight 27 | weight = self._get_linearized_weight() 28 | kw = self.kernel_size[0] 29 | dilation = self.dilation[0] 30 | 31 | bsz = input.size(0) # input: bsz x len x dim 32 | if kw > 1: 33 | input = input.data 34 | if self.input_buffer is None: 35 | self.input_buffer = input.new(bsz, kw + (kw - 1) * (dilation - 1), input.size(2)) 36 | self.input_buffer.zero_() 37 | else: 38 | # shift buffer 39 | self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() 40 | # append next input 41 | self.input_buffer[:, -1, :] = input[:, -1, :] 42 | input = self.input_buffer 43 | if dilation > 1: 44 | input = input[:, 0::dilation, :].contiguous() 45 | output = F.linear(input.view(bsz, -1), weight, self.bias) 46 | return output.view(bsz, 1, -1) 47 | 48 | def clear_buffer(self): 49 | self.input_buffer = None 50 | 51 | def _get_linearized_weight(self): 52 | if self._linearized_weight is None: 53 | kw = self.kernel_size[0] 54 | # nn.Conv1d 55 | if self.weight.size() == (self.out_channels, self.in_channels, kw): 56 | weight = self.weight.transpose(1, 2).contiguous() 57 | else: 58 | # fairseq.modules.conv_tbc.ConvTBC 59 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 60 | assert weight.size() == (self.out_channels, kw, self.in_channels) 61 | self._linearized_weight = weight.view(self.out_channels, -1) 62 | return self._linearized_weight 63 | 64 | def _clear_linearized_weight(self, *args): 65 | self._linearized_weight = None 66 | -------------------------------------------------------------------------------- /wavenet_vocoder/mixture.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Code is adapted from: 3 | # https://github.com/pclucas14/pixel-cnn-pp 4 | # https://github.com/openai/pixel-cnn 5 | 6 | from __future__ import with_statement, print_function, absolute_import 7 | 8 | import math 9 | import numpy as np 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.distributions import Normal 15 | 16 | 17 | def log_sum_exp(x): 18 | """ numerically stable log_sum_exp implementation that prevents overflow """ 19 | # TF ordering 20 | axis = len(x.size()) - 1 21 | m, _ = torch.max(x, dim=axis) 22 | m2, _ = torch.max(x, dim=axis, keepdim=True) 23 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 24 | 25 | 26 | def discretized_mix_logistic_loss(y_hat, y, num_classes=256, 27 | log_scale_min=-7.0, reduce=True): 28 | """Discretized mixture of logistic distributions loss 29 | 30 | Note that it is assumed that input is scaled to [-1, 1]. 31 | 32 | Args: 33 | y_hat (Tensor): Predicted output (B x C x T) 34 | y (Tensor): Target (B x T x 1). 35 | num_classes (int): Number of classes 36 | log_scale_min (float): Log scale minimum value 37 | reduce (bool): If True, the losses are averaged or summed for each 38 | minibatch. 39 | 40 | Returns 41 | Tensor: loss 42 | """ 43 | assert y_hat.dim() == 3 44 | assert y_hat.size(1) % 3 == 0 45 | nr_mix = y_hat.size(1) // 3 46 | 47 | # (B x T x C) 48 | y_hat = y_hat.transpose(1, 2) 49 | 50 | # unpack parameters. (B, T, num_mixtures) x 3 51 | logit_probs = y_hat[:, :, :nr_mix] 52 | means = y_hat[:, :, nr_mix:2 * nr_mix] 53 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 54 | 55 | # B x T x 1 -> B x T x num_mixtures 56 | y = y.expand_as(means) 57 | 58 | centered_y = y - means 59 | inv_stdv = torch.exp(-log_scales) 60 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) 61 | cdf_plus = torch.sigmoid(plus_in) 62 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) 63 | cdf_min = torch.sigmoid(min_in) 64 | 65 | # log probability for edge case of 0 (before scaling) 66 | # equivalent: torch.log(torch.sigmoid(plus_in)) 67 | log_cdf_plus = plus_in - F.softplus(plus_in) 68 | 69 | # log probability for edge case of 255 (before scaling) 70 | # equivalent: (1 - torch.sigmoid(min_in)).log() 71 | log_one_minus_cdf_min = -F.softplus(min_in) 72 | 73 | # probability for all other cases 74 | cdf_delta = cdf_plus - cdf_min 75 | 76 | mid_in = inv_stdv * centered_y 77 | # log probability in the center of the bin, to be used in extreme cases 78 | # (not actually used in our code) 79 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 80 | 81 | # tf equivalent 82 | """ 83 | log_probs = tf.where(x < -0.999, log_cdf_plus, 84 | tf.where(x > 0.999, log_one_minus_cdf_min, 85 | tf.where(cdf_delta > 1e-5, 86 | tf.log(tf.maximum(cdf_delta, 1e-12)), 87 | log_pdf_mid - np.log(127.5)))) 88 | """ 89 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value 90 | # for num_classes=65536 case? 1e-7? not sure.. 91 | inner_inner_cond = (cdf_delta > 1e-5).float() 92 | 93 | inner_inner_out = inner_inner_cond * \ 94 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ 95 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) 96 | inner_cond = (y > 0.999).float() 97 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 98 | cond = (y < -0.999).float() 99 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 100 | 101 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 102 | 103 | if reduce: 104 | return -torch.sum(log_sum_exp(log_probs)) 105 | else: 106 | return -log_sum_exp(log_probs).unsqueeze(-1) 107 | 108 | 109 | def to_one_hot(tensor, n, fill_with=1.): 110 | # we perform one hot encore with respect to the last axis 111 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 112 | if tensor.is_cuda: 113 | one_hot = one_hot.cuda() 114 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 115 | return one_hot 116 | 117 | 118 | def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, 119 | clamp_log_scale=False): 120 | """ 121 | Sample from discretized mixture of logistic distributions 122 | 123 | Args: 124 | y (Tensor): B x C x T 125 | log_scale_min (float): Log scale minimum value 126 | 127 | Returns: 128 | Tensor: sample in range of [-1, 1]. 129 | """ 130 | assert y.size(1) % 3 == 0 131 | nr_mix = y.size(1) // 3 132 | 133 | # B x T x C 134 | y = y.transpose(1, 2) 135 | logit_probs = y[:, :, :nr_mix] 136 | 137 | # sample mixture indicator from softmax 138 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 139 | temp = logit_probs.data - torch.log(- torch.log(temp)) 140 | _, argmax = temp.max(dim=-1) 141 | 142 | # (B, T) -> (B, T, nr_mix) 143 | one_hot = to_one_hot(argmax, nr_mix) 144 | # select logistic parameters 145 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 146 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1) 147 | if clamp_log_scale: 148 | log_scales = torch.clamp(log_scales, min=log_scale_min) 149 | # sample from logistic & clip to interval 150 | # we don't actually round to the nearest 8bit value when sampling 151 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) 152 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 153 | 154 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 155 | 156 | return x 157 | 158 | 159 | # we can easily define discretized version of the gaussian loss, however, 160 | # use continuous version as same as the https://clarinet-demo.github.io/ 161 | def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True): 162 | """Mixture of continuous gaussian distributions loss 163 | 164 | Note that it is assumed that input is scaled to [-1, 1]. 165 | 166 | Args: 167 | y_hat (Tensor): Predicted output (B x C x T) 168 | y (Tensor): Target (B x T x 1). 169 | log_scale_min (float): Log scale minimum value 170 | reduce (bool): If True, the losses are averaged or summed for each 171 | minibatch. 172 | Returns 173 | Tensor: loss 174 | """ 175 | assert y_hat.dim() == 3 176 | C = y_hat.size(1) 177 | if C == 2: 178 | nr_mix = 1 179 | else: 180 | assert y_hat.size(1) % 3 == 0 181 | nr_mix = y_hat.size(1) // 3 182 | 183 | # (B x T x C) 184 | y_hat = y_hat.transpose(1, 2) 185 | 186 | # unpack parameters. 187 | if C == 2: 188 | # special case for C == 2, just for compatibility 189 | logit_probs = None 190 | means = y_hat[:, :, 0:1] 191 | log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min) 192 | else: 193 | # (B, T, num_mixtures) x 3 194 | logit_probs = y_hat[:, :, :nr_mix] 195 | means = y_hat[:, :, nr_mix:2 * nr_mix] 196 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 197 | 198 | # B x T x 1 -> B x T x num_mixtures 199 | y = y.expand_as(means) 200 | 201 | centered_y = y - means 202 | dist = Normal(loc=0., scale=torch.exp(log_scales)) 203 | # do we need to add a trick to avoid log(0)? 204 | log_probs = dist.log_prob(centered_y) 205 | 206 | if nr_mix > 1: 207 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 208 | 209 | if reduce: 210 | if nr_mix == 1: 211 | return -torch.sum(log_probs) 212 | else: 213 | return -torch.sum(log_sum_exp(log_probs)) 214 | else: 215 | if nr_mix == 1: 216 | return -log_probs 217 | else: 218 | return -log_sum_exp(log_probs).unsqueeze(-1) 219 | 220 | 221 | def sample_from_mix_gaussian(y, log_scale_min=-7.0): 222 | """ 223 | Sample from (discretized) mixture of gaussian distributions 224 | Args: 225 | y (Tensor): B x C x T 226 | log_scale_min (float): Log scale minimum value 227 | Returns: 228 | Tensor: sample in range of [-1, 1]. 229 | """ 230 | C = y.size(1) 231 | if C == 2: 232 | nr_mix = 1 233 | else: 234 | assert y.size(1) % 3 == 0 235 | nr_mix = y.size(1) // 3 236 | 237 | # B x T x C 238 | y = y.transpose(1, 2) 239 | 240 | if C == 2: 241 | logit_probs = None 242 | else: 243 | logit_probs = y[:, :, :nr_mix] 244 | 245 | if nr_mix > 1: 246 | # sample mixture indicator from softmax 247 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 248 | temp = logit_probs.data - torch.log(- torch.log(temp)) 249 | _, argmax = temp.max(dim=-1) 250 | 251 | # (B, T) -> (B, T, nr_mix) 252 | one_hot = to_one_hot(argmax, nr_mix) 253 | 254 | # Select means and log scales 255 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 256 | log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1) 257 | else: 258 | if C == 2: 259 | means, log_scales = y[:, :, 0], y[:, :, 1] 260 | elif C == 3: 261 | means, log_scales = y[:, :, 1], y[:, :, 2] 262 | else: 263 | assert False, "shouldn't happen" 264 | 265 | scales = torch.exp(log_scales) 266 | dist = Normal(loc=means, scale=scales) 267 | x = dist.sample() 268 | 269 | x = torch.clamp(x, min=-1.0, max=1.0) 270 | return x 271 | -------------------------------------------------------------------------------- /wavenet_vocoder/modules.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from wavenet_vocoder import conv 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): 14 | m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) 15 | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") 16 | if m.bias is not None: 17 | nn.init.constant_(m.bias, 0) 18 | return nn.utils.weight_norm(m) 19 | 20 | 21 | def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): 22 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 23 | m.weight.data.normal_(0, std) 24 | return m 25 | 26 | 27 | def ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs): 28 | freq_axis_kernel_size = kernel_size[0] 29 | m = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs) 30 | m.weight.data.fill_(1.0 / freq_axis_kernel_size) 31 | m.bias.data.zero_() 32 | return nn.utils.weight_norm(m) 33 | 34 | 35 | def Conv1d1x1(in_channels, out_channels, bias=True): 36 | """1-by-1 convolution layer 37 | """ 38 | return Conv1d(in_channels, out_channels, kernel_size=1, padding=0, 39 | dilation=1, bias=bias) 40 | 41 | 42 | def _conv1x1_forward(conv, x, is_incremental): 43 | """Conv1x1 forward 44 | """ 45 | if is_incremental: 46 | x = conv.incremental_forward(x) 47 | else: 48 | x = conv(x) 49 | return x 50 | 51 | 52 | class ResidualConv1dGLU(nn.Module): 53 | """Residual dilated conv1d + Gated linear unit 54 | 55 | Args: 56 | residual_channels (int): Residual input / output channels 57 | gate_channels (int): Gated activation channels. 58 | kernel_size (int): Kernel size of convolution layers. 59 | skip_out_channels (int): Skip connection channels. If None, set to same 60 | as ``residual_channels``. 61 | cin_channels (int): Local conditioning channels. If negative value is 62 | set, local conditioning is disabled. 63 | gin_channels (int): Global conditioning channels. If negative value is 64 | set, global conditioning is disabled. 65 | dropout (float): Dropout probability. 66 | padding (int): Padding for convolution layers. If None, proper padding 67 | is computed depends on dilation and kernel_size. 68 | dilation (int): Dilation factor. 69 | """ 70 | 71 | def __init__(self, residual_channels, gate_channels, kernel_size, 72 | skip_out_channels=None, 73 | cin_channels=-1, gin_channels=-1, 74 | dropout=1 - 0.95, padding=None, dilation=1, causal=True, 75 | bias=True, *args, **kwargs): 76 | super(ResidualConv1dGLU, self).__init__() 77 | self.dropout = dropout 78 | if skip_out_channels is None: 79 | skip_out_channels = residual_channels 80 | if padding is None: 81 | # no future time stamps available 82 | if causal: 83 | padding = (kernel_size - 1) * dilation 84 | else: 85 | padding = (kernel_size - 1) // 2 * dilation 86 | self.causal = causal 87 | 88 | self.conv = Conv1d(residual_channels, gate_channels, kernel_size, 89 | padding=padding, dilation=dilation, 90 | bias=bias, *args, **kwargs) 91 | 92 | # local conditioning 93 | if cin_channels > 0: 94 | self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) 95 | else: 96 | self.conv1x1c = None 97 | 98 | # global conditioning 99 | if gin_channels > 0: 100 | self.conv1x1g = Conv1d1x1(gin_channels, gate_channels, bias=False) 101 | else: 102 | self.conv1x1g = None 103 | 104 | # conv output is split into two groups 105 | gate_out_channels = gate_channels // 2 106 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 107 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) 108 | 109 | def forward(self, x, c=None, g=None): 110 | return self._forward(x, c, g, False) 111 | 112 | def incremental_forward(self, x, c=None, g=None): 113 | return self._forward(x, c, g, True) 114 | 115 | def _forward(self, x, c, g, is_incremental): 116 | """Forward 117 | 118 | Args: 119 | x (Tensor): B x C x T 120 | c (Tensor): B x C x T, Local conditioning features 121 | g (Tensor): B x C x T, Expanded global conditioning features 122 | is_incremental (Bool) : Whether incremental mode or not 123 | 124 | Returns: 125 | Tensor: output 126 | """ 127 | residual = x 128 | x = F.dropout(x, p=self.dropout, training=self.training) 129 | if is_incremental: 130 | splitdim = -1 131 | x = self.conv.incremental_forward(x) 132 | else: 133 | splitdim = 1 134 | x = self.conv(x) 135 | # remove future time steps 136 | x = x[:, :, :residual.size(-1)] if self.causal else x 137 | 138 | a, b = x.split(x.size(splitdim) // 2, dim=splitdim) 139 | 140 | # local conditioning 141 | if c is not None: 142 | assert self.conv1x1c is not None 143 | c = _conv1x1_forward(self.conv1x1c, c, is_incremental) 144 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 145 | a, b = a + ca, b + cb 146 | 147 | # global conditioning 148 | if g is not None: 149 | assert self.conv1x1g is not None 150 | g = _conv1x1_forward(self.conv1x1g, g, is_incremental) 151 | ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) 152 | a, b = a + ga, b + gb 153 | 154 | x = torch.tanh(a) * torch.sigmoid(b) 155 | 156 | # For skip connection 157 | s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) 158 | 159 | # For residual connection 160 | x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) 161 | 162 | x = (x + residual) * math.sqrt(0.5) 163 | return x, s 164 | 165 | def clear_buffer(self): 166 | for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, 167 | self.conv1x1c, self.conv1x1g]: 168 | if c is not None: 169 | c.clear_buffer() 170 | -------------------------------------------------------------------------------- /wavenet_vocoder/tfcompat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r9y9/wavenet_vocoder/a35fff76ea3687b05e1a10023cad3f7f64fa25a3/wavenet_vocoder/tfcompat/__init__.py -------------------------------------------------------------------------------- /wavenet_vocoder/tfcompat/readme.md: -------------------------------------------------------------------------------- 1 | Source: hparam.py copied from tensorflow v1.12.0. 2 | 3 | https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py 4 | 5 | with the following: 6 | wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py 7 | 8 | Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project. 9 | -------------------------------------------------------------------------------- /wavenet_vocoder/upsample.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | 12 | class Stretch2d(nn.Module): 13 | def __init__(self, x_scale, y_scale, mode="nearest"): 14 | super(Stretch2d, self).__init__() 15 | self.x_scale = x_scale 16 | self.y_scale = y_scale 17 | self.mode = mode 18 | 19 | def forward(self, x): 20 | return F.interpolate( 21 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) 22 | 23 | 24 | def _get_activation(upsample_activation): 25 | nonlinear = getattr(nn, upsample_activation) 26 | return nonlinear 27 | 28 | 29 | class UpsampleNetwork(nn.Module): 30 | def __init__(self, upsample_scales, upsample_activation="none", 31 | upsample_activation_params={}, mode="nearest", 32 | freq_axis_kernel_size=1, cin_pad=0, cin_channels=80): 33 | super(UpsampleNetwork, self).__init__() 34 | self.up_layers = nn.ModuleList() 35 | total_scale = np.prod(upsample_scales) 36 | self.indent = cin_pad * total_scale 37 | for scale in upsample_scales: 38 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 39 | k_size = (freq_axis_kernel_size, scale * 2 + 1) 40 | padding = (freq_axis_padding, scale) 41 | stretch = Stretch2d(scale, 1, mode) 42 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) 43 | conv.weight.data.fill_(1. / np.prod(k_size)) 44 | conv = nn.utils.weight_norm(conv) 45 | self.up_layers.append(stretch) 46 | self.up_layers.append(conv) 47 | if upsample_activation != "none": 48 | nonlinear = _get_activation(upsample_activation) 49 | self.up_layers.append(nonlinear(**upsample_activation_params)) 50 | 51 | def forward(self, c): 52 | """ 53 | Args: 54 | c : B x C x T 55 | """ 56 | 57 | # B x 1 x C x T 58 | c = c.unsqueeze(1) 59 | for f in self.up_layers: 60 | c = f(c) 61 | # B x C x T 62 | c = c.squeeze(1) 63 | 64 | if self.indent > 0: 65 | c = c[:, :, self.indent:-self.indent] 66 | return c 67 | 68 | 69 | class ConvInUpsampleNetwork(nn.Module): 70 | def __init__(self, upsample_scales, upsample_activation="none", 71 | upsample_activation_params={}, mode="nearest", 72 | freq_axis_kernel_size=1, cin_pad=0, 73 | cin_channels=80): 74 | super(ConvInUpsampleNetwork, self).__init__() 75 | # To capture wide-context information in conditional features 76 | # meaningless if cin_pad == 0 77 | ks = 2 * cin_pad + 1 78 | self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, bias=False) 79 | self.upsample = UpsampleNetwork( 80 | upsample_scales, upsample_activation, upsample_activation_params, 81 | mode, freq_axis_kernel_size, cin_pad=0, cin_channels=cin_channels) 82 | 83 | def forward(self, c): 84 | c_up = self.upsample(self.conv_in(c)) 85 | return c_up 86 | -------------------------------------------------------------------------------- /wavenet_vocoder/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | 5 | def _assert_valid_input_type(s): 6 | assert s == "mulaw-quantize" or s == "mulaw" or s == "raw" 7 | 8 | 9 | def is_mulaw_quantize(s): 10 | _assert_valid_input_type(s) 11 | return s == "mulaw-quantize" 12 | 13 | 14 | def is_mulaw(s): 15 | _assert_valid_input_type(s) 16 | return s == "mulaw" 17 | 18 | 19 | def is_raw(s): 20 | _assert_valid_input_type(s) 21 | return s == "raw" 22 | 23 | 24 | def is_scalar_input(s): 25 | return is_raw(s) or is_mulaw(s) 26 | -------------------------------------------------------------------------------- /wavenet_vocoder/version.py: -------------------------------------------------------------------------------- 1 | version = '0.2.0' 2 | -------------------------------------------------------------------------------- /wavenet_vocoder/wavenet.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import with_statement, print_function, absolute_import 3 | 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from .modules import Embedding 12 | 13 | from .modules import Conv1d1x1, ResidualConv1dGLU, ConvTranspose2d 14 | from .mixture import sample_from_discretized_mix_logistic 15 | from .mixture import sample_from_mix_gaussian 16 | from wavenet_vocoder import upsample 17 | 18 | 19 | def _expand_global_features(B, T, g, bct=True): 20 | """Expand global conditioning features to all time steps 21 | 22 | Args: 23 | B (int): Batch size. 24 | T (int): Time length. 25 | g (Tensor): Global features, (B x C) or (B x C x 1). 26 | bct (bool) : returns (B x C x T) if True, otherwise (B x T x C) 27 | 28 | Returns: 29 | Tensor: B x C x T or B x T x C or None 30 | """ 31 | if g is None: 32 | return None 33 | g = g.unsqueeze(-1) if g.dim() == 2 else g 34 | if bct: 35 | g_bct = g.expand(B, -1, T) 36 | return g_bct.contiguous() 37 | else: 38 | g_btc = g.expand(B, -1, T).transpose(1, 2) 39 | return g_btc.contiguous() 40 | 41 | 42 | def receptive_field_size(total_layers, num_cycles, kernel_size, 43 | dilation=lambda x: 2**x): 44 | """Compute receptive field size 45 | 46 | Args: 47 | total_layers (int): total layers 48 | num_cycles (int): cycles 49 | kernel_size (int): kernel size 50 | dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` 51 | to disable dilated convolution. 52 | 53 | Returns: 54 | int: receptive field size in sample 55 | 56 | """ 57 | assert total_layers % num_cycles == 0 58 | layers_per_cycle = total_layers // num_cycles 59 | dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] 60 | return (kernel_size - 1) * sum(dilations) + 1 61 | 62 | 63 | class WaveNet(nn.Module): 64 | """The WaveNet model that supports local and global conditioning. 65 | 66 | Args: 67 | out_channels (int): Output channels. If input_type is mu-law quantized 68 | one-hot vecror. this must equal to the quantize channels. Other wise 69 | num_mixtures x 3 (pi, mu, log_scale). 70 | layers (int): Number of total layers 71 | stacks (int): Number of dilation cycles 72 | residual_channels (int): Residual input / output channels 73 | gate_channels (int): Gated activation channels. 74 | skip_out_channels (int): Skip connection channels. 75 | kernel_size (int): Kernel size of convolution layers. 76 | dropout (float): Dropout probability. 77 | cin_channels (int): Local conditioning channels. If negative value is 78 | set, local conditioning is disabled. 79 | gin_channels (int): Global conditioning channels. If negative value is 80 | set, global conditioning is disabled. 81 | n_speakers (int): Number of speakers. Used only if global conditioning 82 | is enabled. 83 | upsample_conditional_features (bool): Whether upsampling local 84 | conditioning features by transposed convolution layers or not. 85 | upsample_scales (list): List of upsample scale. 86 | ``np.prod(upsample_scales)`` must equal to hop size. Used only if 87 | upsample_conditional_features is enabled. 88 | freq_axis_kernel_size (int): Freq-axis kernel_size for transposed 89 | convolution layers for upsampling. If you only care about time-axis 90 | upsampling, set this to 1. 91 | scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise 92 | quantized one-hot vector is expected. 93 | use_speaker_embedding (Bool): Use speaker embedding or Not. Set to False 94 | if you want to disable embedding layer and use external features 95 | directly. 96 | """ 97 | 98 | def __init__(self, out_channels=256, layers=20, stacks=2, 99 | residual_channels=512, 100 | gate_channels=512, 101 | skip_out_channels=512, 102 | kernel_size=3, dropout=1 - 0.95, 103 | cin_channels=-1, gin_channels=-1, n_speakers=None, 104 | upsample_conditional_features=False, 105 | upsample_net="ConvInUpsampleNetwork", 106 | upsample_params={"upsample_scales": [4, 4, 4, 4]}, 107 | scalar_input=False, 108 | use_speaker_embedding=False, 109 | output_distribution="Logistic", 110 | cin_pad=0, 111 | ): 112 | super(WaveNet, self).__init__() 113 | self.scalar_input = scalar_input 114 | self.out_channels = out_channels 115 | self.cin_channels = cin_channels 116 | self.output_distribution = output_distribution 117 | assert layers % stacks == 0 118 | layers_per_stack = layers // stacks 119 | if scalar_input: 120 | self.first_conv = Conv1d1x1(1, residual_channels) 121 | else: 122 | self.first_conv = Conv1d1x1(out_channels, residual_channels) 123 | 124 | self.conv_layers = nn.ModuleList() 125 | for layer in range(layers): 126 | dilation = 2**(layer % layers_per_stack) 127 | conv = ResidualConv1dGLU( 128 | residual_channels, gate_channels, 129 | kernel_size=kernel_size, 130 | skip_out_channels=skip_out_channels, 131 | bias=True, # magenda uses bias, but musyoku doesn't 132 | dilation=dilation, dropout=dropout, 133 | cin_channels=cin_channels, 134 | gin_channels=gin_channels) 135 | self.conv_layers.append(conv) 136 | self.last_conv_layers = nn.ModuleList([ 137 | nn.ReLU(inplace=True), 138 | Conv1d1x1(skip_out_channels, skip_out_channels), 139 | nn.ReLU(inplace=True), 140 | Conv1d1x1(skip_out_channels, out_channels), 141 | ]) 142 | 143 | if gin_channels > 0 and use_speaker_embedding: 144 | assert n_speakers is not None 145 | self.embed_speakers = Embedding( 146 | n_speakers, gin_channels, padding_idx=None, std=0.1) 147 | else: 148 | self.embed_speakers = None 149 | 150 | # Upsample conv net 151 | if upsample_conditional_features: 152 | self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) 153 | else: 154 | self.upsample_net = None 155 | 156 | self.receptive_field = receptive_field_size(layers, stacks, kernel_size) 157 | 158 | def has_speaker_embedding(self): 159 | return self.embed_speakers is not None 160 | 161 | def local_conditioning_enabled(self): 162 | return self.cin_channels > 0 163 | 164 | def forward(self, x, c=None, g=None, softmax=False): 165 | """Forward step 166 | 167 | Args: 168 | x (Tensor): One-hot encoded audio signal, shape (B x C x T) 169 | c (Tensor): Local conditioning features, 170 | shape (B x cin_channels x T) 171 | g (Tensor): Global conditioning features, 172 | shape (B x gin_channels x 1) or speaker Ids of shape (B x 1). 173 | Note that ``self.use_speaker_embedding`` must be False when you 174 | want to disable embedding layer and use external features 175 | directly (e.g., one-hot vector). 176 | Also type of input tensor must be FloatTensor, not LongTensor 177 | in case of ``self.use_speaker_embedding`` equals False. 178 | softmax (bool): Whether applies softmax or not. 179 | 180 | Returns: 181 | Tensor: output, shape B x out_channels x T 182 | """ 183 | B, _, T = x.size() 184 | 185 | if g is not None: 186 | if self.embed_speakers is not None: 187 | # (B x 1) -> (B x 1 x gin_channels) 188 | g = self.embed_speakers(g.view(B, -1)) 189 | # (B x gin_channels x 1) 190 | g = g.transpose(1, 2) 191 | assert g.dim() == 3 192 | # Expand global conditioning features to all time steps 193 | g_bct = _expand_global_features(B, T, g, bct=True) 194 | 195 | if c is not None and self.upsample_net is not None: 196 | c = self.upsample_net(c) 197 | assert c.size(-1) == x.size(-1) 198 | 199 | # Feed data to network 200 | x = self.first_conv(x) 201 | skips = 0 202 | for f in self.conv_layers: 203 | x, h = f(x, c, g_bct) 204 | skips += h 205 | skips *= math.sqrt(1.0 / len(self.conv_layers)) 206 | 207 | x = skips 208 | for f in self.last_conv_layers: 209 | x = f(x) 210 | 211 | x = F.softmax(x, dim=1) if softmax else x 212 | 213 | return x 214 | 215 | def incremental_forward(self, initial_input=None, c=None, g=None, 216 | T=100, test_inputs=None, 217 | tqdm=lambda x: x, softmax=True, quantize=True, 218 | log_scale_min=-50.0): 219 | """Incremental forward step 220 | 221 | Due to linearized convolutions, inputs of shape (B x C x T) are reshaped 222 | to (B x T x C) internally and fed to the network for each time step. 223 | Input of each time step will be of shape (B x 1 x C). 224 | 225 | Args: 226 | initial_input (Tensor): Initial decoder input, (B x C x 1) 227 | c (Tensor): Local conditioning features, shape (B x C' x T) 228 | g (Tensor): Global conditioning features, shape (B x C'' or B x C''x 1) 229 | T (int): Number of time steps to generate. 230 | test_inputs (Tensor): Teacher forcing inputs (for debugging) 231 | tqdm (lamda) : tqdm 232 | softmax (bool) : Whether applies softmax or not 233 | quantize (bool): Whether quantize softmax output before feeding the 234 | network output to input for the next time step. TODO: rename 235 | log_scale_min (float): Log scale minimum value. 236 | 237 | Returns: 238 | Tensor: Generated one-hot encoded samples. B x C x T  239 | or scaler vector B x 1 x T 240 | """ 241 | self.clear_buffer() 242 | B = 1 243 | 244 | # Note: shape should be **(B x T x C)**, not (B x C x T) opposed to 245 | # batch forward due to linealized convolution 246 | if test_inputs is not None: 247 | if self.scalar_input: 248 | if test_inputs.size(1) == 1: 249 | test_inputs = test_inputs.transpose(1, 2).contiguous() 250 | else: 251 | if test_inputs.size(1) == self.out_channels: 252 | test_inputs = test_inputs.transpose(1, 2).contiguous() 253 | 254 | B = test_inputs.size(0) 255 | if T is None: 256 | T = test_inputs.size(1) 257 | else: 258 | T = max(T, test_inputs.size(1)) 259 | # cast to int in case of numpy.int64... 260 | T = int(T) 261 | 262 | # Global conditioning 263 | if g is not None: 264 | if self.embed_speakers is not None: 265 | g = self.embed_speakers(g.view(B, -1)) 266 | # (B x gin_channels, 1) 267 | g = g.transpose(1, 2) 268 | assert g.dim() == 3 269 | g_btc = _expand_global_features(B, T, g, bct=False) 270 | 271 | # Local conditioning 272 | if c is not None: 273 | B = c.shape[0] 274 | if self.upsample_net is not None: 275 | c = self.upsample_net(c) 276 | assert c.size(-1) == T 277 | if c.size(-1) == T: 278 | c = c.transpose(1, 2).contiguous() 279 | 280 | outputs = [] 281 | if initial_input is None: 282 | if self.scalar_input: 283 | initial_input = torch.zeros(B, 1, 1) 284 | else: 285 | initial_input = torch.zeros(B, 1, self.out_channels) 286 | initial_input[:, :, 127] = 1 # TODO: is this ok? 287 | # https://github.com/pytorch/pytorch/issues/584#issuecomment-275169567 288 | if next(self.parameters()).is_cuda: 289 | initial_input = initial_input.cuda() 290 | else: 291 | if initial_input.size(1) == self.out_channels: 292 | initial_input = initial_input.transpose(1, 2).contiguous() 293 | 294 | current_input = initial_input 295 | 296 | for t in tqdm(range(T)): 297 | if test_inputs is not None and t < test_inputs.size(1): 298 | current_input = test_inputs[:, t, :].unsqueeze(1) 299 | else: 300 | if t > 0: 301 | current_input = outputs[-1] 302 | 303 | # Conditioning features for single time step 304 | ct = None if c is None else c[:, t, :].unsqueeze(1) 305 | gt = None if g is None else g_btc[:, t, :].unsqueeze(1) 306 | 307 | x = current_input 308 | x = self.first_conv.incremental_forward(x) 309 | skips = 0 310 | for f in self.conv_layers: 311 | x, h = f.incremental_forward(x, ct, gt) 312 | skips += h 313 | skips *= math.sqrt(1.0 / len(self.conv_layers)) 314 | x = skips 315 | for f in self.last_conv_layers: 316 | try: 317 | x = f.incremental_forward(x) 318 | except AttributeError: 319 | x = f(x) 320 | 321 | # Generate next input by sampling 322 | if self.scalar_input: 323 | if self.output_distribution == "Logistic": 324 | x = sample_from_discretized_mix_logistic( 325 | x.view(B, -1, 1), log_scale_min=log_scale_min) 326 | elif self.output_distribution == "Normal": 327 | x = sample_from_mix_gaussian( 328 | x.view(B, -1, 1), log_scale_min=log_scale_min) 329 | else: 330 | assert False 331 | else: 332 | x = F.softmax(x.view(B, -1), dim=1) if softmax else x.view(B, -1) 333 | if quantize: 334 | dist = torch.distributions.OneHotCategorical(x) 335 | x = dist.sample() 336 | outputs += [x.data] 337 | # T x B x C 338 | outputs = torch.stack(outputs) 339 | # B x C x T 340 | outputs = outputs.transpose(0, 1).transpose(1, 2).contiguous() 341 | 342 | self.clear_buffer() 343 | return outputs 344 | 345 | def clear_buffer(self): 346 | self.first_conv.clear_buffer() 347 | for f in self.conv_layers: 348 | f.clear_buffer() 349 | for f in self.last_conv_layers: 350 | try: 351 | f.clear_buffer() 352 | except AttributeError: 353 | pass 354 | 355 | def make_generation_fast_(self): 356 | def remove_weight_norm(m): 357 | try: 358 | nn.utils.remove_weight_norm(m) 359 | except ValueError: # this module didn't have weight norm 360 | return 361 | self.apply(remove_weight_norm) 362 | --------------------------------------------------------------------------------