├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── README.md ├── assets ├── WaveRNN.png ├── tacotron_wavernn.png ├── training_viz.gif └── wavernn_alt_model_hrz2.png ├── gen_tacotron.py ├── gen_wavernn.py ├── hparams.py ├── models ├── __init__.py ├── deepmind_version.py ├── fatchord_version.py └── tacotron.py ├── notebooks ├── NB1 - Fit a Sine Wave.ipynb ├── NB2 - Fit a Short Sample.ipynb ├── NB3 - Fit a 30min Sample.ipynb ├── NB4a - Alternative Model (Preprocessing).ipynb ├── NB4b - Alternative Model (Training).ipynb ├── Pruning - Scratchpad.ipynb ├── __init__.py ├── models │ └── wavernn.py ├── outputs │ ├── nb1 │ │ └── model_output.wav │ ├── nb2 │ │ └── 3k_steps.wav │ └── nb3 │ │ └── 12k_steps.wav └── utils │ ├── __init__.py │ ├── display.py │ └── dsp.py ├── preprocess.py ├── pretrained ├── ljspeech.tacotron.r2.180k.zip └── ljspeech.wavernn.mol.800k.zip ├── quick_start.py ├── requirements.txt ├── sentences.txt ├── train_tacotron.py ├── train_wavernn.py └── utils ├── __init__.py ├── checkpoints.py ├── dataset.py ├── display.py ├── distribution.py ├── dsp.py ├── files.py ├── paths.py └── text ├── LICENSE ├── __init__.py ├── cleaners.py ├── cmudict.py ├── numbers.py ├── recipes.py └── symbols.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE files 2 | .idea 3 | .vscode 4 | 5 | # Mac files 6 | .DS_Store 7 | 8 | # Environments 9 | .env 10 | .venv 11 | env/ 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | venv.bak/ 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Jupyter Notebook 48 | .ipynb_checkpoints 49 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 fatchord (https://github.com/fatchord) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveRNN 2 | 3 | ##### (Update: Vanilla Tacotron One TTS system just implemented - more coming soon!) 4 | 5 | ![Tacotron with WaveRNN diagrams](assets/tacotron_wavernn.png) 6 | 7 | Pytorch implementation of Deepmind's WaveRNN model from [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1) 8 | 9 | # Installation 10 | 11 | Ensure you have: 12 | 13 | * Python >= 3.6 14 | * [Pytorch 1 with CUDA](https://pytorch.org/) 15 | 16 | Then install the rest with pip: 17 | 18 | > pip install -r requirements.txt 19 | 20 | # How to Use 21 | 22 | ### Quick Start 23 | 24 | If you want to use TTS functionality immediately you can simply use: 25 | 26 | > python quick_start.py 27 | 28 | This will generate everything in the default sentences.txt file and output to a new 'quick_start' folder where you can playback the wav files and take a look at the attention plots 29 | 30 | You can also use that script to generate custom tts sentences and/or use '-u' to generate unbatched (better audio quality): 31 | 32 | > python quick_start.py -u --input_text "What will happen if I run this command?" 33 | 34 | 35 | ### Training your own Models 36 | ![Attenion and Mel Training GIF](assets/training_viz.gif) 37 | 38 | Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) Dataset. 39 | 40 | Edit **hparams.py**, point **wav_path** to your dataset and run: 41 | 42 | > python preprocess.py 43 | 44 | or use preprocess.py --path to point directly to the dataset 45 | ___ 46 | 47 | Here's my recommendation on what order to run things: 48 | 49 | 1 - Train Tacotron with: 50 | 51 | > python train_tacotron.py 52 | 53 | 2 - You can leave that finish training or at any point you can use: 54 | 55 | > python train_tacotron.py --force_gta 56 | 57 | this will force tactron to create a GTA dataset even if it hasn't finish training. 58 | 59 | 3 - Train WaveRNN with: 60 | 61 | > python train_wavernn.py --gta 62 | 63 | NB: You can always just run train_wavernn.py without --gta if you're not interested in TTS. 64 | 65 | 4 - Generate Sentences with both models using: 66 | 67 | > python gen_tacotron.py wavernn 68 | 69 | this will generate default sentences. If you want generate custom sentences you can use 70 | 71 | > python gen_tacotron.py --input_text "this is whatever you want it to be" wavernn 72 | 73 | And finally, you can always use --help on any of those scripts to see what options are available :) 74 | 75 | 76 | 77 | # Samples 78 | 79 | [Can be found here.](https://fatchord.github.io/model_outputs/) 80 | 81 | # Pretrained Models 82 | 83 | Currently there are two pretrained models available in the /pretrained/ folder': 84 | 85 | Both are trained on LJSpeech 86 | 87 | * WaveRNN (Mixture of Logistics output) trained to 800k steps 88 | * Tacotron trained to 180k steps 89 | 90 | ____ 91 | 92 | ### References 93 | 94 | * [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1) 95 | * [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135) 96 | * [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/abs/1712.05884) 97 | 98 | ### Acknowlegements 99 | 100 | * [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) 101 | * [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder) 102 | * Special thanks to github users [G-Wang](https://github.com/G-Wang), [geneing](https://github.com/geneing) & [erogol](https://github.com/erogol) 103 | -------------------------------------------------------------------------------- /assets/WaveRNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/assets/WaveRNN.png -------------------------------------------------------------------------------- /assets/tacotron_wavernn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/assets/tacotron_wavernn.png -------------------------------------------------------------------------------- /assets/training_viz.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/assets/training_viz.gif -------------------------------------------------------------------------------- /assets/wavernn_alt_model_hrz2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/assets/wavernn_alt_model_hrz2.png -------------------------------------------------------------------------------- /gen_tacotron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.fatchord_version import WaveRNN 3 | from utils import hparams as hp 4 | from utils.text.symbols import symbols 5 | from utils.paths import Paths 6 | from models.tacotron import Tacotron 7 | import argparse 8 | from utils.text import text_to_sequence 9 | from utils.display import save_attention, simple_table 10 | from utils.dsp import reconstruct_waveform, save_wav 11 | import numpy as np 12 | 13 | if __name__ == "__main__": 14 | 15 | # Parse Arguments 16 | parser = argparse.ArgumentParser(description='TTS Generator') 17 | parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!') 18 | parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights') 19 | parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots') 20 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') 21 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') 22 | 23 | parser.set_defaults(input_text=None) 24 | parser.set_defaults(weights_path=None) 25 | 26 | # name of subcommand goes to args.vocoder 27 | subparsers = parser.add_subparsers(required=True, dest='vocoder') 28 | 29 | wr_parser = subparsers.add_parser('wavernn', aliases=['wr']) 30 | wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation') 31 | wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation') 32 | wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples') 33 | wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index') 34 | wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights') 35 | wr_parser.set_defaults(batched=None) 36 | 37 | gl_parser = subparsers.add_parser('griffinlim', aliases=['gl']) 38 | gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations') 39 | 40 | args = parser.parse_args() 41 | 42 | if args.vocoder in ['griffinlim', 'gl']: 43 | args.vocoder = 'griffinlim' 44 | elif args.vocoder in ['wavernn', 'wr']: 45 | args.vocoder = 'wavernn' 46 | else: 47 | raise argparse.ArgumentError('Must provide a valid vocoder type!') 48 | 49 | hp.configure(args.hp_file) # Load hparams from file 50 | # set defaults for any arguments that depend on hparams 51 | if args.vocoder == 'wavernn': 52 | if args.target is None: 53 | args.target = hp.voc_target 54 | if args.overlap is None: 55 | args.overlap = hp.voc_overlap 56 | if args.batched is None: 57 | args.batched = hp.voc_gen_batched 58 | 59 | batched = args.batched 60 | target = args.target 61 | overlap = args.overlap 62 | 63 | input_text = args.input_text 64 | tts_weights = args.tts_weights 65 | save_attn = args.save_attn 66 | 67 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) 68 | 69 | if not args.force_cpu and torch.cuda.is_available(): 70 | device = torch.device('cuda') 71 | else: 72 | device = torch.device('cpu') 73 | print('Using device:', device) 74 | 75 | if args.vocoder == 'wavernn': 76 | print('\nInitialising WaveRNN Model...\n') 77 | # Instantiate WaveRNN Model 78 | voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, 79 | fc_dims=hp.voc_fc_dims, 80 | bits=hp.bits, 81 | pad=hp.voc_pad, 82 | upsample_factors=hp.voc_upsample_factors, 83 | feat_dims=hp.num_mels, 84 | compute_dims=hp.voc_compute_dims, 85 | res_out_dims=hp.voc_res_out_dims, 86 | res_blocks=hp.voc_res_blocks, 87 | hop_length=hp.hop_length, 88 | sample_rate=hp.sample_rate, 89 | mode=hp.voc_mode).to(device) 90 | 91 | voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights 92 | voc_model.load(voc_load_path) 93 | 94 | print('\nInitialising Tacotron Model...\n') 95 | 96 | # Instantiate Tacotron Model 97 | tts_model = Tacotron(embed_dims=hp.tts_embed_dims, 98 | num_chars=len(symbols), 99 | encoder_dims=hp.tts_encoder_dims, 100 | decoder_dims=hp.tts_decoder_dims, 101 | n_mels=hp.num_mels, 102 | fft_bins=hp.num_mels, 103 | postnet_dims=hp.tts_postnet_dims, 104 | encoder_K=hp.tts_encoder_K, 105 | lstm_dims=hp.tts_lstm_dims, 106 | postnet_K=hp.tts_postnet_K, 107 | num_highways=hp.tts_num_highways, 108 | dropout=hp.tts_dropout, 109 | stop_threshold=hp.tts_stop_threshold).to(device) 110 | 111 | tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights 112 | tts_model.load(tts_load_path) 113 | 114 | if input_text: 115 | inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)] 116 | else: 117 | with open('sentences.txt') as f: 118 | inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f] 119 | 120 | if args.vocoder == 'wavernn': 121 | voc_k = voc_model.get_step() // 1000 122 | tts_k = tts_model.get_step() // 1000 123 | 124 | simple_table([('Tacotron', str(tts_k) + 'k'), 125 | ('r', tts_model.r), 126 | ('Vocoder Type', 'WaveRNN'), 127 | ('WaveRNN', str(voc_k) + 'k'), 128 | ('Generation Mode', 'Batched' if batched else 'Unbatched'), 129 | ('Target Samples', target if batched else 'N/A'), 130 | ('Overlap Samples', overlap if batched else 'N/A')]) 131 | 132 | elif args.vocoder == 'griffinlim': 133 | tts_k = tts_model.get_step() // 1000 134 | simple_table([('Tacotron', str(tts_k) + 'k'), 135 | ('r', tts_model.r), 136 | ('Vocoder Type', 'Griffin-Lim'), 137 | ('GL Iters', args.iters)]) 138 | 139 | for i, x in enumerate(inputs, 1): 140 | 141 | print(f'\n| Generating {i}/{len(inputs)}') 142 | _, m, attention = tts_model.generate(x) 143 | # Fix mel spectrogram scaling to be from 0 to 1 144 | m = (m + 4) / 8 145 | np.clip(m, 0, 1, out=m) 146 | 147 | if args.vocoder == 'griffinlim': 148 | v_type = args.vocoder 149 | elif args.vocoder == 'wavernn' and args.batched: 150 | v_type = 'wavernn_batched' 151 | else: 152 | v_type = 'wavernn_unbatched' 153 | 154 | if input_text: 155 | save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav' 156 | else: 157 | save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav' 158 | 159 | if save_attn: save_attention(attention, save_path) 160 | 161 | if args.vocoder == 'wavernn': 162 | m = torch.tensor(m).unsqueeze(0) 163 | voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law) 164 | elif args.vocoder == 'griffinlim': 165 | wav = reconstruct_waveform(m, n_iter=args.iters) 166 | save_wav(wav, save_path) 167 | 168 | print('\n\nDone.\n') 169 | -------------------------------------------------------------------------------- /gen_wavernn.py: -------------------------------------------------------------------------------- 1 | from utils.dataset import get_vocoder_datasets 2 | from utils.dsp import * 3 | from models.fatchord_version import WaveRNN 4 | from utils.paths import Paths 5 | from utils.display import simple_table 6 | import torch 7 | import argparse 8 | from pathlib import Path 9 | 10 | 11 | def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path): 12 | 13 | k = model.get_step() // 1000 14 | 15 | for i, (m, x) in enumerate(test_set, 1): 16 | 17 | if i > samples: break 18 | 19 | print('\n| Generating: %i/%i' % (i, samples)) 20 | 21 | x = x[0].numpy() 22 | 23 | bits = 16 if hp.voc_mode == 'MOL' else hp.bits 24 | 25 | if hp.mu_law and hp.voc_mode != 'MOL': 26 | x = decode_mu_law(x, 2**bits, from_labels=True) 27 | else: 28 | x = label_2_float(x, bits) 29 | 30 | save_wav(x, save_path/f'{k}k_steps_{i}_target.wav') 31 | 32 | batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED' 33 | save_str = str(save_path/f'{k}k_steps_{i}_{batch_str}.wav') 34 | 35 | _ = model.generate(m, save_str, batched, target, overlap, hp.mu_law) 36 | 37 | 38 | def gen_from_file(model: WaveRNN, load_path: Path, save_path: Path, batched, target, overlap): 39 | 40 | k = model.get_step() // 1000 41 | file_name = load_path.stem 42 | 43 | suffix = load_path.suffix 44 | if suffix == ".wav": 45 | wav = load_wav(load_path) 46 | save_wav(wav, save_path/f'__{file_name}__{k}k_steps_target.wav') 47 | mel = melspectrogram(wav) 48 | elif suffix == ".npy": 49 | mel = np.load(load_path) 50 | if mel.ndim != 2 or mel.shape[0] != hp.num_mels: 51 | raise ValueError(f'Expected a numpy array shaped (n_mels, n_hops), but got {wav.shape}!') 52 | _max = np.max(mel) 53 | _min = np.min(mel) 54 | if _max >= 1.01 or _min <= -0.01: 55 | raise ValueError(f'Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]') 56 | else: 57 | raise ValueError(f"Expected an extension of .wav or .npy, but got {suffix}!") 58 | 59 | 60 | mel = torch.tensor(mel).unsqueeze(0) 61 | 62 | batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED' 63 | save_str = save_path/f'__{file_name}__{k}k_steps_{batch_str}.wav' 64 | 65 | _ = model.generate(mel, save_str, batched, target, overlap, hp.mu_law) 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | parser = argparse.ArgumentParser(description='Generate WaveRNN Samples') 71 | parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation') 72 | parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation') 73 | parser.add_argument('--samples', '-s', type=int, help='[int] number of utterances to generate') 74 | parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index') 75 | parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples') 76 | parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset') 77 | parser.add_argument('--voc_weights', '-w', type=str, help='[string/path] Load in different WaveRNN weights') 78 | parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset') 79 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') 80 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') 81 | 82 | parser.set_defaults(batched=None) 83 | 84 | args = parser.parse_args() 85 | 86 | hp.configure(args.hp_file) # Load hparams from file 87 | # set defaults for any arguments that depend on hparams 88 | if args.target is None: 89 | args.target = hp.voc_target 90 | if args.overlap is None: 91 | args.overlap = hp.voc_overlap 92 | if args.batched is None: 93 | args.batched = hp.voc_gen_batched 94 | if args.samples is None: 95 | args.samples = hp.voc_gen_at_checkpoint 96 | 97 | batched = args.batched 98 | samples = args.samples 99 | target = args.target 100 | overlap = args.overlap 101 | file = args.file 102 | gta = args.gta 103 | 104 | if not args.force_cpu and torch.cuda.is_available(): 105 | device = torch.device('cuda') 106 | else: 107 | device = torch.device('cpu') 108 | print('Using device:', device) 109 | 110 | print('\nInitialising Model...\n') 111 | 112 | model = WaveRNN(rnn_dims=hp.voc_rnn_dims, 113 | fc_dims=hp.voc_fc_dims, 114 | bits=hp.bits, 115 | pad=hp.voc_pad, 116 | upsample_factors=hp.voc_upsample_factors, 117 | feat_dims=hp.num_mels, 118 | compute_dims=hp.voc_compute_dims, 119 | res_out_dims=hp.voc_res_out_dims, 120 | res_blocks=hp.voc_res_blocks, 121 | hop_length=hp.hop_length, 122 | sample_rate=hp.sample_rate, 123 | mode=hp.voc_mode).to(device) 124 | 125 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) 126 | 127 | voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights 128 | 129 | model.load(voc_weights) 130 | 131 | simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'), 132 | ('Target Samples', target if batched else 'N/A'), 133 | ('Overlap Samples', overlap if batched else 'N/A')]) 134 | 135 | if file: 136 | file = Path(file).expanduser() 137 | gen_from_file(model, file, paths.voc_output, batched, target, overlap) 138 | else: 139 | _, test_set = get_vocoder_datasets(paths.data, 1, gta) 140 | gen_testset(model, test_set, samples, batched, target, overlap, paths.voc_output) 141 | 142 | print('\n\nExiting...\n') 143 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | 2 | # CONFIG -----------------------------------------------------------------------------------------------------------# 3 | 4 | # Here are the input and output data paths (Note: you can override wav_path in preprocess.py) 5 | wav_path = '/path/to/wav_files/' 6 | data_path = 'data/' 7 | 8 | # model ids are separate - that way you can use a new tts with an old wavernn and vice versa 9 | # NB: expect undefined behaviour if models were trained on different DSP settings 10 | voc_model_id = 'ljspeech_mol' 11 | tts_model_id = 'ljspeech_lsa_smooth_attention' 12 | 13 | # set this to True if you are only interested in WaveRNN 14 | ignore_tts = False 15 | 16 | 17 | # DSP --------------------------------------------------------------------------------------------------------------# 18 | 19 | # Settings for all models 20 | sample_rate = 22050 21 | n_fft = 2048 22 | fft_bins = n_fft // 2 + 1 23 | num_mels = 80 24 | hop_length = 275 # 12.5ms - in line with Tacotron 2 paper 25 | win_length = 1100 # 50ms - same reason as above 26 | fmin = 40 27 | min_level_db = -100 28 | ref_level_db = 20 29 | bits = 9 # bit depth of signal 30 | mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode below 31 | peak_norm = False # Normalise to the peak of each wav file 32 | 33 | 34 | # WAVERNN / VOCODER ------------------------------------------------------------------------------------------------# 35 | 36 | 37 | # Model Hparams 38 | voc_mode = 'MOL' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics) 39 | voc_upsample_factors = (5, 5, 11) # NB - this needs to correctly factorise hop_length 40 | voc_rnn_dims = 512 41 | voc_fc_dims = 512 42 | voc_compute_dims = 128 43 | voc_res_out_dims = 128 44 | voc_res_blocks = 10 45 | 46 | # Training 47 | voc_batch_size = 32 48 | voc_lr = 1e-4 49 | voc_checkpoint_every = 25_000 50 | voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint 51 | voc_total_steps = 1_000_000 # Total number of training steps 52 | voc_test_samples = 50 # How many unseen samples to put aside for testing 53 | voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length 54 | voc_seq_len = hop_length * 5 # must be a multiple of hop_length 55 | voc_clip_grad_norm = 4 # set to None if no gradient clipping needed 56 | 57 | # Generating / Synthesizing 58 | voc_gen_batched = True # very fast (realtime+) single utterance batched generation 59 | voc_target = 11_000 # target number of samples to be generated in each batch entry 60 | voc_overlap = 550 # number of samples for crossfading between batches 61 | 62 | 63 | # TACOTRON/TTS -----------------------------------------------------------------------------------------------------# 64 | 65 | 66 | # Model Hparams 67 | tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs 68 | tts_encoder_dims = 128 69 | tts_decoder_dims = 256 70 | tts_postnet_dims = 128 71 | tts_encoder_K = 16 72 | tts_lstm_dims = 512 73 | tts_postnet_K = 8 74 | tts_num_highways = 4 75 | tts_dropout = 0.5 76 | tts_cleaner_names = ['english_cleaners'] 77 | tts_stop_threshold = -3.4 # Value below which audio generation ends. 78 | # For example, for a range of [-4, 4], this 79 | # will terminate the sequence at the first 80 | # frame that has all values < -3.4 81 | 82 | # Training 83 | 84 | tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule 85 | (5, 1e-4, 100_000, 32), # (r, lr, step, batch_size) 86 | (2, 1e-4, 180_000, 16), 87 | (2, 1e-4, 350_000, 8)] 88 | 89 | tts_max_mel_len = 1250 # if you have a couple of extremely long spectrograms you might want to use this 90 | tts_bin_lengths = True # bins the spectrogram lengths before sampling in data loader - speeds up training 91 | tts_clip_grad_norm = 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 92 | tts_checkpoint_every = 2_000 # checkpoints the model every X steps 93 | # TODO: tts_phoneme_prob = 0.0 # [0 <-> 1] probability for feeding model phonemes vrs graphemes 94 | 95 | 96 | # ------------------------------------------------------------------------------------------------------------------# 97 | 98 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/models/__init__.py -------------------------------------------------------------------------------- /models/deepmind_version.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.display import * 5 | from utils.dsp import * 6 | import numpy as np 7 | 8 | class WaveRNN(nn.Module): 9 | def __init__(self, hidden_size=896, quantisation=256): 10 | super(WaveRNN, self).__init__() 11 | 12 | self.hidden_size = hidden_size 13 | self.split_size = hidden_size // 2 14 | 15 | # The main matmul 16 | self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) 17 | 18 | # Output fc layers 19 | self.O1 = nn.Linear(self.split_size, self.split_size) 20 | self.O2 = nn.Linear(self.split_size, quantisation) 21 | self.O3 = nn.Linear(self.split_size, self.split_size) 22 | self.O4 = nn.Linear(self.split_size, quantisation) 23 | 24 | # Input fc layers 25 | self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False) 26 | self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False) 27 | 28 | # biases for the gates 29 | self.bias_u = nn.Parameter(torch.zeros(self.hidden_size)) 30 | self.bias_r = nn.Parameter(torch.zeros(self.hidden_size)) 31 | self.bias_e = nn.Parameter(torch.zeros(self.hidden_size)) 32 | 33 | # display num params 34 | self.num_params() 35 | 36 | 37 | def forward(self, prev_y, prev_hidden, current_coarse): 38 | 39 | # Main matmul - the projection is split 3 ways 40 | R_hidden = self.R(prev_hidden) 41 | R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1) 42 | 43 | # Project the prev input 44 | coarse_input_proj = self.I_coarse(prev_y) 45 | I_coarse_u, I_coarse_r, I_coarse_e = \ 46 | torch.split(coarse_input_proj, self.split_size, dim=1) 47 | 48 | # Project the prev input and current coarse sample 49 | fine_input = torch.cat([prev_y, current_coarse], dim=1) 50 | fine_input_proj = self.I_fine(fine_input) 51 | I_fine_u, I_fine_r, I_fine_e = \ 52 | torch.split(fine_input_proj, self.split_size, dim=1) 53 | 54 | # concatenate for the gates 55 | I_u = torch.cat([I_coarse_u, I_fine_u], dim=1) 56 | I_r = torch.cat([I_coarse_r, I_fine_r], dim=1) 57 | I_e = torch.cat([I_coarse_e, I_fine_e], dim=1) 58 | 59 | # Compute all gates for coarse and fine 60 | u = F.sigmoid(R_u + I_u + self.bias_u) 61 | r = F.sigmoid(R_r + I_r + self.bias_r) 62 | e = F.tanh(r * R_e + I_e + self.bias_e) 63 | hidden = u * prev_hidden + (1. - u) * e 64 | 65 | # Split the hidden state 66 | hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1) 67 | 68 | # Compute outputs 69 | out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) 70 | out_fine = self.O4(F.relu(self.O3(hidden_fine))) 71 | 72 | return out_coarse, out_fine, hidden 73 | 74 | 75 | def generate(self, seq_len): 76 | device = next(self.parameters()).device # use same device as parameters 77 | 78 | with torch.no_grad(): 79 | 80 | # First split up the biases for the gates 81 | b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size) 82 | b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size) 83 | b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size) 84 | 85 | # Lists for the two output seqs 86 | c_outputs, f_outputs = [], [] 87 | 88 | # Some initial inputs 89 | out_coarse = torch.tensor([0], dtype=torch.long, device=device) 90 | out_fine = torch.tensor([0], dtype=torch.long, device=device) 91 | 92 | # We'll meed a hidden state 93 | hidden = self.get_initial_hidden() 94 | 95 | # Need a clock for display 96 | start = time.time() 97 | 98 | # Loop for generation 99 | for i in range(seq_len): 100 | 101 | # Split into two hidden states 102 | hidden_coarse, hidden_fine = \ 103 | torch.split(hidden, self.split_size, dim=1) 104 | 105 | # Scale and concat previous predictions 106 | out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1. 107 | out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1. 108 | prev_outputs = torch.cat([out_coarse, out_fine], dim=1) 109 | 110 | # Project input 111 | coarse_input_proj = self.I_coarse(prev_outputs) 112 | I_coarse_u, I_coarse_r, I_coarse_e = \ 113 | torch.split(coarse_input_proj, self.split_size, dim=1) 114 | 115 | # Project hidden state and split 6 ways 116 | R_hidden = self.R(hidden) 117 | R_coarse_u , R_fine_u, \ 118 | R_coarse_r, R_fine_r, \ 119 | R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1) 120 | 121 | # Compute the coarse gates 122 | u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u) 123 | r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r) 124 | e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e) 125 | hidden_coarse = u * hidden_coarse + (1. - u) * e 126 | 127 | # Compute the coarse output 128 | out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) 129 | posterior = F.softmax(out_coarse, dim=1) 130 | distrib = torch.distributions.Categorical(posterior) 131 | out_coarse = distrib.sample() 132 | c_outputs.append(out_coarse) 133 | 134 | # Project the [prev outputs and predicted coarse sample] 135 | coarse_pred = out_coarse.float() / 127.5 - 1. 136 | fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1) 137 | fine_input_proj = self.I_fine(fine_input) 138 | I_fine_u, I_fine_r, I_fine_e = \ 139 | torch.split(fine_input_proj, self.split_size, dim=1) 140 | 141 | # Compute the fine gates 142 | u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u) 143 | r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r) 144 | e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e) 145 | hidden_fine = u * hidden_fine + (1. - u) * e 146 | 147 | # Compute the fine output 148 | out_fine = self.O4(F.relu(self.O3(hidden_fine))) 149 | posterior = F.softmax(out_fine, dim=1) 150 | distrib = torch.distributions.Categorical(posterior) 151 | out_fine = distrib.sample() 152 | f_outputs.append(out_fine) 153 | 154 | # Put the hidden state back together 155 | hidden = torch.cat([hidden_coarse, hidden_fine], dim=1) 156 | 157 | # Display progress 158 | speed = (i + 1) / (time.time() - start) 159 | stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed)) 160 | 161 | coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy() 162 | fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy() 163 | output = combine_signal(coarse, fine) 164 | 165 | return output, coarse, fine 166 | 167 | def get_initial_hidden(self, batch_size=1): 168 | device = next(self.parameters()).device # use same device as parameters 169 | return torch.zeros(batch_size, self.hidden_size, device=device) 170 | 171 | def num_params(self, print_out=True): 172 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 173 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 174 | if print_out: 175 | print('Trainable Parameters: %.3f million' % parameters) 176 | return parameters -------------------------------------------------------------------------------- /models/fatchord_version.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.distribution import sample_from_discretized_mix_logistic 5 | from utils.display import * 6 | from utils.dsp import * 7 | import os 8 | import numpy as np 9 | from pathlib import Path 10 | from typing import Union 11 | 12 | 13 | class ResBlock(nn.Module): 14 | def __init__(self, dims): 15 | super().__init__() 16 | self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 17 | self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 18 | self.batch_norm1 = nn.BatchNorm1d(dims) 19 | self.batch_norm2 = nn.BatchNorm1d(dims) 20 | 21 | def forward(self, x): 22 | residual = x 23 | x = self.conv1(x) 24 | x = self.batch_norm1(x) 25 | x = F.relu(x) 26 | x = self.conv2(x) 27 | x = self.batch_norm2(x) 28 | return x + residual 29 | 30 | 31 | class MelResNet(nn.Module): 32 | def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): 33 | super().__init__() 34 | k_size = pad * 2 + 1 35 | self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) 36 | self.batch_norm = nn.BatchNorm1d(compute_dims) 37 | self.layers = nn.ModuleList() 38 | for i in range(res_blocks): 39 | self.layers.append(ResBlock(compute_dims)) 40 | self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) 41 | 42 | def forward(self, x): 43 | x = self.conv_in(x) 44 | x = self.batch_norm(x) 45 | x = F.relu(x) 46 | for f in self.layers: x = f(x) 47 | x = self.conv_out(x) 48 | return x 49 | 50 | 51 | class Stretch2d(nn.Module): 52 | def __init__(self, x_scale, y_scale): 53 | super().__init__() 54 | self.x_scale = x_scale 55 | self.y_scale = y_scale 56 | 57 | def forward(self, x): 58 | b, c, h, w = x.size() 59 | x = x.unsqueeze(-1).unsqueeze(3) 60 | x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) 61 | return x.view(b, c, h * self.y_scale, w * self.x_scale) 62 | 63 | 64 | class UpsampleNetwork(nn.Module): 65 | def __init__(self, feat_dims, upsample_scales, compute_dims, 66 | res_blocks, res_out_dims, pad): 67 | super().__init__() 68 | total_scale = np.cumproduct(upsample_scales)[-1] 69 | self.indent = pad * total_scale 70 | self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) 71 | self.resnet_stretch = Stretch2d(total_scale, 1) 72 | self.up_layers = nn.ModuleList() 73 | for scale in upsample_scales: 74 | k_size = (1, scale * 2 + 1) 75 | padding = (0, scale) 76 | stretch = Stretch2d(scale, 1) 77 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) 78 | conv.weight.data.fill_(1. / k_size[1]) 79 | self.up_layers.append(stretch) 80 | self.up_layers.append(conv) 81 | 82 | def forward(self, m): 83 | aux = self.resnet(m).unsqueeze(1) 84 | aux = self.resnet_stretch(aux) 85 | aux = aux.squeeze(1) 86 | m = m.unsqueeze(1) 87 | for f in self.up_layers: m = f(m) 88 | m = m.squeeze(1)[:, :, self.indent:-self.indent] 89 | return m.transpose(1, 2), aux.transpose(1, 2) 90 | 91 | 92 | class WaveRNN(nn.Module): 93 | def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, 94 | feat_dims, compute_dims, res_out_dims, res_blocks, 95 | hop_length, sample_rate, mode='RAW'): 96 | super().__init__() 97 | self.mode = mode 98 | self.pad = pad 99 | if self.mode == 'RAW': 100 | self.n_classes = 2 ** bits 101 | elif self.mode == 'MOL': 102 | self.n_classes = 30 103 | else: 104 | RuntimeError("Unknown model mode value - ", self.mode) 105 | 106 | # List of rnns to call `flatten_parameters()` on 107 | self._to_flatten = [] 108 | 109 | self.rnn_dims = rnn_dims 110 | self.aux_dims = res_out_dims // 4 111 | self.hop_length = hop_length 112 | self.sample_rate = sample_rate 113 | 114 | self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad) 115 | self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) 116 | 117 | self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) 118 | self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) 119 | self._to_flatten += [self.rnn1, self.rnn2] 120 | 121 | self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) 122 | self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) 123 | self.fc3 = nn.Linear(fc_dims, self.n_classes) 124 | 125 | self.register_buffer('step', torch.zeros(1, dtype=torch.long)) 126 | self.num_params() 127 | 128 | # Avoid fragmentation of RNN parameters and associated warning 129 | self._flatten_parameters() 130 | 131 | def forward(self, x, mels): 132 | device = next(self.parameters()).device # use same device as parameters 133 | 134 | # Although we `_flatten_parameters()` on init, when using DataParallel 135 | # the model gets replicated, making it no longer guaranteed that the 136 | # weights are contiguous in GPU memory. Hence, we must call it again 137 | self._flatten_parameters() 138 | 139 | self.step += 1 140 | bsize = x.size(0) 141 | h1 = torch.zeros(1, bsize, self.rnn_dims, device=device) 142 | h2 = torch.zeros(1, bsize, self.rnn_dims, device=device) 143 | mels, aux = self.upsample(mels) 144 | 145 | aux_idx = [self.aux_dims * i for i in range(5)] 146 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]] 147 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]] 148 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]] 149 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]] 150 | 151 | x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) 152 | x = self.I(x) 153 | res = x 154 | x, _ = self.rnn1(x, h1) 155 | 156 | x = x + res 157 | res = x 158 | x = torch.cat([x, a2], dim=2) 159 | x, _ = self.rnn2(x, h2) 160 | 161 | x = x + res 162 | x = torch.cat([x, a3], dim=2) 163 | x = F.relu(self.fc1(x)) 164 | 165 | x = torch.cat([x, a4], dim=2) 166 | x = F.relu(self.fc2(x)) 167 | return self.fc3(x) 168 | 169 | def generate(self, mels, save_path: Union[str, Path], batched, target, overlap, mu_law): 170 | self.eval() 171 | 172 | device = next(self.parameters()).device # use same device as parameters 173 | 174 | mu_law = mu_law if self.mode == 'RAW' else False 175 | 176 | output = [] 177 | start = time.time() 178 | rnn1 = self.get_gru_cell(self.rnn1) 179 | rnn2 = self.get_gru_cell(self.rnn2) 180 | 181 | with torch.no_grad(): 182 | 183 | mels = torch.as_tensor(mels, device=device) 184 | wave_len = (mels.size(-1) - 1) * self.hop_length 185 | mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both') 186 | mels, aux = self.upsample(mels.transpose(1, 2)) 187 | 188 | if batched: 189 | mels = self.fold_with_overlap(mels, target, overlap) 190 | aux = self.fold_with_overlap(aux, target, overlap) 191 | 192 | b_size, seq_len, _ = mels.size() 193 | 194 | h1 = torch.zeros(b_size, self.rnn_dims, device=device) 195 | h2 = torch.zeros(b_size, self.rnn_dims, device=device) 196 | x = torch.zeros(b_size, 1, device=device) 197 | 198 | d = self.aux_dims 199 | aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] 200 | 201 | for i in range(seq_len): 202 | 203 | m_t = mels[:, i, :] 204 | 205 | a1_t, a2_t, a3_t, a4_t = \ 206 | (a[:, i, :] for a in aux_split) 207 | 208 | x = torch.cat([x, m_t, a1_t], dim=1) 209 | x = self.I(x) 210 | h1 = rnn1(x, h1) 211 | 212 | x = x + h1 213 | inp = torch.cat([x, a2_t], dim=1) 214 | h2 = rnn2(inp, h2) 215 | 216 | x = x + h2 217 | x = torch.cat([x, a3_t], dim=1) 218 | x = F.relu(self.fc1(x)) 219 | 220 | x = torch.cat([x, a4_t], dim=1) 221 | x = F.relu(self.fc2(x)) 222 | 223 | logits = self.fc3(x) 224 | 225 | if self.mode == 'MOL': 226 | sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) 227 | output.append(sample.view(-1)) 228 | # x = torch.FloatTensor([[sample]]).cuda() 229 | x = sample.transpose(0, 1) 230 | 231 | elif self.mode == 'RAW': 232 | posterior = F.softmax(logits, dim=1) 233 | distrib = torch.distributions.Categorical(posterior) 234 | 235 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. 236 | output.append(sample) 237 | x = sample.unsqueeze(-1) 238 | else: 239 | raise RuntimeError("Unknown model mode value - ", self.mode) 240 | 241 | if i % 100 == 0: self.gen_display(i, seq_len, b_size, start) 242 | 243 | output = torch.stack(output).transpose(0, 1) 244 | output = output.cpu().numpy() 245 | output = output.astype(np.float64) 246 | 247 | if mu_law: 248 | output = decode_mu_law(output, self.n_classes, False) 249 | 250 | if batched: 251 | output = self.xfade_and_unfold(output, target, overlap) 252 | else: 253 | output = output[0] 254 | 255 | # Fade-out at the end to avoid signal cutting out suddenly 256 | fade_out = np.linspace(1, 0, 20 * self.hop_length) 257 | output = output[:wave_len] 258 | output[-20 * self.hop_length:] *= fade_out 259 | 260 | save_wav(output, save_path) 261 | 262 | self.train() 263 | 264 | return output 265 | 266 | 267 | def gen_display(self, i, seq_len, b_size, start): 268 | gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 269 | pbar = progbar(i, seq_len) 270 | msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | ' 271 | stream(msg) 272 | 273 | def get_gru_cell(self, gru): 274 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) 275 | gru_cell.weight_hh.data = gru.weight_hh_l0.data 276 | gru_cell.weight_ih.data = gru.weight_ih_l0.data 277 | gru_cell.bias_hh.data = gru.bias_hh_l0.data 278 | gru_cell.bias_ih.data = gru.bias_ih_l0.data 279 | return gru_cell 280 | 281 | def pad_tensor(self, x, pad, side='both'): 282 | # NB - this is just a quick method i need right now 283 | # i.e., it won't generalise to other shapes/dims 284 | b, t, c = x.size() 285 | total = t + 2 * pad if side == 'both' else t + pad 286 | padded = torch.zeros(b, total, c, device=x.device) 287 | if side == 'before' or side == 'both': 288 | padded[:, pad:pad + t, :] = x 289 | elif side == 'after': 290 | padded[:, :t, :] = x 291 | return padded 292 | 293 | def fold_with_overlap(self, x, target, overlap): 294 | 295 | ''' Fold the tensor with overlap for quick batched inference. 296 | Overlap will be used for crossfading in xfade_and_unfold() 297 | 298 | Args: 299 | x (tensor) : Upsampled conditioning features. 300 | shape=(1, timesteps, features) 301 | target (int) : Target timesteps for each index of batch 302 | overlap (int) : Timesteps for both xfade and rnn warmup 303 | 304 | Return: 305 | (tensor) : shape=(num_folds, target + 2 * overlap, features) 306 | 307 | Details: 308 | x = [[h1, h2, ... hn]] 309 | 310 | Where each h is a vector of conditioning features 311 | 312 | Eg: target=2, overlap=1 with x.size(1)=10 313 | 314 | folded = [[h1, h2, h3, h4], 315 | [h4, h5, h6, h7], 316 | [h7, h8, h9, h10]] 317 | ''' 318 | 319 | _, total_len, features = x.size() 320 | 321 | # Calculate variables needed 322 | num_folds = (total_len - overlap) // (target + overlap) 323 | extended_len = num_folds * (overlap + target) + overlap 324 | remaining = total_len - extended_len 325 | 326 | # Pad if some time steps poking out 327 | if remaining != 0: 328 | num_folds += 1 329 | padding = target + 2 * overlap - remaining 330 | x = self.pad_tensor(x, padding, side='after') 331 | 332 | folded = torch.zeros(num_folds, target + 2 * overlap, features, device=x.device) 333 | 334 | # Get the values for the folded tensor 335 | for i in range(num_folds): 336 | start = i * (target + overlap) 337 | end = start + target + 2 * overlap 338 | folded[i] = x[:, start:end, :] 339 | 340 | return folded 341 | 342 | def xfade_and_unfold(self, y, target, overlap): 343 | 344 | ''' Applies a crossfade and unfolds into a 1d array. 345 | 346 | Args: 347 | y (ndarry) : Batched sequences of audio samples 348 | shape=(num_folds, target + 2 * overlap) 349 | dtype=np.float64 350 | overlap (int) : Timesteps for both xfade and rnn warmup 351 | 352 | Return: 353 | (ndarry) : audio samples in a 1d array 354 | shape=(total_len) 355 | dtype=np.float64 356 | 357 | Details: 358 | y = [[seq1], 359 | [seq2], 360 | [seq3]] 361 | 362 | Apply a gain envelope at both ends of the sequences 363 | 364 | y = [[seq1_in, seq1_target, seq1_out], 365 | [seq2_in, seq2_target, seq2_out], 366 | [seq3_in, seq3_target, seq3_out]] 367 | 368 | Stagger and add up the groups of samples: 369 | 370 | [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] 371 | 372 | ''' 373 | 374 | num_folds, length = y.shape 375 | target = length - 2 * overlap 376 | total_len = num_folds * (target + overlap) + overlap 377 | 378 | # Need some silence for the rnn warmup 379 | silence_len = overlap // 2 380 | fade_len = overlap - silence_len 381 | silence = np.zeros((silence_len), dtype=np.float64) 382 | linear = np.ones((silence_len), dtype=np.float64) 383 | 384 | # Equal power crossfade 385 | t = np.linspace(-1, 1, fade_len, dtype=np.float64) 386 | fade_in = np.sqrt(0.5 * (1 + t)) 387 | fade_out = np.sqrt(0.5 * (1 - t)) 388 | 389 | # Concat the silence to the fades 390 | fade_in = np.concatenate([silence, fade_in]) 391 | fade_out = np.concatenate([linear, fade_out]) 392 | 393 | # Apply the gain to the overlap samples 394 | y[:, :overlap] *= fade_in 395 | y[:, -overlap:] *= fade_out 396 | 397 | unfolded = np.zeros((total_len), dtype=np.float64) 398 | 399 | # Loop to add up all the samples 400 | for i in range(num_folds): 401 | start = i * (target + overlap) 402 | end = start + target + 2 * overlap 403 | unfolded[start:end] += y[i] 404 | 405 | return unfolded 406 | 407 | def get_step(self): 408 | return self.step.data.item() 409 | 410 | def log(self, path, msg): 411 | with open(path, 'a') as f: 412 | print(msg, file=f) 413 | 414 | def load(self, path: Union[str, Path]): 415 | # Use device of model params as location for loaded state 416 | device = next(self.parameters()).device 417 | self.load_state_dict(torch.load(path, map_location=device), strict=False) 418 | 419 | def save(self, path: Union[str, Path]): 420 | # No optimizer argument because saving a model should not include data 421 | # only relevant in the training process - it should only be properties 422 | # of the model itself. Let caller take care of saving optimzier state. 423 | torch.save(self.state_dict(), path) 424 | 425 | def num_params(self, print_out=True): 426 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 427 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 428 | if print_out: 429 | print('Trainable Parameters: %.3fM' % parameters) 430 | return parameters 431 | 432 | def _flatten_parameters(self): 433 | """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used 434 | to improve efficiency and avoid PyTorch yelling at us.""" 435 | [m.flatten_parameters() for m in self._to_flatten] 436 | -------------------------------------------------------------------------------- /models/tacotron.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | 10 | class HighwayNetwork(nn.Module): 11 | def __init__(self, size): 12 | super().__init__() 13 | self.W1 = nn.Linear(size, size) 14 | self.W2 = nn.Linear(size, size) 15 | self.W1.bias.data.fill_(0.) 16 | 17 | def forward(self, x): 18 | x1 = self.W1(x) 19 | x2 = self.W2(x) 20 | g = torch.sigmoid(x2) 21 | y = g * F.relu(x1) + (1. - g) * x 22 | return y 23 | 24 | 25 | class Encoder(nn.Module): 26 | def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout): 27 | super().__init__() 28 | self.embedding = nn.Embedding(num_chars, embed_dims) 29 | self.pre_net = PreNet(embed_dims) 30 | self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, 31 | proj_channels=[cbhg_channels, cbhg_channels], 32 | num_highways=num_highways) 33 | 34 | def forward(self, x): 35 | x = self.embedding(x) 36 | x = self.pre_net(x) 37 | x.transpose_(1, 2) 38 | x = self.cbhg(x) 39 | return x 40 | 41 | 42 | class BatchNormConv(nn.Module): 43 | def __init__(self, in_channels, out_channels, kernel, relu=True): 44 | super().__init__() 45 | self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) 46 | self.bnorm = nn.BatchNorm1d(out_channels) 47 | self.relu = relu 48 | 49 | def forward(self, x): 50 | x = self.conv(x) 51 | x = F.relu(x) if self.relu is True else x 52 | return self.bnorm(x) 53 | 54 | 55 | class CBHG(nn.Module): 56 | def __init__(self, K, in_channels, channels, proj_channels, num_highways): 57 | super().__init__() 58 | 59 | # List of all rnns to call `flatten_parameters()` on 60 | self._to_flatten = [] 61 | 62 | self.bank_kernels = [i for i in range(1, K + 1)] 63 | self.conv1d_bank = nn.ModuleList() 64 | for k in self.bank_kernels: 65 | conv = BatchNormConv(in_channels, channels, k) 66 | self.conv1d_bank.append(conv) 67 | 68 | self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 69 | 70 | self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) 71 | self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) 72 | 73 | # Fix the highway input if necessary 74 | if proj_channels[-1] != channels: 75 | self.highway_mismatch = True 76 | self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) 77 | else: 78 | self.highway_mismatch = False 79 | 80 | self.highways = nn.ModuleList() 81 | for i in range(num_highways): 82 | hn = HighwayNetwork(channels) 83 | self.highways.append(hn) 84 | 85 | self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True) 86 | self._to_flatten.append(self.rnn) 87 | 88 | # Avoid fragmentation of RNN parameters and associated warning 89 | self._flatten_parameters() 90 | 91 | def forward(self, x): 92 | # Although we `_flatten_parameters()` on init, when using DataParallel 93 | # the model gets replicated, making it no longer guaranteed that the 94 | # weights are contiguous in GPU memory. Hence, we must call it again 95 | self._flatten_parameters() 96 | 97 | # Save these for later 98 | residual = x 99 | seq_len = x.size(-1) 100 | conv_bank = [] 101 | 102 | # Convolution Bank 103 | for conv in self.conv1d_bank: 104 | c = conv(x) # Convolution 105 | conv_bank.append(c[:, :, :seq_len]) 106 | 107 | # Stack along the channel axis 108 | conv_bank = torch.cat(conv_bank, dim=1) 109 | 110 | # dump the last padding to fit residual 111 | x = self.maxpool(conv_bank)[:, :, :seq_len] 112 | 113 | # Conv1d projections 114 | x = self.conv_project1(x) 115 | x = self.conv_project2(x) 116 | 117 | # Residual Connect 118 | x = x + residual 119 | 120 | # Through the highways 121 | x = x.transpose(1, 2) 122 | if self.highway_mismatch is True: 123 | x = self.pre_highway(x) 124 | for h in self.highways: x = h(x) 125 | 126 | # And then the RNN 127 | x, _ = self.rnn(x) 128 | return x 129 | 130 | def _flatten_parameters(self): 131 | """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used 132 | to improve efficiency and avoid PyTorch yelling at us.""" 133 | [m.flatten_parameters() for m in self._to_flatten] 134 | 135 | class PreNet(nn.Module): 136 | def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): 137 | super().__init__() 138 | self.fc1 = nn.Linear(in_dims, fc1_dims) 139 | self.fc2 = nn.Linear(fc1_dims, fc2_dims) 140 | self.p = dropout 141 | 142 | def forward(self, x): 143 | x = self.fc1(x) 144 | x = F.relu(x) 145 | x = F.dropout(x, self.p, training=self.training) 146 | x = self.fc2(x) 147 | x = F.relu(x) 148 | x = F.dropout(x, self.p, training=self.training) 149 | return x 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, attn_dims): 154 | super().__init__() 155 | self.W = nn.Linear(attn_dims, attn_dims, bias=False) 156 | self.v = nn.Linear(attn_dims, 1, bias=False) 157 | 158 | def forward(self, encoder_seq_proj, query, t): 159 | 160 | # print(encoder_seq_proj.shape) 161 | # Transform the query vector 162 | query_proj = self.W(query).unsqueeze(1) 163 | 164 | # Compute the scores 165 | u = self.v(torch.tanh(encoder_seq_proj + query_proj)) 166 | scores = F.softmax(u, dim=1) 167 | 168 | return scores.transpose(1, 2) 169 | 170 | 171 | class LSA(nn.Module): 172 | def __init__(self, attn_dim, kernel_size=31, filters=32): 173 | super().__init__() 174 | self.conv = nn.Conv1d(2, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=False) 175 | self.L = nn.Linear(filters, attn_dim, bias=True) 176 | self.W = nn.Linear(attn_dim, attn_dim, bias=True) 177 | self.v = nn.Linear(attn_dim, 1, bias=False) 178 | self.cumulative = None 179 | self.attention = None 180 | 181 | def init_attention(self, encoder_seq_proj): 182 | device = next(self.parameters()).device # use same device as parameters 183 | b, t, c = encoder_seq_proj.size() 184 | self.cumulative = torch.zeros(b, t, device=device) 185 | self.attention = torch.zeros(b, t, device=device) 186 | 187 | def forward(self, encoder_seq_proj, query, t): 188 | 189 | if t == 0: self.init_attention(encoder_seq_proj) 190 | 191 | processed_query = self.W(query).unsqueeze(1) 192 | 193 | location = torch.cat([self.cumulative.unsqueeze(1), self.attention.unsqueeze(1)], dim=1) 194 | processed_loc = self.L(self.conv(location).transpose(1, 2)) 195 | 196 | u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) 197 | u = u.squeeze(-1) 198 | 199 | # Smooth Attention 200 | scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) 201 | # scores = F.softmax(u, dim=1) 202 | self.attention = scores 203 | self.cumulative += self.attention 204 | 205 | return scores.unsqueeze(-1).transpose(1, 2) 206 | 207 | 208 | class Decoder(nn.Module): 209 | # Class variable because its value doesn't change between classes 210 | # yet ought to be scoped by class because its a property of a Decoder 211 | max_r = 20 212 | def __init__(self, n_mels, decoder_dims, lstm_dims): 213 | super().__init__() 214 | self.register_buffer('r', torch.tensor(1, dtype=torch.int)) 215 | self.n_mels = n_mels 216 | self.prenet = PreNet(n_mels) 217 | self.attn_net = LSA(decoder_dims) 218 | self.attn_rnn = nn.GRUCell(decoder_dims + decoder_dims // 2, decoder_dims) 219 | self.rnn_input = nn.Linear(2 * decoder_dims, lstm_dims) 220 | self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) 221 | self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) 222 | self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) 223 | 224 | def zoneout(self, prev, current, p=0.1): 225 | device = next(self.parameters()).device # Use same device as parameters 226 | mask = torch.zeros(prev.size(), device=device).bernoulli_(p) 227 | return prev * mask + current * (1 - mask) 228 | 229 | def forward(self, encoder_seq, encoder_seq_proj, prenet_in, 230 | hidden_states, cell_states, context_vec, t): 231 | 232 | # Need this for reshaping mels 233 | batch_size = encoder_seq.size(0) 234 | 235 | # Unpack the hidden and cell states 236 | attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states 237 | rnn1_cell, rnn2_cell = cell_states 238 | 239 | # PreNet for the Attention RNN 240 | prenet_out = self.prenet(prenet_in) 241 | 242 | # Compute the Attention RNN hidden state 243 | attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) 244 | attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) 245 | 246 | # Compute the attention scores 247 | scores = self.attn_net(encoder_seq_proj, attn_hidden, t) 248 | 249 | # Dot product to create the context vector 250 | context_vec = scores @ encoder_seq 251 | context_vec = context_vec.squeeze(1) 252 | 253 | # Concat Attention RNN output w. Context Vector & project 254 | x = torch.cat([context_vec, attn_hidden], dim=1) 255 | x = self.rnn_input(x) 256 | 257 | # Compute first Residual RNN 258 | rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) 259 | if self.training: 260 | rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next) 261 | else: 262 | rnn1_hidden = rnn1_hidden_next 263 | x = x + rnn1_hidden 264 | 265 | # Compute second Residual RNN 266 | rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) 267 | if self.training: 268 | rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next) 269 | else: 270 | rnn2_hidden = rnn2_hidden_next 271 | x = x + rnn2_hidden 272 | 273 | # Project Mels 274 | mels = self.mel_proj(x) 275 | mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] 276 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 277 | cell_states = (rnn1_cell, rnn2_cell) 278 | 279 | return mels, scores, hidden_states, cell_states, context_vec 280 | 281 | 282 | class Tacotron(nn.Module): 283 | def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims, 284 | encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold): 285 | super().__init__() 286 | self.n_mels = n_mels 287 | self.lstm_dims = lstm_dims 288 | self.decoder_dims = decoder_dims 289 | self.encoder = Encoder(embed_dims, num_chars, encoder_dims, 290 | encoder_K, num_highways, dropout) 291 | self.encoder_proj = nn.Linear(decoder_dims, decoder_dims, bias=False) 292 | self.decoder = Decoder(n_mels, decoder_dims, lstm_dims) 293 | self.postnet = CBHG(postnet_K, n_mels, postnet_dims, [256, 80], num_highways) 294 | self.post_proj = nn.Linear(postnet_dims * 2, fft_bins, bias=False) 295 | 296 | self.init_model() 297 | self.num_params() 298 | 299 | self.register_buffer('step', torch.zeros(1, dtype=torch.long)) 300 | self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32)) 301 | 302 | @property 303 | def r(self): 304 | return self.decoder.r.item() 305 | 306 | @r.setter 307 | def r(self, value): 308 | self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) 309 | 310 | def forward(self, x, m, generate_gta=False): 311 | device = next(self.parameters()).device # use same device as parameters 312 | 313 | self.step += 1 314 | 315 | if generate_gta: 316 | self.eval() 317 | else: 318 | self.train() 319 | 320 | batch_size, _, steps = m.size() 321 | 322 | # Initialise all hidden states and pack into tuple 323 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) 324 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 325 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 326 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 327 | 328 | # Initialise all lstm cell states and pack into tuple 329 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 330 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 331 | cell_states = (rnn1_cell, rnn2_cell) 332 | 333 | # Frame for start of decoder loop 334 | go_frame = torch.zeros(batch_size, self.n_mels, device=device) 335 | 336 | # Need an initial context vector 337 | context_vec = torch.zeros(batch_size, self.decoder_dims, device=device) 338 | 339 | # Project the encoder outputs to avoid 340 | # unnecessary matmuls in the decoder loop 341 | encoder_seq = self.encoder(x) 342 | encoder_seq_proj = self.encoder_proj(encoder_seq) 343 | 344 | # Need a couple of lists for outputs 345 | mel_outputs, attn_scores = [], [] 346 | 347 | # Run the decoder loop 348 | for t in range(0, steps, self.r): 349 | prenet_in = m[:, :, t - 1] if t > 0 else go_frame 350 | mel_frames, scores, hidden_states, cell_states, context_vec = \ 351 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in, 352 | hidden_states, cell_states, context_vec, t) 353 | mel_outputs.append(mel_frames) 354 | attn_scores.append(scores) 355 | 356 | # Concat the mel outputs into sequence 357 | mel_outputs = torch.cat(mel_outputs, dim=2) 358 | 359 | # Post-Process for Linear Spectrograms 360 | postnet_out = self.postnet(mel_outputs) 361 | linear = self.post_proj(postnet_out) 362 | linear = linear.transpose(1, 2) 363 | 364 | # For easy visualisation 365 | attn_scores = torch.cat(attn_scores, 1) 366 | # attn_scores = attn_scores.cpu().data.numpy() 367 | 368 | return mel_outputs, linear, attn_scores 369 | 370 | def generate(self, x, steps=2000): 371 | self.eval() 372 | device = next(self.parameters()).device # use same device as parameters 373 | 374 | batch_size = 1 375 | x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0) 376 | 377 | # Need to initialise all hidden states and pack into tuple for tidyness 378 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) 379 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 380 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 381 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 382 | 383 | # Need to initialise all lstm cell states and pack into tuple for tidyness 384 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 385 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 386 | cell_states = (rnn1_cell, rnn2_cell) 387 | 388 | # Need a Frame for start of decoder loop 389 | go_frame = torch.zeros(batch_size, self.n_mels, device=device) 390 | 391 | # Need an initial context vector 392 | context_vec = torch.zeros(batch_size, self.decoder_dims, device=device) 393 | 394 | # Project the encoder outputs to avoid 395 | # unnecessary matmuls in the decoder loop 396 | encoder_seq = self.encoder(x) 397 | encoder_seq_proj = self.encoder_proj(encoder_seq) 398 | 399 | # Need a couple of lists for outputs 400 | mel_outputs, attn_scores = [], [] 401 | 402 | # Run the decoder loop 403 | for t in range(0, steps, self.r): 404 | prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame 405 | mel_frames, scores, hidden_states, cell_states, context_vec = \ 406 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in, 407 | hidden_states, cell_states, context_vec, t) 408 | mel_outputs.append(mel_frames) 409 | attn_scores.append(scores) 410 | # Stop the loop if silent frames present 411 | if (mel_frames < self.stop_threshold).all() and t > 10: break 412 | 413 | # Concat the mel outputs into sequence 414 | mel_outputs = torch.cat(mel_outputs, dim=2) 415 | 416 | # Post-Process for Linear Spectrograms 417 | postnet_out = self.postnet(mel_outputs) 418 | linear = self.post_proj(postnet_out) 419 | 420 | 421 | linear = linear.transpose(1, 2)[0].cpu().data.numpy() 422 | mel_outputs = mel_outputs[0].cpu().data.numpy() 423 | 424 | # For easy visualisation 425 | attn_scores = torch.cat(attn_scores, 1) 426 | attn_scores = attn_scores.cpu().data.numpy()[0] 427 | 428 | self.train() 429 | 430 | return mel_outputs, linear, attn_scores 431 | 432 | def init_model(self): 433 | for p in self.parameters(): 434 | if p.dim() > 1: nn.init.xavier_uniform_(p) 435 | 436 | def get_step(self): 437 | return self.step.data.item() 438 | 439 | def reset_step(self): 440 | # assignment to parameters or buffers is overloaded, updates internal dict entry 441 | self.step = self.step.data.new_tensor(1) 442 | 443 | def log(self, path, msg): 444 | with open(path, 'a') as f: 445 | print(msg, file=f) 446 | 447 | def load(self, path: Union[str, Path]): 448 | # Use device of model params as location for loaded state 449 | device = next(self.parameters()).device 450 | state_dict = torch.load(path, map_location=device) 451 | 452 | # Backwards compatibility with old saved models 453 | if 'r' in state_dict and not 'decoder.r' in state_dict: 454 | self.r = state_dict['r'] 455 | 456 | self.load_state_dict(state_dict, strict=False) 457 | 458 | def save(self, path: Union[str, Path]): 459 | # No optimizer argument because saving a model should not include data 460 | # only relevant in the training process - it should only be properties 461 | # of the model itself. Let caller take care of saving optimzier state. 462 | torch.save(self.state_dict(), path) 463 | 464 | def num_params(self, print_out=True): 465 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 466 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 467 | if print_out: 468 | print('Trainable Parameters: %.3fM' % parameters) 469 | return parameters 470 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/notebooks/__init__.py -------------------------------------------------------------------------------- /notebooks/models/wavernn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class WaveRNN(nn.Module) : 7 | def __init__(self, hidden_size=896, quantisation=256) : 8 | super(WaveRNN, self).__init__() 9 | 10 | self.hidden_size = hidden_size 11 | self.split_size = hidden_size // 2 12 | 13 | # The main matmul 14 | self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) 15 | 16 | # Output fc layers 17 | self.O1 = nn.Linear(self.split_size, self.split_size) 18 | self.O2 = nn.Linear(self.split_size, quantisation) 19 | self.O3 = nn.Linear(self.split_size, self.split_size) 20 | self.O4 = nn.Linear(self.split_size, quantisation) 21 | 22 | # Input fc layers 23 | self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False) 24 | self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False) 25 | 26 | # biases for the gates 27 | self.bias_u = nn.Parameter(torch.zeros(self.hidden_size)) 28 | self.bias_r = nn.Parameter(torch.zeros(self.hidden_size)) 29 | self.bias_e = nn.Parameter(torch.zeros(self.hidden_size)) 30 | 31 | # display num params 32 | self.num_params() 33 | 34 | 35 | def forward(self, prev_y, prev_hidden, current_coarse) : 36 | 37 | # Main matmul - the projection is split 3 ways 38 | R_hidden = self.R(prev_hidden) 39 | R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1) 40 | 41 | # Project the prev input 42 | coarse_input_proj = self.I_coarse(prev_y) 43 | I_coarse_u, I_coarse_r, I_coarse_e = \ 44 | torch.split(coarse_input_proj, self.split_size, dim=1) 45 | 46 | # Project the prev input and current coarse sample 47 | fine_input = torch.cat([prev_y, current_coarse], dim=1) 48 | fine_input_proj = self.I_fine(fine_input) 49 | I_fine_u, I_fine_r, I_fine_e = \ 50 | torch.split(fine_input_proj, self.split_size, dim=1) 51 | 52 | # concatenate for the gates 53 | I_u = torch.cat([I_coarse_u, I_fine_u], dim=1) 54 | I_r = torch.cat([I_coarse_r, I_fine_r], dim=1) 55 | I_e = torch.cat([I_coarse_e, I_fine_e], dim=1) 56 | 57 | # Compute all gates for coarse and fine 58 | u = F.sigmoid(R_u + I_u + self.bias_u) 59 | r = F.sigmoid(R_r + I_r + self.bias_r) 60 | e = F.tanh(r * R_e + I_e + self.bias_e) 61 | hidden = u * prev_hidden + (1. - u) * e 62 | 63 | # Split the hidden state 64 | hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1) 65 | 66 | # Compute outputs 67 | out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) 68 | out_fine = self.O4(F.relu(self.O3(hidden_fine))) 69 | 70 | return out_coarse, out_fine, hidden 71 | 72 | 73 | def generate(self, seq_len) : 74 | 75 | with torch.no_grad() : 76 | 77 | # First split up the biases for the gates 78 | b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size) 79 | b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size) 80 | b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size) 81 | 82 | # Lists for the two output seqs 83 | c_outputs, f_outputs = [], [] 84 | 85 | # Some initial inputs 86 | out_coarse = torch.LongTensor([0]).cuda() 87 | out_fine = torch.LongTensor([0]).cuda() 88 | 89 | # We'll meed a hidden state 90 | hidden = self.init_hidden() 91 | 92 | # Need a clock for display 93 | start = time.time() 94 | 95 | # Loop for generation 96 | for i in range(seq_len) : 97 | 98 | # Split into two hidden states 99 | hidden_coarse, hidden_fine = \ 100 | torch.split(hidden, self.split_size, dim=1) 101 | 102 | # Scale and concat previous predictions 103 | out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1. 104 | out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1. 105 | prev_outputs = torch.cat([out_coarse, out_fine], dim=1) 106 | 107 | # Project input 108 | coarse_input_proj = self.I_coarse(prev_outputs) 109 | I_coarse_u, I_coarse_r, I_coarse_e = \ 110 | torch.split(coarse_input_proj, self.split_size, dim=1) 111 | 112 | # Project hidden state and split 6 ways 113 | R_hidden = self.R(hidden) 114 | R_coarse_u , R_fine_u, \ 115 | R_coarse_r, R_fine_r, \ 116 | R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1) 117 | 118 | # Compute the coarse gates 119 | u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u) 120 | r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r) 121 | e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e) 122 | hidden_coarse = u * hidden_coarse + (1. - u) * e 123 | 124 | # Compute the coarse output 125 | out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) 126 | posterior = F.softmax(out_coarse, dim=1) 127 | distrib = torch.distributions.Categorical(posterior) 128 | out_coarse = distrib.sample() 129 | c_outputs.append(out_coarse) 130 | 131 | # Project the [prev outputs and predicted coarse sample] 132 | coarse_pred = out_coarse.float() / 127.5 - 1. 133 | fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1) 134 | fine_input_proj = self.I_fine(fine_input) 135 | I_fine_u, I_fine_r, I_fine_e = \ 136 | torch.split(fine_input_proj, self.split_size, dim=1) 137 | 138 | # Compute the fine gates 139 | u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u) 140 | r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r) 141 | e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e) 142 | hidden_fine = u * hidden_fine + (1. - u) * e 143 | 144 | # Compute the fine output 145 | out_fine = self.O4(F.relu(self.O3(hidden_fine))) 146 | posterior = F.softmax(out_fine, dim=1) 147 | distrib = torch.distributions.Categorical(posterior) 148 | out_fine = distrib.sample() 149 | f_outputs.append(out_fine) 150 | 151 | # Put the hidden state back together 152 | hidden = torch.cat([hidden_coarse, hidden_fine], dim=1) 153 | 154 | # Display progress 155 | speed = (i + 1) / (time.time() - start) 156 | stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed)) 157 | 158 | coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy() 159 | fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy() 160 | output = combine_signal(coarse, fine) 161 | 162 | return output, coarse, fine 163 | 164 | 165 | def init_hidden(self, batch_size=1) : 166 | return torch.zeros(batch_size, self.hidden_size).cuda() 167 | 168 | 169 | def num_params(self) : 170 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 171 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 172 | print('Trainable Parameters: %.3f million' % parameters) -------------------------------------------------------------------------------- /notebooks/outputs/nb1/model_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/notebooks/outputs/nb1/model_output.wav -------------------------------------------------------------------------------- /notebooks/outputs/nb2/3k_steps.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/notebooks/outputs/nb2/3k_steps.wav -------------------------------------------------------------------------------- /notebooks/outputs/nb3/12k_steps.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/notebooks/outputs/nb3/12k_steps.wav -------------------------------------------------------------------------------- /notebooks/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/notebooks/utils/__init__.py -------------------------------------------------------------------------------- /notebooks/utils/display.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import time, sys, math 3 | import numpy as np 4 | 5 | def stream(string, variables) : 6 | sys.stdout.write(f'\r{string}' % variables) 7 | 8 | def num_params(model) : 9 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 10 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 11 | print('Trainable Parameters: %.3f million' % parameters) 12 | 13 | def time_since(started) : 14 | elapsed = time.time() - started 15 | m = int(elapsed // 60) 16 | s = int(elapsed % 60) 17 | if m >= 60 : 18 | h = int(m // 60) 19 | m = m % 60 20 | return f'{h}h {m}m {s}s' 21 | else : 22 | return f'{m}m {s}s' 23 | 24 | def plot(array) : 25 | fig = plt.figure(figsize=(30, 5)) 26 | ax = fig.add_subplot(111) 27 | ax.xaxis.label.set_color('grey') 28 | ax.yaxis.label.set_color('grey') 29 | ax.xaxis.label.set_fontsize(23) 30 | ax.yaxis.label.set_fontsize(23) 31 | ax.tick_params(axis='x', colors='grey', labelsize=23) 32 | ax.tick_params(axis='y', colors='grey', labelsize=23) 33 | plt.plot(array) 34 | 35 | def plot_spec(M) : 36 | M = np.flip(M, axis=0) 37 | plt.figure(figsize=(18,4)) 38 | plt.imshow(M, interpolation='nearest', aspect='auto') 39 | plt.show() 40 | 41 | -------------------------------------------------------------------------------- /notebooks/utils/dsp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa, math 3 | 4 | sample_rate = 22050 5 | n_fft = 2048 6 | fft_bins = n_fft // 2 + 1 7 | num_mels = 80 8 | hop_length = int(sample_rate * 0.0125) # 12.5ms 9 | win_length = int(sample_rate * 0.05) # 50ms 10 | fmin = 40 11 | min_level_db = -100 12 | ref_level_db = 20 13 | 14 | def load_wav(filename, encode=True) : 15 | x = librosa.load(filename, sr=sample_rate)[0] 16 | if encode == True : x = encode_16bits(x) 17 | return x 18 | 19 | def save_wav(y, filename) : 20 | if y.dtype != 'int16' : 21 | y = encode_16bits(y) 22 | librosa.output.write_wav(filename, y.astype(np.int16), sample_rate) 23 | 24 | def split_signal(x) : 25 | unsigned = x + 2**15 26 | coarse = unsigned // 256 27 | fine = unsigned % 256 28 | return coarse, fine 29 | 30 | def combine_signal(coarse, fine) : 31 | return coarse * 256 + fine - 2**15 32 | 33 | def encode_16bits(x) : 34 | return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) 35 | 36 | mel_basis = None 37 | 38 | def linear_to_mel(spectrogram): 39 | global mel_basis 40 | if mel_basis is None: 41 | mel_basis = build_mel_basis() 42 | return np.dot(mel_basis, spectrogram) 43 | 44 | def build_mel_basis(): 45 | return librosa.filters.mel(sample_rate, n_fft, n_mels=num_mels, fmin=fmin) 46 | 47 | def normalize(S): 48 | return np.clip((S - min_level_db) / -min_level_db, 0, 1) 49 | 50 | def denormalize(S): 51 | return (np.clip(S, 0, 1) * -min_level_db) + min_level_db 52 | 53 | def amp_to_db(x): 54 | return 20 * np.log10(np.maximum(1e-5, x)) 55 | 56 | def db_to_amp(x): 57 | return np.power(10.0, x * 0.05) 58 | 59 | def spectrogram(y): 60 | D = stft(y) 61 | S = amp_to_db(np.abs(D)) - ref_level_db 62 | return normalize(S) 63 | 64 | def melspectrogram(y): 65 | D = stft(y) 66 | S = amp_to_db(linear_to_mel(np.abs(D))) 67 | return normalize(S) 68 | 69 | def stft(y): 70 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from utils.display import * 3 | from utils.dsp import * 4 | from utils import hparams as hp 5 | from multiprocessing import Pool, cpu_count 6 | from utils.paths import Paths 7 | import pickle 8 | import argparse 9 | from utils.text.recipes import ljspeech 10 | from utils.files import get_files 11 | from pathlib import Path 12 | 13 | 14 | # Helper functions for argument types 15 | def valid_n_workers(num): 16 | n = int(num) 17 | if n < 1: 18 | raise argparse.ArgumentTypeError('%r must be an integer greater than 0' % num) 19 | return n 20 | 21 | parser = argparse.ArgumentParser(description='Preprocessing for WaveRNN and Tacotron') 22 | parser.add_argument('--path', '-p', help='directly point to dataset path (overrides hparams.wav_path') 23 | parser.add_argument('--extension', '-e', metavar='EXT', default='.wav', help='file extension to search for in dataset folder') 24 | parser.add_argument('--num_workers', '-w', metavar='N', type=valid_n_workers, default=cpu_count()-1, help='The number of worker threads to use for preprocessing') 25 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') 26 | args = parser.parse_args() 27 | 28 | hp.configure(args.hp_file) # Load hparams from file 29 | if args.path is None: 30 | args.path = hp.wav_path 31 | 32 | extension = args.extension 33 | path = args.path 34 | 35 | 36 | def convert_file(path: Path): 37 | y = load_wav(path) 38 | peak = np.abs(y).max() 39 | if hp.peak_norm or peak > 1.0: 40 | y /= peak 41 | mel = melspectrogram(y) 42 | if hp.voc_mode == 'RAW': 43 | quant = encode_mu_law(y, mu=2**hp.bits) if hp.mu_law else float_2_label(y, bits=hp.bits) 44 | elif hp.voc_mode == 'MOL': 45 | quant = float_2_label(y, bits=16) 46 | 47 | return mel.astype(np.float32), quant.astype(np.int64) 48 | 49 | 50 | def process_wav(path: Path): 51 | wav_id = path.stem 52 | m, x = convert_file(path) 53 | np.save(paths.mel/f'{wav_id}.npy', m, allow_pickle=False) 54 | np.save(paths.quant/f'{wav_id}.npy', x, allow_pickle=False) 55 | return wav_id, m.shape[-1] 56 | 57 | 58 | wav_files = get_files(path, extension) 59 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) 60 | 61 | print(f'\n{len(wav_files)} {extension[1:]} files found in "{path}"\n') 62 | 63 | if len(wav_files) == 0: 64 | 65 | print('Please point wav_path in hparams.py to your dataset,') 66 | print('or use the --path option.\n') 67 | 68 | else: 69 | 70 | if not hp.ignore_tts: 71 | 72 | text_dict = ljspeech(path) 73 | 74 | with open(paths.data/'text_dict.pkl', 'wb') as f: 75 | pickle.dump(text_dict, f) 76 | 77 | n_workers = max(1, args.num_workers) 78 | 79 | simple_table([ 80 | ('Sample Rate', hp.sample_rate), 81 | ('Bit Depth', hp.bits), 82 | ('Mu Law', hp.mu_law), 83 | ('Hop Length', hp.hop_length), 84 | ('CPU Usage', f'{n_workers}/{cpu_count()}') 85 | ]) 86 | 87 | pool = Pool(processes=n_workers) 88 | dataset = [] 89 | 90 | for i, (item_id, length) in enumerate(pool.imap_unordered(process_wav, wav_files), 1): 91 | dataset += [(item_id, length)] 92 | bar = progbar(i, len(wav_files)) 93 | message = f'{bar} {i}/{len(wav_files)} ' 94 | stream(message) 95 | 96 | with open(paths.data/'dataset.pkl', 'wb') as f: 97 | pickle.dump(dataset, f) 98 | 99 | print('\n\nCompleted. Ready to run "python train_tacotron.py" or "python train_wavernn.py". \n') 100 | -------------------------------------------------------------------------------- /pretrained/ljspeech.tacotron.r2.180k.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/pretrained/ljspeech.tacotron.r2.180k.zip -------------------------------------------------------------------------------- /pretrained/ljspeech.wavernn.mol.800k.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatchord/WaveRNN/83c08fdcd625be56244f4145f41500468974b144/pretrained/ljspeech.wavernn.mol.800k.zip -------------------------------------------------------------------------------- /quick_start.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.fatchord_version import WaveRNN 3 | from utils import hparams as hp 4 | from utils.text.symbols import symbols 5 | from models.tacotron import Tacotron 6 | import argparse 7 | from utils.text import text_to_sequence 8 | from utils.display import save_attention, simple_table 9 | import zipfile, os 10 | 11 | 12 | os.makedirs('quick_start/tts_weights/', exist_ok=True) 13 | os.makedirs('quick_start/voc_weights/', exist_ok=True) 14 | 15 | zip_ref = zipfile.ZipFile('pretrained/ljspeech.wavernn.mol.800k.zip', 'r') 16 | zip_ref.extractall('quick_start/voc_weights/') 17 | zip_ref.close() 18 | 19 | zip_ref = zipfile.ZipFile('pretrained/ljspeech.tacotron.r2.180k.zip', 'r') 20 | zip_ref.extractall('quick_start/tts_weights/') 21 | zip_ref.close() 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | # Parse Arguments 27 | parser = argparse.ArgumentParser(description='TTS Generator') 28 | parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!') 29 | parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)') 30 | parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)') 31 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') 32 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', 33 | help='The file to use for the hyperparameters') 34 | args = parser.parse_args() 35 | 36 | hp.configure(args.hp_file) # Load hparams from file 37 | 38 | parser.set_defaults(batched=True) 39 | parser.set_defaults(input_text=None) 40 | 41 | batched = args.batched 42 | input_text = args.input_text 43 | 44 | if not args.force_cpu and torch.cuda.is_available(): 45 | device = torch.device('cuda') 46 | else: 47 | device = torch.device('cpu') 48 | print('Using device:', device) 49 | 50 | print('\nInitialising WaveRNN Model...\n') 51 | 52 | # Instantiate WaveRNN Model 53 | voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, 54 | fc_dims=hp.voc_fc_dims, 55 | bits=hp.bits, 56 | pad=hp.voc_pad, 57 | upsample_factors=hp.voc_upsample_factors, 58 | feat_dims=hp.num_mels, 59 | compute_dims=hp.voc_compute_dims, 60 | res_out_dims=hp.voc_res_out_dims, 61 | res_blocks=hp.voc_res_blocks, 62 | hop_length=hp.hop_length, 63 | sample_rate=hp.sample_rate, 64 | mode='MOL').to(device) 65 | 66 | voc_model.load('quick_start/voc_weights/latest_weights.pyt') 67 | 68 | print('\nInitialising Tacotron Model...\n') 69 | 70 | # Instantiate Tacotron Model 71 | tts_model = Tacotron(embed_dims=hp.tts_embed_dims, 72 | num_chars=len(symbols), 73 | encoder_dims=hp.tts_encoder_dims, 74 | decoder_dims=hp.tts_decoder_dims, 75 | n_mels=hp.num_mels, 76 | fft_bins=hp.num_mels, 77 | postnet_dims=hp.tts_postnet_dims, 78 | encoder_K=hp.tts_encoder_K, 79 | lstm_dims=hp.tts_lstm_dims, 80 | postnet_K=hp.tts_postnet_K, 81 | num_highways=hp.tts_num_highways, 82 | dropout=hp.tts_dropout, 83 | stop_threshold=hp.tts_stop_threshold).to(device) 84 | 85 | 86 | tts_model.load('quick_start/tts_weights/latest_weights.pyt') 87 | 88 | if input_text: 89 | inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)] 90 | else: 91 | with open('sentences.txt') as f: 92 | inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f] 93 | 94 | voc_k = voc_model.get_step() // 1000 95 | tts_k = tts_model.get_step() // 1000 96 | 97 | r = tts_model.r 98 | 99 | simple_table([('WaveRNN', str(voc_k) + 'k'), 100 | (f'Tacotron(r={r})', str(tts_k) + 'k'), 101 | ('Generation Mode', 'Batched' if batched else 'Unbatched'), 102 | ('Target Samples', 11_000 if batched else 'N/A'), 103 | ('Overlap Samples', 550 if batched else 'N/A')]) 104 | 105 | for i, x in enumerate(inputs, 1): 106 | 107 | print(f'\n| Generating {i}/{len(inputs)}') 108 | _, m, attention = tts_model.generate(x) 109 | 110 | if input_text: 111 | save_path = f'quick_start/__input_{input_text[:10]}_{tts_k}k.wav' 112 | else: 113 | save_path = f'quick_start/{i}_batched{str(batched)}_{tts_k}k.wav' 114 | 115 | # save_attention(attention, save_path) 116 | 117 | m = torch.tensor(m).unsqueeze(0) 118 | m = (m + 4) / 8 119 | 120 | voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law) 121 | 122 | print('\n\nDone.\n') 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | librosa==0.6.3 3 | matplotlib 4 | unidecode 5 | inflect 6 | nltk -------------------------------------------------------------------------------- /sentences.txt: -------------------------------------------------------------------------------- 1 | Scientists at the CERN laboratory say they have discovered a new particle. 2 | There's a way to measure the acute emotional intelligence that has never gone out of style. 3 | President Trump met with other leaders at the Group of 20 conference. 4 | The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled. 5 | Generative adversarial network or variational auto-encoder. 6 | Basilar membrane and otolaryngology are not auto-correlations. 7 | -------------------------------------------------------------------------------- /train_tacotron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import torch.nn.functional as F 4 | from utils import hparams as hp 5 | from utils.display import * 6 | from utils.dataset import get_tts_datasets 7 | from utils.text.symbols import symbols 8 | from utils.paths import Paths 9 | from models.tacotron import Tacotron 10 | import argparse 11 | from utils import data_parallel_workaround 12 | import os 13 | from pathlib import Path 14 | import time 15 | import numpy as np 16 | import sys 17 | from utils.checkpoints import save_checkpoint, restore_checkpoint 18 | 19 | 20 | def np_now(x: torch.Tensor): return x.detach().cpu().numpy() 21 | 22 | 23 | def main(): 24 | # Parse Arguments 25 | parser = argparse.ArgumentParser(description='Train Tacotron TTS') 26 | parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps') 27 | parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features') 28 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') 29 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') 30 | args = parser.parse_args() 31 | 32 | hp.configure(args.hp_file) # Load hparams from file 33 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) 34 | 35 | force_train = args.force_train 36 | force_gta = args.force_gta 37 | 38 | if not args.force_cpu and torch.cuda.is_available(): 39 | device = torch.device('cuda') 40 | for session in hp.tts_schedule: 41 | _, _, _, batch_size = session 42 | if batch_size % torch.cuda.device_count() != 0: 43 | raise ValueError('`batch_size` must be evenly divisible by n_gpus!') 44 | else: 45 | device = torch.device('cpu') 46 | print('Using device:', device) 47 | 48 | # Instantiate Tacotron Model 49 | print('\nInitialising Tacotron Model...\n') 50 | model = Tacotron(embed_dims=hp.tts_embed_dims, 51 | num_chars=len(symbols), 52 | encoder_dims=hp.tts_encoder_dims, 53 | decoder_dims=hp.tts_decoder_dims, 54 | n_mels=hp.num_mels, 55 | fft_bins=hp.num_mels, 56 | postnet_dims=hp.tts_postnet_dims, 57 | encoder_K=hp.tts_encoder_K, 58 | lstm_dims=hp.tts_lstm_dims, 59 | postnet_K=hp.tts_postnet_K, 60 | num_highways=hp.tts_num_highways, 61 | dropout=hp.tts_dropout, 62 | stop_threshold=hp.tts_stop_threshold).to(device) 63 | 64 | optimizer = optim.Adam(model.parameters()) 65 | restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True) 66 | 67 | if not force_gta: 68 | for i, session in enumerate(hp.tts_schedule): 69 | current_step = model.get_step() 70 | 71 | r, lr, max_step, batch_size = session 72 | 73 | training_steps = max_step - current_step 74 | 75 | # Do we need to change to the next session? 76 | if current_step >= max_step: 77 | # Are there no further sessions than the current one? 78 | if i == len(hp.tts_schedule)-1: 79 | # There are no more sessions. Check if we force training. 80 | if force_train: 81 | # Don't finish the loop - train forever 82 | training_steps = 999_999_999 83 | else: 84 | # We have completed training. Breaking is same as continue 85 | break 86 | else: 87 | # There is a following session, go to it 88 | continue 89 | 90 | model.r = r 91 | 92 | simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'), 93 | ('Batch Size', batch_size), 94 | ('Learning Rate', lr), 95 | ('Outputs/Step (r)', model.r)]) 96 | 97 | train_set, attn_example = get_tts_datasets(paths.data, batch_size, r) 98 | tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example) 99 | 100 | print('Training Complete.') 101 | print('To continue training increase tts_total_steps in hparams.py or use --force_train\n') 102 | 103 | 104 | print('Creating Ground Truth Aligned Dataset...\n') 105 | 106 | train_set, attn_example = get_tts_datasets(paths.data, 8, model.r) 107 | create_gta_features(model, train_set, paths.gta) 108 | 109 | print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n') 110 | 111 | 112 | def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example): 113 | device = next(model.parameters()).device # use same device as model parameters 114 | 115 | for g in optimizer.param_groups: g['lr'] = lr 116 | 117 | total_iters = len(train_set) 118 | epochs = train_steps // total_iters + 1 119 | 120 | for e in range(1, epochs+1): 121 | 122 | start = time.time() 123 | running_loss = 0 124 | 125 | # Perform 1 epoch 126 | for i, (x, m, ids, _) in enumerate(train_set, 1): 127 | 128 | x, m = x.to(device), m.to(device) 129 | 130 | # Parallelize model onto GPUS using workaround due to python bug 131 | if device.type == 'cuda' and torch.cuda.device_count() > 1: 132 | m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m) 133 | else: 134 | m1_hat, m2_hat, attention = model(x, m) 135 | 136 | m1_loss = F.l1_loss(m1_hat, m) 137 | m2_loss = F.l1_loss(m2_hat, m) 138 | 139 | loss = m1_loss + m2_loss 140 | 141 | optimizer.zero_grad() 142 | loss.backward() 143 | if hp.tts_clip_grad_norm is not None: 144 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) 145 | if np.isnan(grad_norm): 146 | print('grad_norm was NaN!') 147 | 148 | optimizer.step() 149 | 150 | running_loss += loss.item() 151 | avg_loss = running_loss / i 152 | 153 | speed = i / (time.time() - start) 154 | 155 | step = model.get_step() 156 | k = step // 1000 157 | 158 | if step % hp.tts_checkpoint_every == 0: 159 | ckpt_name = f'taco_step{k}K' 160 | save_checkpoint('tts', paths, model, optimizer, 161 | name=ckpt_name, is_silent=True) 162 | 163 | if attn_example in ids: 164 | idx = ids.index(attn_example) 165 | save_attention(np_now(attention[idx][:, :160]), paths.tts_attention/f'{step}') 166 | save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot/f'{step}', 600) 167 | 168 | msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | ' 169 | stream(msg) 170 | 171 | # Must save latest optimizer state to ensure that resuming training 172 | # doesn't produce artifacts 173 | save_checkpoint('tts', paths, model, optimizer, is_silent=True) 174 | model.log(paths.tts_log, msg) 175 | print(' ') 176 | 177 | 178 | def create_gta_features(model: Tacotron, train_set, save_path: Path): 179 | device = next(model.parameters()).device # use same device as model parameters 180 | 181 | iters = len(train_set) 182 | 183 | for i, (x, mels, ids, mel_lens) in enumerate(train_set, 1): 184 | 185 | x, mels = x.to(device), mels.to(device) 186 | 187 | with torch.no_grad(): _, gta, _ = model(x, mels) 188 | 189 | gta = gta.cpu().numpy() 190 | 191 | for j, item_id in enumerate(ids): 192 | mel = gta[j][:, :mel_lens[j]] 193 | mel = (mel + 4) / 8 194 | np.save(save_path/f'{item_id}.npy', mel, allow_pickle=False) 195 | 196 | bar = progbar(i, iters) 197 | msg = f'{bar} {i}/{iters} Batches ' 198 | stream(msg) 199 | 200 | 201 | if __name__ == "__main__": 202 | main() 203 | -------------------------------------------------------------------------------- /train_wavernn.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from torch import optim 5 | import torch.nn.functional as F 6 | from utils.display import stream, simple_table 7 | from utils.dataset import get_vocoder_datasets 8 | from utils.distribution import discretized_mix_logistic_loss 9 | from utils import hparams as hp 10 | from models.fatchord_version import WaveRNN 11 | from gen_wavernn import gen_testset 12 | from utils.paths import Paths 13 | import argparse 14 | from utils import data_parallel_workaround 15 | from utils.checkpoints import save_checkpoint, restore_checkpoint 16 | 17 | 18 | def main(): 19 | 20 | # Parse Arguments 21 | parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder') 22 | parser.add_argument('--lr', '-l', type=float, help='[float] override hparams.py learning rate') 23 | parser.add_argument('--batch_size', '-b', type=int, help='[int] override hparams.py batch size') 24 | parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps') 25 | parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features') 26 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') 27 | parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') 28 | args = parser.parse_args() 29 | 30 | hp.configure(args.hp_file) # load hparams from file 31 | if args.lr is None: 32 | args.lr = hp.voc_lr 33 | if args.batch_size is None: 34 | args.batch_size = hp.voc_batch_size 35 | 36 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) 37 | 38 | batch_size = args.batch_size 39 | force_train = args.force_train 40 | train_gta = args.gta 41 | lr = args.lr 42 | 43 | if not args.force_cpu and torch.cuda.is_available(): 44 | device = torch.device('cuda') 45 | if batch_size % torch.cuda.device_count() != 0: 46 | raise ValueError('`batch_size` must be evenly divisible by n_gpus!') 47 | else: 48 | device = torch.device('cpu') 49 | print('Using device:', device) 50 | 51 | print('\nInitialising Model...\n') 52 | 53 | # Instantiate WaveRNN Model 54 | voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, 55 | fc_dims=hp.voc_fc_dims, 56 | bits=hp.bits, 57 | pad=hp.voc_pad, 58 | upsample_factors=hp.voc_upsample_factors, 59 | feat_dims=hp.num_mels, 60 | compute_dims=hp.voc_compute_dims, 61 | res_out_dims=hp.voc_res_out_dims, 62 | res_blocks=hp.voc_res_blocks, 63 | hop_length=hp.hop_length, 64 | sample_rate=hp.sample_rate, 65 | mode=hp.voc_mode).to(device) 66 | 67 | # Check to make sure the hop length is correctly factorised 68 | assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length 69 | 70 | optimizer = optim.Adam(voc_model.parameters()) 71 | restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True) 72 | 73 | train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta) 74 | 75 | total_steps = 10_000_000 if force_train else hp.voc_total_steps 76 | 77 | simple_table([('Remaining', str((total_steps - voc_model.get_step())//1000) + 'k Steps'), 78 | ('Batch Size', batch_size), 79 | ('LR', lr), 80 | ('Sequence Len', hp.voc_seq_len), 81 | ('GTA Train', train_gta)]) 82 | 83 | loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss 84 | 85 | voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps) 86 | 87 | print('Training Complete.') 88 | print('To continue training increase voc_total_steps in hparams.py or use --force_train') 89 | 90 | 91 | def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps): 92 | # Use same device as model parameters 93 | device = next(model.parameters()).device 94 | 95 | for g in optimizer.param_groups: g['lr'] = lr 96 | 97 | total_iters = len(train_set) 98 | epochs = (total_steps - model.get_step()) // total_iters + 1 99 | 100 | for e in range(1, epochs + 1): 101 | 102 | start = time.time() 103 | running_loss = 0. 104 | 105 | for i, (x, y, m) in enumerate(train_set, 1): 106 | x, m, y = x.to(device), m.to(device), y.to(device) 107 | 108 | # Parallelize model onto GPUS using workaround due to python bug 109 | if device.type == 'cuda' and torch.cuda.device_count() > 1: 110 | y_hat = data_parallel_workaround(model, x, m) 111 | else: 112 | y_hat = model(x, m) 113 | 114 | if model.mode == 'RAW': 115 | y_hat = y_hat.transpose(1, 2).unsqueeze(-1) 116 | 117 | elif model.mode == 'MOL': 118 | y = y.float() 119 | 120 | y = y.unsqueeze(-1) 121 | 122 | 123 | loss = loss_func(y_hat, y) 124 | 125 | optimizer.zero_grad() 126 | loss.backward() 127 | if hp.voc_clip_grad_norm is not None: 128 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm) 129 | if np.isnan(grad_norm): 130 | print('grad_norm was NaN!') 131 | optimizer.step() 132 | 133 | running_loss += loss.item() 134 | avg_loss = running_loss / i 135 | 136 | speed = i / (time.time() - start) 137 | 138 | step = model.get_step() 139 | k = step // 1000 140 | 141 | if step % hp.voc_checkpoint_every == 0: 142 | gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, 143 | hp.voc_target, hp.voc_overlap, paths.voc_output) 144 | ckpt_name = f'wave_step{k}K' 145 | save_checkpoint('voc', paths, model, optimizer, 146 | name=ckpt_name, is_silent=True) 147 | 148 | msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | ' 149 | stream(msg) 150 | 151 | # Must save latest optimizer state to ensure that resuming training 152 | # doesn't produce artifacts 153 | save_checkpoint('voc', paths, model, optimizer, is_silent=True) 154 | model.log(paths.voc_log, msg) 155 | print(' ') 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Make it explicit that we do it the Python 3 way 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | from builtins import * 4 | 5 | import sys 6 | import torch 7 | import re 8 | 9 | from importlib.util import spec_from_file_location, module_from_spec 10 | from pathlib import Path 11 | from typing import Union 12 | 13 | # Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599) 14 | # Modified by: Ryan Butler (https://github.com/TheButlah) 15 | # workaround for https://github.com/pytorch/pytorch/issues/15716 16 | # the idea is to return outputs and replicas explicitly, so that making pytorch 17 | # not to release the nodes (this is a pytorch bug though) 18 | 19 | _output_ref = None 20 | _replicas_ref = None 21 | 22 | def data_parallel_workaround(model, *input): 23 | global _output_ref 24 | global _replicas_ref 25 | device_ids = list(range(torch.cuda.device_count())) 26 | output_device = device_ids[0] 27 | replicas = torch.nn.parallel.replicate(model, device_ids) 28 | # input.shape = (num_args, batch, ...) 29 | inputs = torch.nn.parallel.scatter(input, device_ids) 30 | # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) 31 | replicas = replicas[:len(inputs)] 32 | outputs = torch.nn.parallel.parallel_apply(replicas, inputs) 33 | y_hat = torch.nn.parallel.gather(outputs, output_device) 34 | _output_ref = outputs 35 | _replicas_ref = replicas 36 | return y_hat 37 | 38 | 39 | ###### Deal with hparams import that has to be configured at runtime ###### 40 | class __HParams: 41 | """Manages the hyperparams pseudo-module""" 42 | def __init__(self, path: Union[str, Path]=None): 43 | """Constructs the hyperparameters from a path to a python module. If 44 | `path` is None, will raise an AttributeError whenever its attributes 45 | are accessed. Otherwise, configures self based on `path`.""" 46 | if path is None: 47 | self._configured = False 48 | else: 49 | self.configure(path) 50 | 51 | def __getattr__(self, item): 52 | if not self.is_configured(): 53 | raise AttributeError("HParams not configured yet. Call self.configure()") 54 | else: 55 | return super().__getattr__(item) 56 | 57 | def configure(self, path: Union[str, Path]): 58 | """Configures hparams by copying over atrributes from a module with the 59 | given path. Raises an exception if already configured.""" 60 | if self.is_configured(): 61 | raise RuntimeError("Cannot reconfigure hparams!") 62 | 63 | ###### Check for proper path ###### 64 | if not isinstance(path, Path): 65 | path = Path(path).expanduser() 66 | if not path.exists(): 67 | raise FileNotFoundError(f"Could not find hparams file {path}") 68 | elif path.suffix != ".py": 69 | raise ValueError("`path` must be a python file") 70 | 71 | ###### Load in attributes from module ###### 72 | m = _import_from_file("hparams", path) 73 | 74 | reg = re.compile(r"^__.+__$") # Matches magic methods 75 | for name, value in m.__dict__.items(): 76 | if reg.match(name): 77 | # Skip builtins 78 | continue 79 | if name in self.__dict__: 80 | # Cannot overwrite already existing attributes 81 | raise AttributeError( 82 | f"module at `path` cannot contain attribute {name} as it " 83 | "overwrites an attribute of the same name in utils.hparams") 84 | # Fair game to copy over the attribute 85 | self.__setattr__(name, value) 86 | 87 | self._configured = True 88 | 89 | def is_configured(self): 90 | return self._configured 91 | 92 | hparams = __HParams() 93 | 94 | 95 | def _import_from_file(name, path: Path): 96 | """Programmatically returns a module object from a filepath""" 97 | if not Path(path).exists(): 98 | raise FileNotFoundError('"%s" doesn\'t exist!' % path) 99 | spec = spec_from_file_location(name, path) 100 | if spec is None: 101 | raise ValueError('could not load module from "%s"' % path) 102 | m = module_from_spec(spec) 103 | spec.loader.exec_module(m) 104 | return m -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.paths import Paths 3 | from models.tacotron import Tacotron 4 | 5 | 6 | def get_checkpoint_paths(checkpoint_type: str, paths: Paths): 7 | """ 8 | Returns the correct checkpointing paths 9 | depending on whether model is Vocoder or TTS 10 | 11 | Args: 12 | checkpoint_type: Either 'voc' or 'tts' 13 | paths: Paths object 14 | """ 15 | if checkpoint_type is 'tts': 16 | weights_path = paths.tts_latest_weights 17 | optim_path = paths.tts_latest_optim 18 | checkpoint_path = paths.tts_checkpoints 19 | elif checkpoint_type is 'voc': 20 | weights_path = paths.voc_latest_weights 21 | optim_path = paths.voc_latest_optim 22 | checkpoint_path = paths.voc_checkpoints 23 | else: 24 | raise NotImplementedError 25 | 26 | return weights_path, optim_path, checkpoint_path 27 | 28 | 29 | def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, 30 | name=None, is_silent=False): 31 | """Saves the training session to disk. 32 | 33 | Args: 34 | paths: Provides information about the different paths to use. 35 | model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from. 36 | optimizer: An optmizer to save the state of (momentum, etc). 37 | name: If provided, will name to a checkpoint with the given name. Note 38 | that regardless of whether this is provided or not, this function 39 | will always update the files specified in `paths` that give the 40 | location of the latest weights and optimizer state. Saving 41 | a named checkpoint happens in addition to this update. 42 | """ 43 | def helper(path_dict, is_named): 44 | s = 'named' if is_named else 'latest' 45 | num_exist = sum(p.exists() for p in path_dict.values()) 46 | 47 | if num_exist not in (0,2): 48 | # Checkpoint broken 49 | raise FileNotFoundError( 50 | f'We expected either both or no files in the {s} checkpoint to ' 51 | 'exist, but instead we got exactly one!') 52 | 53 | if num_exist == 0: 54 | if not is_silent: print(f'Creating {s} checkpoint...') 55 | for p in path_dict.values(): 56 | p.parent.mkdir(parents=True, exist_ok=True) 57 | else: 58 | if not is_silent: print(f'Saving to existing {s} checkpoint...') 59 | 60 | if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}') 61 | model.save(path_dict['w']) 62 | if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}') 63 | torch.save(optimizer.state_dict(), path_dict['o']) 64 | 65 | weights_path, optim_path, checkpoint_path = \ 66 | get_checkpoint_paths(checkpoint_type, paths) 67 | 68 | latest_paths = {'w': weights_path, 'o': optim_path} 69 | helper(latest_paths, False) 70 | 71 | if name: 72 | named_paths = { 73 | 'w': checkpoint_path/f'{name}_weights.pyt', 74 | 'o': checkpoint_path/f'{name}_optim.pyt', 75 | } 76 | helper(named_paths, True) 77 | 78 | 79 | def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, 80 | name=None, create_if_missing=False): 81 | """Restores from a training session saved to disk. 82 | 83 | NOTE: The optimizer's state is placed on the same device as it's model 84 | parameters. Therefore, be sure you have done `model.to(device)` before 85 | calling this method. 86 | 87 | Args: 88 | paths: Provides information about the different paths to use. 89 | model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from. 90 | optimizer: An optmizer to save the state of (momentum, etc). 91 | name: If provided, will restore from a checkpoint with the given name. 92 | Otherwise, will restore from the latest weights and optimizer state 93 | as specified in `paths`. 94 | create_if_missing: If `True`, will create the checkpoint if it doesn't 95 | yet exist, as well as update the files specified in `paths` that 96 | give the location of the current latest weights and optimizer state. 97 | If `False` and the checkpoint doesn't exist, will raise a 98 | `FileNotFoundError`. 99 | """ 100 | 101 | weights_path, optim_path, checkpoint_path = \ 102 | get_checkpoint_paths(checkpoint_type, paths) 103 | 104 | if name: 105 | path_dict = { 106 | 'w': checkpoint_path/f'{name}_weights.pyt', 107 | 'o': checkpoint_path/f'{name}_optim.pyt', 108 | } 109 | s = 'named' 110 | else: 111 | path_dict = { 112 | 'w': weights_path, 113 | 'o': optim_path 114 | } 115 | s = 'latest' 116 | 117 | num_exist = sum(p.exists() for p in path_dict.values()) 118 | if num_exist == 2: 119 | # Checkpoint exists 120 | print(f'Restoring from {s} checkpoint...') 121 | print(f'Loading {s} weights: {path_dict["w"]}') 122 | model.load(path_dict['w']) 123 | print(f'Loading {s} optimizer state: {path_dict["o"]}') 124 | optimizer.load_state_dict(torch.load(path_dict['o'])) 125 | elif create_if_missing: 126 | save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False) 127 | else: 128 | raise FileNotFoundError(f'The {s} checkpoint could not be found!') -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torch.utils.data.sampler import Sampler 6 | from utils.dsp import * 7 | from utils import hparams as hp 8 | from utils.text import text_to_sequence 9 | from utils.paths import Paths 10 | from pathlib import Path 11 | 12 | 13 | ################################################################################### 14 | # WaveRNN/Vocoder Dataset ######################################################### 15 | ################################################################################### 16 | 17 | 18 | class VocoderDataset(Dataset): 19 | def __init__(self, path: Path, dataset_ids, train_gta=False): 20 | self.metadata = dataset_ids 21 | self.mel_path = path/'gta' if train_gta else path/'mel' 22 | self.quant_path = path/'quant' 23 | 24 | 25 | def __getitem__(self, index): 26 | item_id = self.metadata[index] 27 | m = np.load(self.mel_path/f'{item_id}.npy') 28 | x = np.load(self.quant_path/f'{item_id}.npy') 29 | return m, x 30 | 31 | def __len__(self): 32 | return len(self.metadata) 33 | 34 | 35 | def get_vocoder_datasets(path: Path, batch_size, train_gta): 36 | 37 | with open(path/'dataset.pkl', 'rb') as f: 38 | dataset = pickle.load(f) 39 | 40 | dataset_ids = [x[0] for x in dataset] 41 | 42 | random.seed(1234) 43 | random.shuffle(dataset_ids) 44 | 45 | test_ids = dataset_ids[-hp.voc_test_samples:] 46 | train_ids = dataset_ids[:-hp.voc_test_samples] 47 | 48 | train_dataset = VocoderDataset(path, train_ids, train_gta) 49 | test_dataset = VocoderDataset(path, test_ids, train_gta) 50 | 51 | train_set = DataLoader(train_dataset, 52 | collate_fn=collate_vocoder, 53 | batch_size=batch_size, 54 | num_workers=2, 55 | shuffle=True, 56 | pin_memory=True) 57 | 58 | test_set = DataLoader(test_dataset, 59 | batch_size=1, 60 | num_workers=1, 61 | shuffle=False, 62 | pin_memory=True) 63 | 64 | return train_set, test_set 65 | 66 | 67 | def collate_vocoder(batch): 68 | mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad 69 | max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch] 70 | mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] 71 | sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets] 72 | 73 | mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)] 74 | 75 | labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)] 76 | 77 | mels = np.stack(mels).astype(np.float32) 78 | labels = np.stack(labels).astype(np.int64) 79 | 80 | mels = torch.tensor(mels) 81 | labels = torch.tensor(labels).long() 82 | 83 | x = labels[:, :hp.voc_seq_len] 84 | y = labels[:, 1:] 85 | 86 | bits = 16 if hp.voc_mode == 'MOL' else hp.bits 87 | 88 | x = label_2_float(x.float(), bits) 89 | 90 | if hp.voc_mode == 'MOL': 91 | y = label_2_float(y.float(), bits) 92 | 93 | return x, y, mels 94 | 95 | 96 | ################################################################################### 97 | # Tacotron/TTS Dataset ############################################################ 98 | ################################################################################### 99 | 100 | 101 | def get_tts_datasets(path: Path, batch_size, r): 102 | 103 | with open(path/'dataset.pkl', 'rb') as f: 104 | dataset = pickle.load(f) 105 | 106 | dataset_ids = [] 107 | mel_lengths = [] 108 | 109 | for (item_id, len) in dataset: 110 | if len <= hp.tts_max_mel_len: 111 | dataset_ids += [item_id] 112 | mel_lengths += [len] 113 | 114 | with open(path/'text_dict.pkl', 'rb') as f: 115 | text_dict = pickle.load(f) 116 | 117 | train_dataset = TTSDataset(path, dataset_ids, text_dict) 118 | 119 | sampler = None 120 | 121 | if hp.tts_bin_lengths: 122 | sampler = BinnedLengthSampler(mel_lengths, batch_size, batch_size * 3) 123 | 124 | train_set = DataLoader(train_dataset, 125 | collate_fn=lambda batch: collate_tts(batch, r), 126 | batch_size=batch_size, 127 | sampler=sampler, 128 | num_workers=1, 129 | pin_memory=True) 130 | 131 | longest = mel_lengths.index(max(mel_lengths)) 132 | 133 | # Used to evaluate attention during training process 134 | attn_example = dataset_ids[longest] 135 | 136 | # print(attn_example) 137 | 138 | return train_set, attn_example 139 | 140 | 141 | class TTSDataset(Dataset): 142 | def __init__(self, path: Path, dataset_ids, text_dict): 143 | self.path = path 144 | self.metadata = dataset_ids 145 | self.text_dict = text_dict 146 | 147 | def __getitem__(self, index): 148 | item_id = self.metadata[index] 149 | x = text_to_sequence(self.text_dict[item_id], hp.tts_cleaner_names) 150 | mel = np.load(self.path/'mel'/f'{item_id}.npy') 151 | mel_len = mel.shape[-1] 152 | return x, mel, item_id, mel_len 153 | 154 | def __len__(self): 155 | return len(self.metadata) 156 | 157 | 158 | def pad1d(x, max_len): 159 | return np.pad(x, (0, max_len - len(x)), mode='constant') 160 | 161 | 162 | def pad2d(x, max_len): 163 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode='constant') 164 | 165 | 166 | def collate_tts(batch, r): 167 | 168 | x_lens = [len(x[0]) for x in batch] 169 | max_x_len = max(x_lens) 170 | 171 | chars = [pad1d(x[0], max_x_len) for x in batch] 172 | chars = np.stack(chars) 173 | 174 | spec_lens = [x[1].shape[-1] for x in batch] 175 | max_spec_len = max(spec_lens) + 1 176 | if max_spec_len % r != 0: 177 | max_spec_len += r - max_spec_len % r 178 | 179 | mel = [pad2d(x[1], max_spec_len) for x in batch] 180 | mel = np.stack(mel) 181 | 182 | ids = [x[2] for x in batch] 183 | mel_lens = [x[3] for x in batch] 184 | 185 | chars = torch.tensor(chars).long() 186 | mel = torch.tensor(mel) 187 | 188 | # scale spectrograms to -4 <--> 4 189 | mel = (mel * 8.) - 4. 190 | return chars, mel, ids, mel_lens 191 | 192 | 193 | class BinnedLengthSampler(Sampler): 194 | def __init__(self, lengths, batch_size, bin_size): 195 | _, self.idx = torch.sort(torch.tensor(lengths).long()) 196 | self.batch_size = batch_size 197 | self.bin_size = bin_size 198 | assert self.bin_size % self.batch_size == 0 199 | 200 | def __iter__(self): 201 | # Need to change to numpy since there's a bug in random.shuffle(tensor) 202 | # TODO: Post an issue on pytorch repo 203 | idx = self.idx.numpy() 204 | bins = [] 205 | 206 | for i in range(len(idx) // self.bin_size): 207 | this_bin = idx[i * self.bin_size:(i + 1) * self.bin_size] 208 | random.shuffle(this_bin) 209 | bins += [this_bin] 210 | 211 | random.shuffle(bins) 212 | binned_idx = np.stack(bins).reshape(-1) 213 | 214 | if len(binned_idx) < len(idx): 215 | last_bin = idx[len(binned_idx):] 216 | random.shuffle(last_bin) 217 | binned_idx = np.concatenate([binned_idx, last_bin]) 218 | 219 | return iter(torch.tensor(binned_idx).long()) 220 | 221 | def __len__(self): 222 | return len(self.idx) 223 | -------------------------------------------------------------------------------- /utils/display.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('agg') # Use non-interactive backend by default 3 | import matplotlib.pyplot as plt 4 | import time 5 | import numpy as np 6 | import sys 7 | 8 | 9 | def progbar(i, n, size=16): 10 | done = (i * size) // n 11 | bar = '' 12 | for i in range(size): 13 | bar += '█' if i <= done else '░' 14 | return bar 15 | 16 | 17 | def stream(message): 18 | sys.stdout.write(f"\r{message}") 19 | 20 | 21 | def simple_table(item_tuples): 22 | 23 | border_pattern = '+---------------------------------------' 24 | whitespace = ' ' 25 | 26 | headings, cells, = [], [] 27 | 28 | for item in item_tuples: 29 | 30 | heading, cell = str(item[0]), str(item[1]) 31 | 32 | pad_head = True if len(heading) < len(cell) else False 33 | 34 | pad = abs(len(heading) - len(cell)) 35 | pad = whitespace[:pad] 36 | 37 | pad_left = pad[:len(pad)//2] 38 | pad_right = pad[len(pad)//2:] 39 | 40 | if pad_head: 41 | heading = pad_left + heading + pad_right 42 | else: 43 | cell = pad_left + cell + pad_right 44 | 45 | headings += [heading] 46 | cells += [cell] 47 | 48 | border, head, body = '', '', '' 49 | 50 | for i in range(len(item_tuples)): 51 | 52 | temp_head = f'| {headings[i]} ' 53 | temp_body = f'| {cells[i]} ' 54 | 55 | border += border_pattern[:len(temp_head)] 56 | head += temp_head 57 | body += temp_body 58 | 59 | if i == len(item_tuples) - 1: 60 | head += '|' 61 | body += '|' 62 | border += '+' 63 | 64 | print(border) 65 | print(head) 66 | print(border) 67 | print(body) 68 | print(border) 69 | print(' ') 70 | 71 | 72 | def time_since(started): 73 | elapsed = time.time() - started 74 | m = int(elapsed // 60) 75 | s = int(elapsed % 60) 76 | if m >= 60: 77 | h = int(m // 60) 78 | m = m % 60 79 | return f'{h}h {m}m {s}s' 80 | else: 81 | return f'{m}m {s}s' 82 | 83 | 84 | def save_attention(attn, path): 85 | fig = plt.figure(figsize=(12, 6)) 86 | plt.imshow(attn.T, interpolation='nearest', aspect='auto') 87 | fig.savefig(path.parent/f'{path.stem}.png', bbox_inches='tight') 88 | plt.close(fig) 89 | 90 | 91 | def save_spectrogram(M, path, length=None): 92 | M = np.flip(M, axis=0) 93 | if length: M = M[:, :length] 94 | fig = plt.figure(figsize=(12, 6)) 95 | plt.imshow(M, interpolation='nearest', aspect='auto') 96 | fig.savefig(f'{path}.png', bbox_inches='tight') 97 | plt.close(fig) 98 | 99 | 100 | def plot(array): 101 | mpl.interactive(True) 102 | fig = plt.figure(figsize=(30, 5)) 103 | ax = fig.add_subplot(111) 104 | ax.xaxis.label.set_color('grey') 105 | ax.yaxis.label.set_color('grey') 106 | ax.xaxis.label.set_fontsize(23) 107 | ax.yaxis.label.set_fontsize(23) 108 | ax.tick_params(axis='x', colors='grey', labelsize=23) 109 | ax.tick_params(axis='y', colors='grey', labelsize=23) 110 | plt.plot(array) 111 | mpl.interactive(False) 112 | 113 | 114 | def plot_spec(M): 115 | mpl.interactive(True) 116 | M = np.flip(M, axis=0) 117 | plt.figure(figsize=(18,4)) 118 | plt.imshow(M, interpolation='nearest', aspect='auto') 119 | plt.show() 120 | mpl.interactive(False) 121 | 122 | -------------------------------------------------------------------------------- /utils/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def log_sum_exp(x): 7 | """ numerically stable log_sum_exp implementation that prevents overflow """ 8 | # TF ordering 9 | axis = len(x.size()) - 1 10 | m, _ = torch.max(x, dim=axis) 11 | m2, _ = torch.max(x, dim=axis, keepdim=True) 12 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 13 | 14 | 15 | # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py 16 | def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, 17 | log_scale_min=None, reduce=True): 18 | if log_scale_min is None: 19 | log_scale_min = float(np.log(1e-14)) 20 | y_hat = y_hat.permute(0,2,1) 21 | assert y_hat.dim() == 3 22 | assert y_hat.size(1) % 3 == 0 23 | nr_mix = y_hat.size(1) // 3 24 | 25 | # (B x T x C) 26 | y_hat = y_hat.transpose(1, 2) 27 | 28 | # unpack parameters. (B, T, num_mixtures) x 3 29 | logit_probs = y_hat[:, :, :nr_mix] 30 | means = y_hat[:, :, nr_mix:2 * nr_mix] 31 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 32 | 33 | # B x T x 1 -> B x T x num_mixtures 34 | y = y.expand_as(means) 35 | 36 | centered_y = y - means 37 | inv_stdv = torch.exp(-log_scales) 38 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) 39 | cdf_plus = torch.sigmoid(plus_in) 40 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) 41 | cdf_min = torch.sigmoid(min_in) 42 | 43 | # log probability for edge case of 0 (before scaling) 44 | # equivalent: torch.log(F.sigmoid(plus_in)) 45 | log_cdf_plus = plus_in - F.softplus(plus_in) 46 | 47 | # log probability for edge case of 255 (before scaling) 48 | # equivalent: (1 - F.sigmoid(min_in)).log() 49 | log_one_minus_cdf_min = -F.softplus(min_in) 50 | 51 | # probability for all other cases 52 | cdf_delta = cdf_plus - cdf_min 53 | 54 | mid_in = inv_stdv * centered_y 55 | # log probability in the center of the bin, to be used in extreme cases 56 | # (not actually used in our code) 57 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 58 | 59 | # tf equivalent 60 | """ 61 | log_probs = tf.where(x < -0.999, log_cdf_plus, 62 | tf.where(x > 0.999, log_one_minus_cdf_min, 63 | tf.where(cdf_delta > 1e-5, 64 | tf.log(tf.maximum(cdf_delta, 1e-12)), 65 | log_pdf_mid - np.log(127.5)))) 66 | """ 67 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value 68 | # for num_classes=65536 case? 1e-7? not sure.. 69 | inner_inner_cond = (cdf_delta > 1e-5).float() 70 | 71 | inner_inner_out = inner_inner_cond * \ 72 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ 73 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) 74 | inner_cond = (y > 0.999).float() 75 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 76 | cond = (y < -0.999).float() 77 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 78 | 79 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 80 | 81 | if reduce: 82 | return -torch.mean(log_sum_exp(log_probs)) 83 | else: 84 | return -log_sum_exp(log_probs).unsqueeze(-1) 85 | 86 | 87 | def sample_from_discretized_mix_logistic(y, log_scale_min=None): 88 | """ 89 | Sample from discretized mixture of logistic distributions 90 | Args: 91 | y (Tensor): B x C x T 92 | log_scale_min (float): Log scale minimum value 93 | Returns: 94 | Tensor: sample in range of [-1, 1]. 95 | """ 96 | if log_scale_min is None: 97 | log_scale_min = float(np.log(1e-14)) 98 | assert y.size(1) % 3 == 0 99 | nr_mix = y.size(1) // 3 100 | 101 | # B x T x C 102 | y = y.transpose(1, 2) 103 | logit_probs = y[:, :, :nr_mix] 104 | 105 | # sample mixture indicator from softmax 106 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 107 | temp = logit_probs.data - torch.log(- torch.log(temp)) 108 | _, argmax = temp.max(dim=-1) 109 | 110 | # (B, T) -> (B, T, nr_mix) 111 | one_hot = F.one_hot(argmax, nr_mix).float() 112 | # select logistic parameters 113 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 114 | log_scales = torch.clamp(torch.sum( 115 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) 116 | # sample from logistic & clip to interval 117 | # we don't actually round to the nearest 8bit value when sampling 118 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) 119 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 120 | 121 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 122 | 123 | return x 124 | 125 | ''' 126 | def to_one_hot(tensor, n, fill_with=1.): 127 | # we perform one hot encore with respect to the last axis 128 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 129 | if tensor.is_cuda: 130 | one_hot = one_hot.cuda() 131 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 132 | return one_hot''' 133 | -------------------------------------------------------------------------------- /utils/dsp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import librosa 4 | from utils import hparams as hp 5 | from scipy.signal import lfilter 6 | 7 | 8 | def label_2_float(x, bits): 9 | return 2 * x / (2**bits - 1.) - 1. 10 | 11 | 12 | def float_2_label(x, bits): 13 | assert abs(x).max() <= 1.0 14 | x = (x + 1.) * (2**bits - 1) / 2 15 | return x.clip(0, 2**bits - 1) 16 | 17 | 18 | def load_wav(path): 19 | return librosa.load(path, sr=hp.sample_rate)[0] 20 | 21 | 22 | def save_wav(x, path): 23 | librosa.output.write_wav(path, x.astype(np.float32), sr=hp.sample_rate) 24 | 25 | 26 | def split_signal(x): 27 | unsigned = x + 2**15 28 | coarse = unsigned // 256 29 | fine = unsigned % 256 30 | return coarse, fine 31 | 32 | 33 | def combine_signal(coarse, fine): 34 | return coarse * 256 + fine - 2**15 35 | 36 | 37 | def encode_16bits(x): 38 | return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) 39 | 40 | 41 | def linear_to_mel(spectrogram): 42 | return librosa.feature.melspectrogram( 43 | S=spectrogram, sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin) 44 | 45 | ''' 46 | def build_mel_basis(): 47 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin) 48 | ''' 49 | 50 | def normalize(S): 51 | return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1) 52 | 53 | 54 | def denormalize(S): 55 | return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db 56 | 57 | 58 | def amp_to_db(x): 59 | return 20 * np.log10(np.maximum(1e-5, x)) 60 | 61 | 62 | def db_to_amp(x): 63 | return np.power(10.0, x * 0.05) 64 | 65 | 66 | def spectrogram(y): 67 | D = stft(y) 68 | S = amp_to_db(np.abs(D)) - hp.ref_level_db 69 | return normalize(S) 70 | 71 | 72 | def melspectrogram(y): 73 | D = stft(y) 74 | S = amp_to_db(linear_to_mel(np.abs(D))) 75 | return normalize(S) 76 | 77 | 78 | def stft(y): 79 | return librosa.stft( 80 | y=y, 81 | n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 82 | 83 | 84 | def pre_emphasis(x): 85 | return lfilter([1, -hp.preemphasis], [1], x) 86 | 87 | 88 | def de_emphasis(x): 89 | return lfilter([1], [1, -hp.preemphasis], x) 90 | 91 | 92 | def encode_mu_law(x, mu): 93 | mu = mu - 1 94 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) 95 | return np.floor((fx + 1) / 2 * mu + 0.5) 96 | 97 | 98 | def decode_mu_law(y, mu, from_labels=True): 99 | # TODO: get rid of log2 - makes no sense 100 | if from_labels: y = label_2_float(y, math.log2(mu)) 101 | mu = mu - 1 102 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) 103 | return x 104 | 105 | def reconstruct_waveform(mel, n_iter=32): 106 | """Uses Griffin-Lim phase reconstruction to convert from a normalized 107 | mel spectrogram back into a waveform.""" 108 | denormalized = denormalize(mel) 109 | amp_mel = db_to_amp(denormalized) 110 | S = librosa.feature.inverse.mel_to_stft( 111 | amp_mel, power=1, sr=hp.sample_rate, 112 | n_fft=hp.n_fft, fmin=hp.fmin) 113 | wav = librosa.core.griffinlim( 114 | S, n_iter=n_iter, 115 | hop_length=hp.hop_length, win_length=hp.win_length) 116 | return wav 117 | 118 | -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | def get_files(path: Union[str, Path], extension='.wav'): 5 | if isinstance(path, str): path = Path(path).expanduser().resolve() 6 | return list(path.rglob(f'*{extension}')) 7 | -------------------------------------------------------------------------------- /utils/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | class Paths: 6 | """Manages and configures the paths used by WaveRNN, Tacotron, and the data.""" 7 | def __init__(self, data_path, voc_id, tts_id): 8 | self.base = Path(__file__).parent.parent.expanduser().resolve() 9 | 10 | # Data Paths 11 | self.data = Path(data_path).expanduser().resolve() 12 | self.quant = self.data/'quant' 13 | self.mel = self.data/'mel' 14 | self.gta = self.data/'gta' 15 | 16 | # WaveRNN/Vocoder Paths 17 | self.voc_checkpoints = self.base/'checkpoints'/f'{voc_id}.wavernn' 18 | self.voc_latest_weights = self.voc_checkpoints/'latest_weights.pyt' 19 | self.voc_latest_optim = self.voc_checkpoints/'latest_optim.pyt' 20 | self.voc_output = self.base/'model_outputs'/f'{voc_id}.wavernn' 21 | self.voc_step = self.voc_checkpoints/'step.npy' 22 | self.voc_log = self.voc_checkpoints/'log.txt' 23 | 24 | # Tactron/TTS Paths 25 | self.tts_checkpoints = self.base/'checkpoints'/f'{tts_id}.tacotron' 26 | self.tts_latest_weights = self.tts_checkpoints/'latest_weights.pyt' 27 | self.tts_latest_optim = self.tts_checkpoints/'latest_optim.pyt' 28 | self.tts_output = self.base/'model_outputs'/f'{tts_id}.tacotron' 29 | self.tts_step = self.tts_checkpoints/'step.npy' 30 | self.tts_log = self.tts_checkpoints/'log.txt' 31 | self.tts_attention = self.tts_checkpoints/'attention' 32 | self.tts_mel_plot = self.tts_checkpoints/'mel_plots' 33 | 34 | self.create_paths() 35 | 36 | def create_paths(self): 37 | os.makedirs(self.data, exist_ok=True) 38 | os.makedirs(self.quant, exist_ok=True) 39 | os.makedirs(self.mel, exist_ok=True) 40 | os.makedirs(self.gta, exist_ok=True) 41 | os.makedirs(self.voc_checkpoints, exist_ok=True) 42 | os.makedirs(self.voc_output, exist_ok=True) 43 | os.makedirs(self.tts_checkpoints, exist_ok=True) 44 | os.makedirs(self.tts_output, exist_ok=True) 45 | os.makedirs(self.tts_attention, exist_ok=True) 46 | os.makedirs(self.tts_mel_plot, exist_ok=True) 47 | 48 | def get_tts_named_weights(self, name): 49 | """Gets the path for the weights in a named tts checkpoint.""" 50 | return self.tts_checkpoints/f'{name}_weights.pyt' 51 | 52 | def get_tts_named_optim(self, name): 53 | """Gets the path for the optimizer state in a named tts checkpoint.""" 54 | return self.tts_checkpoints/f'{name}_optim.pyt' 55 | 56 | def get_voc_named_weights(self, name): 57 | """Gets the path for the weights in a named voc checkpoint.""" 58 | return self.voc_checkpoints/f'{name}_weights.pyt' 59 | 60 | def get_voc_named_optim(self, name): 61 | """Gets the path for the optimizer state in a named voc checkpoint.""" 62 | return self.voc_checkpoints/f'{name}_optim.pyt' 63 | 64 | 65 | -------------------------------------------------------------------------------- /utils/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /utils/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from utils.text import cleaners 4 | from utils.text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s is not '_' and s is not '~' 75 | -------------------------------------------------------------------------------- /utils/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | -------------------------------------------------------------------------------- /utils/text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /utils/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /utils/text/recipes.py: -------------------------------------------------------------------------------- 1 | from utils.files import get_files 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | 6 | def ljspeech(path: Union[str, Path]): 7 | csv_file = get_files(path, extension='.csv') 8 | 9 | assert len(csv_file) == 1 10 | 11 | text_dict = {} 12 | 13 | with open(csv_file[0], encoding='utf-8') as f : 14 | for line in f : 15 | split = line.split('|') 16 | text_dict[split[0]] = split[-1] 17 | 18 | return text_dict -------------------------------------------------------------------------------- /utils/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from utils.text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | --------------------------------------------------------------------------------