├── .gitignore
├── LICENSE
├── README.md
├── audio.py
├── dataset.py
├── distributions.py
├── hparams.py
├── inputs
└── sample.wav
├── loss_function.py
├── lrschedule.py
├── model.py
├── preprocess.py
├── requirements.txt
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Gary Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WaveRNN-Pytorch
2 | This repository contains Fatcord's [Alternative](https://github.com/fatchord/WaveRNN) WaveRNN (Faster training), which contains a fast-training, small GPU memory implementation of WaveRNN vocoder.
3 |
4 | # Model Pruning and Real Time CPU Inference
5 | See geneing's awesome fork that has model pruning, export to C++ and real time inference on CPU: https://github.com/geneing/WaveRNN-Pytorch.
6 |
7 |
8 | # Highlights
9 | * support raw audio wav modelling (via a single Beta Distribution)
10 | * relatively fast synthesis speed without much optimization yet (around 2000 samples/sec on GTX 1060 Ti, 16 GB ram, i5 processor)
11 | * support Fatcord's original quantized (9-bit) wav modelling
12 |
13 | # Audio Samples
14 | 1. [Obama & Bernie Sanders](https://soundcloud.com/gary-wang-23/sets/obama_bernie_fun) See this repo in action!
15 |
16 | 2. [10-bit audio](https://soundcloud.com/gary-wang-23/sets/wavernn-pytorch-10-bit-raw-audio-200k) on held-out testing data from LJSpeech. This model sounds and trains pretty close to 9 bit. We want the higher bit the better.
17 |
18 | 3. [9-bit audio](https://soundcloud.com/gary-wang-23/sets/wave_rnn_9_bit_11k_step) on held-out testing data from LJSpeech. This model trains the fastest (this is around 130 epochs)
19 |
20 | 4. [Single beta distribution](https://soundcloud.com/gary-wang-23/sets/wavernn-samples) on held-out testing data from LjSpeech. This is trained with the single Beta distribution.
21 |
22 | # Pretrained Checkpoints
23 | 1. [Single Beta Distribution](https://drive.google.com/open?id=138i0MtEkDqLM6fmBniQloEMtMlCHgJha) trained for 112k. Make sure to change `hparams.input_type` to `raw`.
24 | 2. [9-bit quantized audio](https://drive.google.com/open?id=114Xk3P9dD-_e2W8jmiKSpOX1UGb7qem3) trained for 11k, or around 130 epochs, can be trained further. Make sure to change `hparams.input_type` to `bits`.
25 | 3. [10-bit quantized audio](https://drive.google.com/open?id=1djWm62tHIndopyS5spkHf68lI6-h5a3H). To ensure your model is built properly, download the `hparams.py` [here](https://drive.google.com/open?id=1nXSW4u01bEbUkRW4Vd3IQ6soBAXPg6aw), either replace this with your local `hparams.py` file or note and update any changes.
26 |
27 |
28 |
29 |
30 | # Requirements
31 | * Python 3
32 | * CUDA >=8.0
33 | * PyTorch >= v0.4.1
34 |
35 | # Installation
36 | Ensure above requirements are met.
37 |
38 | ```
39 | git clone https://github.com/G-Wang/WaveRNN-Pytorch.git
40 | cd WaveRNN-Pytorch
41 | pip install -r requirements.txt
42 | ```
43 |
44 | # Usage
45 | ## 1. Adjusting Hyperparameters
46 | Before running scripts, one can adjust hyperparameters in `hparams.py`.
47 |
48 | Some hyperparameters that you might want to adjust:
49 | * `fix_learning_rate` The model is robust enough to learn well with a fix learning rate of `1e-4`, I suggest you try this setting for fastest training, you can decrease this down to `5e-6` for final step refinement. Set this to `None` to train with learning rate schedule instead
50 | * `input_type` (best performing ones are currently `bits` and `raw`, see `hparams.py` for more details)
51 | * `batch_size`
52 | * `save_every_step` (checkpoint saving frequency)
53 | * `evaluate_every_step` (evaluation frequency)
54 | * `seq_len_factor` (sequence length of training audio, the longer the more GPU it takes)
55 | ## 2. Preprocessing
56 | This function processes raw wav files into corresponding mel-spectrogram and wav files according to the audio processing hyperparameters.
57 |
58 | Example usage:
59 | ```
60 | python preprocess.py /path/to/my/wav/files
61 | ```
62 | This will process all the `.wav` files in the folder `/path/to/my/wav/files` and save them in the default local directory called `data_dir`.
63 |
64 | Can include `--output_dir` to specify a specific directory to store the processed outputs.
65 |
66 | ## 3. Training
67 | Start training process. checkpoints are by default stored in the local directory `checkpoints`.
68 | The script will automatically save a checkpoint when terminated by `crtl + c`.
69 |
70 |
71 | Example 1: starting a new model for training
72 | ```
73 | python train.py data_dir
74 | ```
75 | `data_dir` is the directory containing the processed files.
76 |
77 | Example 2: Restoring training from checkpoint
78 | ```
79 | python train.py data_dir --checkpoint=checkpoints/checkpoint0010000.pth
80 | ```
81 | Evaluation `.wav` files and plots are saved in `checkpoints/eval`.
82 |
83 | # WIP
84 | - [ ] optimize learning rate schedule
85 | - [ ] optimize training hyperparameters (seq_len and batch_size)
86 | - [ ] batch generation for synthesis speedup
87 | - [ ] model pruning
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.filters
3 | import math
4 | import numpy as np
5 | from scipy import signal
6 | from hparams import hparams
7 | from scipy.io import wavfile
8 |
9 | # r9r9 preprocessing
10 | import lws
11 |
12 |
13 | def load_wav(path):
14 | return librosa.load(path, sr=hparams.sample_rate)[0]
15 |
16 | def save_wav(wav, path):
17 | wav = wav * 32767 / max(0.01, np.max(np.abs(wav)))
18 | wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
19 |
20 |
21 | def preemphasis(x):
22 | from nnmnkwii.preprocessing import preemphasis
23 | return preemphasis(x, hparams.preemphasis)
24 |
25 |
26 | def inv_preemphasis(x):
27 | from nnmnkwii.preprocessing import inv_preemphasis
28 | return inv_preemphasis(x, hparams.preemphasis)
29 |
30 |
31 | def spectrogram(y):
32 | D = _lws_processor().stft(preemphasis(y)).T
33 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
34 | return _normalize(S)
35 |
36 |
37 | def inv_spectrogram(spectrogram):
38 | '''Converts spectrogram to waveform using librosa'''
39 | S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear
40 | processor = _lws_processor()
41 | D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
42 | y = processor.istft(D).astype(np.float32)
43 | return inv_preemphasis(y)
44 |
45 |
46 | def melspectrogram(y):
47 | D = _lws_processor().stft(preemphasis(y)).T
48 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
49 | if not hparams.allow_clipping_in_normalization:
50 | assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
51 | return _normalize(S)
52 |
53 |
54 | def _lws_processor():
55 | return lws.lws(hparams.fft_size, hparams.hop_size, mode="speech")
56 |
57 |
58 | # Conversions:
59 |
60 |
61 | _mel_basis = None
62 |
63 |
64 | def _linear_to_mel(spectrogram):
65 | global _mel_basis
66 | if _mel_basis is None:
67 | _mel_basis = _build_mel_basis()
68 | return np.dot(_mel_basis, spectrogram)
69 |
70 |
71 | def _build_mel_basis():
72 | if hparams.fmax is not None:
73 | assert hparams.fmax <= hparams.sample_rate // 2
74 | return librosa.filters.mel(hparams.sample_rate, hparams.fft_size,
75 | fmin=hparams.fmin, fmax=hparams.fmax,
76 | n_mels=hparams.num_mels)
77 |
78 |
79 | def _amp_to_db(x):
80 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
81 | return 20 * np.log10(np.maximum(min_level, x))
82 |
83 |
84 | def _db_to_amp(x):
85 | return np.power(10.0, x * 0.05)
86 |
87 |
88 | def _normalize(S):
89 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
90 |
91 |
92 | def _denormalize(S):
93 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
94 |
95 |
96 | # Fatcord's preprocessing
97 | def quantize(x):
98 | """quantize audio signal
99 |
100 | """
101 | quant = (x + 1.) * (2**hparams.bits - 1) / 2
102 | return quant.astype(np.int)
103 |
104 |
105 | # testing
106 | def test_everything():
107 | wav = np.random.randn(12000,)
108 | mel = melspectrogram(wav)
109 | spec = spectrogram(wav)
110 | quant = quantize(wav)
111 | print(wav.shape, mel.shape, spec.shape, quant.shape)
112 | print(quant.max(), quant.min(), mel.max(), mel.min(), spec.max(), spec.min())
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import os
4 |
5 | import torch
6 | from torch.utils.data import DataLoader, Dataset
7 | from hparams import hparams as hp
8 | from utils import mulaw_quantize, inv_mulaw_quantize
9 | import pickle
10 |
11 |
12 | class AudiobookDataset(Dataset):
13 | def __init__(self, data_path):
14 | self.path = os.path.join(data_path, "")
15 | with open(os.path.join(self.path,'dataset_ids.pkl'), 'rb') as f:
16 | self.metadata = pickle.load(f)
17 | self.mel_path = os.path.join(data_path, "mel")
18 | self.wav_path = os.path.join(data_path, "wav")
19 | self.test_path = os.path.join(data_path, "test")
20 |
21 | def __getitem__(self, index):
22 | file = self.metadata[index]
23 | m = np.load(os.path.join(self.mel_path,'{}.npy'.format(file)))
24 | x = np.load(os.path.join(self.wav_path,'{}.npy'.format(file)))
25 | return m, x
26 |
27 | def __len__(self):
28 | return len(self.metadata)
29 |
30 |
31 | def raw_collate(batch) :
32 | """collate function used for raw wav forms, such as using beta/guassian/mixture of logistic
33 | """
34 |
35 | pad = 2
36 | mel_win = hp.seq_len // hp.hop_size + 2 * pad
37 | max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch]
38 | mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
39 | sig_offsets = [(offset + pad) * hp.hop_size for offset in mel_offsets]
40 |
41 | mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] \
42 | for i, x in enumerate(batch)]
43 |
44 | coarse = [x[1][sig_offsets[i]:sig_offsets[i] + hp.seq_len + 1] \
45 | for i, x in enumerate(batch)]
46 |
47 | mels = np.stack(mels).astype(np.float32)
48 | coarse = np.stack(coarse).astype(np.float32)
49 |
50 | mels = torch.FloatTensor(mels)
51 | coarse = torch.FloatTensor(coarse)
52 |
53 | x_input = coarse[:,:hp.seq_len]
54 |
55 | y_coarse = coarse[:, 1:]
56 |
57 | return x_input, mels, y_coarse
58 |
59 |
60 |
61 | def discrete_collate(batch) :
62 | """collate function used for discrete wav output, such as 9-bit, mulaw-discrete, etc.
63 | """
64 |
65 | pad = 2
66 | mel_win = hp.seq_len // hp.hop_size + 2 * pad
67 | max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch]
68 | mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
69 | sig_offsets = [(offset + pad) * hp.hop_size for offset in mel_offsets]
70 |
71 | mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] \
72 | for i, x in enumerate(batch)]
73 |
74 | coarse = [x[1][sig_offsets[i]:sig_offsets[i] + hp.seq_len + 1] \
75 | for i, x in enumerate(batch)]
76 |
77 | mels = np.stack(mels).astype(np.float32)
78 | coarse = np.stack(coarse).astype(np.int64)
79 |
80 | mels = torch.FloatTensor(mels)
81 | coarse = torch.LongTensor(coarse)
82 | if hp.input_type == 'bits':
83 | x_input = 2 * coarse[:, :hp.seq_len].float() / (2**hp.bits - 1.) - 1.
84 | elif hp.input_type == 'mulaw':
85 | x_input = inv_mulaw_quantize(coarse[:, :hp.seq_len], hp.mulaw_quantize_channels)
86 |
87 | y_coarse = coarse[:, 1:]
88 |
89 | return x_input, mels, y_coarse
90 |
91 |
92 | def no_test_raw_collate():
93 | import matplotlib.pyplot as plt
94 | from test_utils import plot, plot_spec
95 | data_id_path = "data_dir/"
96 | data_path = "data_dir/"
97 | print(hp.seq_len)
98 |
99 | with open('{}dataset_ids.pkl'.format(data_id_path), 'rb') as f:
100 | dataset_ids = pickle.load(f)
101 | dataset = AudiobookDataset(data_path)
102 | print(len(dataset))
103 |
104 | data_loader = DataLoader(dataset, collate_fn=raw_collate, batch_size=32,
105 | num_workers=0, shuffle=True)
106 |
107 | x, m, y = next(iter(data_loader))
108 | print(x.shape, m.shape, y.shape)
109 | plot(x.numpy()[0])
110 | plot(y.numpy()[0])
111 | plot_spec(m.numpy()[0])
112 |
113 |
114 | def test_discrete_collate():
115 | import matplotlib.pyplot as plt
116 | from test_utils import plot, plot_spec
117 | data_id_path = "data_dir/"
118 | data_path = "data_dir/"
119 | print(hp.seq_len)
120 |
121 | with open('{}dataset_ids.pkl'.format(data_id_path), 'rb') as f:
122 | dataset_ids = pickle.load(f)
123 | dataset = AudiobookDataset(data_path)
124 | print(len(dataset))
125 |
126 | data_loader = DataLoader(dataset, collate_fn=discrete_collate, batch_size=32,
127 | num_workers=0, shuffle=True)
128 |
129 | x, m, y = next(iter(data_loader))
130 | print(x.shape, m.shape, y.shape)
131 | plot(x.numpy()[0])
132 | plot(y.numpy()[0])
133 | plot_spec(m.numpy()[0])
134 |
135 |
136 |
137 | def no_test_dataset():
138 | data_id_path = "data_dir/"
139 | data_path = "data_dir/"
140 | print(hp.seq_len)
141 | dataset = AudiobookDataset(data_path)
142 |
--------------------------------------------------------------------------------
/distributions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.distributions import Beta, Normal
8 | from hparams import hparams as hp
9 |
10 | def sample_from_beta_dist(y_hat):
11 | """
12 | y_hat (batch_size x seq_len x 2):
13 |
14 | """
15 | # take exponentional to ensure positive
16 | loc_y = y_hat.exp()
17 | alpha = loc_y[:,:,0].unsqueeze(-1)
18 | beta = loc_y[:,:,1].unsqueeze(-1)
19 | dist = Beta(alpha, beta)
20 | sample = dist.sample()
21 | # rescale sample from [0,1] to [-1, 1]
22 | sample = 2.0*sample-1.0
23 | return sample
24 |
25 |
26 | def beta_mle_loss(y_hat, y, reduce=True):
27 | """y_hat (batch_size x seq_len x 2)
28 | y (batch_size x seq_len x 1)
29 |
30 | """
31 | # take exponentional to ensure positive
32 | loc_y = y_hat.exp()
33 | alpha = loc_y[:,:,0].unsqueeze(-1)
34 | beta = loc_y[:,:,1].unsqueeze(-1)
35 | dist = Beta(alpha, beta)
36 | # rescale y to be between
37 | y = (y + 1.0)/2.0
38 | # note that we will get inf loss if y == 0 or 1.0 exactly, so we will clip it slightly just in case
39 | y = torch.clamp(y, 1e-5, 0.99999)
40 | # compute logprob
41 | loss = -dist.log_prob(y).squeeze(-1)
42 | if reduce:
43 | return loss.mean()
44 | else:
45 | return loss
46 |
47 |
48 | def log_sum_exp(x):
49 | """ numerically stable log_sum_exp implementation that prevents overflow """
50 | # TF ordering
51 | axis = len(x.size()) - 1
52 | m, _ = torch.max(x, dim=axis)
53 | m2, _ = torch.max(x, dim=axis, keepdim=True)
54 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
55 |
56 |
57 | def discretized_mix_logistic_loss(y_hat, y, num_classes=256,
58 | log_scale_min=hp.log_scale_min, reduce=True):
59 | """Discretized mixture of logistic distributions loss
60 |
61 | Note that it is assumed that input is scaled to [-1, 1].
62 |
63 | Args:
64 | y_hat (Tensor): Predicted output (B x T x C)
65 | y (Tensor): Target (B x T x 1).
66 | num_classes (int): Number of classes
67 | log_scale_min (float): Log scale minimum value
68 | reduce (bool): If True, the losses are averaged or summed for each
69 | minibatch.
70 |
71 | Returns
72 | Tensor: loss
73 | """
74 | y_hat = y_hat.permute(0,2,1)
75 | assert y_hat.dim() == 3
76 | assert y_hat.size(1) % 3 == 0
77 | nr_mix = y_hat.size(1) // 3
78 |
79 | # (B x T x C)
80 | y_hat = y_hat.transpose(1, 2)
81 |
82 | # unpack parameters. (B, T, num_mixtures) x 3
83 | logit_probs = y_hat[:, :, :nr_mix]
84 | means = y_hat[:, :, nr_mix:2 * nr_mix]
85 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
86 |
87 | # B x T x 1 -> B x T x num_mixtures
88 | y = y.expand_as(means)
89 |
90 | centered_y = y - means
91 | inv_stdv = torch.exp(-log_scales)
92 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
93 | cdf_plus = torch.sigmoid(plus_in)
94 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
95 | cdf_min = torch.sigmoid(min_in)
96 |
97 | # log probability for edge case of 0 (before scaling)
98 | # equivalent: torch.log(F.sigmoid(plus_in))
99 | log_cdf_plus = plus_in - F.softplus(plus_in)
100 |
101 | # log probability for edge case of 255 (before scaling)
102 | # equivalent: (1 - F.sigmoid(min_in)).log()
103 | log_one_minus_cdf_min = -F.softplus(min_in)
104 |
105 | # probability for all other cases
106 | cdf_delta = cdf_plus - cdf_min
107 |
108 | mid_in = inv_stdv * centered_y
109 | # log probability in the center of the bin, to be used in extreme cases
110 | # (not actually used in our code)
111 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
112 |
113 | # tf equivalent
114 | """
115 | log_probs = tf.where(x < -0.999, log_cdf_plus,
116 | tf.where(x > 0.999, log_one_minus_cdf_min,
117 | tf.where(cdf_delta > 1e-5,
118 | tf.log(tf.maximum(cdf_delta, 1e-12)),
119 | log_pdf_mid - np.log(127.5))))
120 | """
121 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
122 | # for num_classes=65536 case? 1e-7? not sure..
123 | inner_inner_cond = (cdf_delta > 1e-5).float()
124 |
125 | inner_inner_out = inner_inner_cond * \
126 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
127 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
128 | inner_cond = (y > 0.999).float()
129 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
130 | cond = (y < -0.999).float()
131 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
132 |
133 | log_probs = log_probs + F.log_softmax(logit_probs, -1)
134 |
135 | if reduce:
136 | return -torch.sum(log_sum_exp(log_probs))
137 | else:
138 | return -log_sum_exp(log_probs).unsqueeze(-1)
139 |
140 |
141 | def to_one_hot(tensor, n, fill_with=1.):
142 | # we perform one hot encore with respect to the last axis
143 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
144 | if tensor.is_cuda:
145 | one_hot = one_hot.cuda()
146 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
147 | return one_hot
148 |
149 |
150 | def sample_from_discretized_mix_logistic(y, log_scale_min=hp.log_scale_min):
151 | """
152 | Sample from discretized mixture of logistic distributions
153 |
154 | Args:
155 | y (Tensor): B x C x T
156 | log_scale_min (float): Log scale minimum value
157 |
158 | Returns:
159 | Tensor: sample in range of [-1, 1].
160 | """
161 | assert y.size(1) % 3 == 0
162 | nr_mix = y.size(1) // 3
163 |
164 | # B x T x C
165 | y = y.transpose(1, 2)
166 | logit_probs = y[:, :, :nr_mix]
167 |
168 | # sample mixture indicator from softmax
169 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
170 | temp = logit_probs.data - torch.log(- torch.log(temp))
171 | _, argmax = temp.max(dim=-1)
172 |
173 | # (B, T) -> (B, T, nr_mix)
174 | one_hot = to_one_hot(argmax, nr_mix)
175 | # select logistic parameters
176 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
177 | log_scales = torch.clamp(torch.sum(
178 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
179 | # sample from logistic & clip to interval
180 | # we don't actually round to the nearest 8bit value when sampling
181 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
182 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
183 |
184 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
185 |
186 | return x
187 |
188 |
189 | # add gaussian from clarinet implementation:https://raw.githubusercontent.com/ksw0306/ClariNet/master/loss.py
190 | def gaussian_loss(y_hat, y, log_std_min=-7.0, reduce=True):
191 | """y_hat (batch_size x seq_len x 2)
192 | y (batch_size x seq_len x 1)
193 | """
194 | assert y_hat.dim() == 3
195 | assert y_hat.size(2) == 2
196 |
197 | mean = y_hat[:, :, :1]
198 | log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
199 |
200 | log_probs = -0.5 * (- math.log(2.0 * math.pi) - 2. * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
201 |
202 | if reduce:
203 | return log_probs.squeeze().mean()
204 | else:
205 | return log_probs.squeeze()
206 |
207 |
208 | def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.):
209 | """y_hat (batch_size x seq_len x 2)
210 | y (batch_size x seq_len x 1)
211 | """
212 | assert y_hat.size(2) == 2
213 |
214 | mean = y_hat[:, :, :1]
215 | log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
216 | dist = Normal(mean, torch.exp(log_std))
217 | sample = dist.sample()
218 | sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
219 | del dist
220 | return sample
221 |
222 |
223 |
224 |
225 |
226 | def test_gaussian():
227 |
228 | y_hat = torch.rand(16, 120, 2)
229 | y_true = torch.rand(16, 120, 1)
230 | out = sample_from_gaussian(y_hat)
231 | loss = gaussian_loss(y_hat, y_true)
232 | loss_mean = loss.mean()
233 | print(out.shape, loss.shape, loss_mean.item())
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | class hparams:
2 |
3 | # option parameters
4 |
5 | # Input type:
6 | # 1. raw [-1, 1]
7 | # 2. mixture [-1, 1]
8 | # 3. bits [0, 512]
9 | # 4. mulaw[0, mulaw_quantize_channels]
10 | #
11 | input_type = 'raw'
12 | #
13 | # distribution type, currently supports only 'beta' and 'mixture'
14 | distribution = 'gaussian' # or "mixture"
15 | log_scale_min = -32.23619130191664 # = float(np.log(1e-7))
16 | quantize_channels = 65536 # quantize channel used for compute loss for mixture of logistics
17 | #
18 | # for Fatcord's original 9 bit audio, specify the audio bit rate. Note this corresponds to network output
19 | # of size 2**bits, so 9 bits would be 512 output, etc.
20 | bits = 10
21 | # for mu-law
22 | mulaw_quantize_channels = 512
23 | # note: r9r9's deepvoice3 preprocessing is used instead of Fatcord's original.
24 | #--------------
25 | # audio processing parameters
26 | num_mels = 80
27 | fmin = 125
28 | fmax = 7600
29 | fft_size = 1024
30 | hop_size = 256
31 | win_length = 1024
32 | sample_rate = 22050
33 | preemphasis = 0.97
34 | min_level_db = -100
35 | ref_level_db = 20
36 | rescaling = False
37 | rescaling_max = 0.999
38 | allow_clipping_in_normalization = True
39 | #----------------
40 | #
41 | #----------------
42 | # model parameters
43 | rnn_dims = 600
44 | fc_dims = 512
45 | pad = 2
46 | # note upsample factors must multiply out to be equal to hop_size, so adjust
47 | # if necessary (i.e 4 x 4 x 16 = 256)
48 | upsample_factors = (4, 4, 16)
49 | compute_dims = 128
50 | res_out_dims = 128
51 | res_blocks = 10
52 | #----------------
53 | #
54 | #----------------
55 | # training parameters
56 | batch_size = 32
57 | nepochs = 5000
58 | save_every_step = 10000
59 | evaluate_every_step = 5000
60 | # seq_len_factor can be adjusted to increase training sequence length (will increase GPU usage)
61 | seq_len_factor = 5
62 | seq_len = seq_len_factor * hop_size
63 | grad_norm = 10
64 | #learning rate parameters
65 | initial_learning_rate=1e-3
66 | lr_schedule_type = 'step' # or 'noam'
67 | # for noam learning rate schedule
68 | noam_warm_up_steps = 2000 * (batch_size // 16)
69 | # for step learning rate schedule
70 | step_gamma = 0.5
71 | lr_step_interval = 15000
72 |
73 | adam_beta1=0.9
74 | adam_beta2=0.999
75 | adam_eps=1e-8
76 | amsgrad=False
77 | weight_decay = 0.0
78 | fix_learning_rate = None # modify if one wants to use a fixed learning rate, else set to None to use noam learning rate
79 | #-----------------
80 |
--------------------------------------------------------------------------------
/inputs/sample.wav:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | some_files/20180510_mixture_lj_checkpoint_step000320000_ema.wav at master · G-Wang/some_files · GitHub
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
532 |
533 |
534 |
535 |
536 |
537 |
Permalink
538 |
539 |
540 |
541 |
542 |
543 |
544 |
Join GitHub today
549 |
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.
550 |
Sign up
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
622 |
623 |
634 |
635 |
some_files / 20180510_mixture_lj_checkpoint_step000320000_ema.wav
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 | 13da440
645 |
646 | Jul 25, 2018
647 |
648 |
653 |
654 |
655 |
656 |
657 |
658 |
659 | 1 contributor
660 |
661 |
662 |
668 |
669 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
706 |
707 |
708 |
709 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 | You can’t perform that action at this time.
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
You signed in with another tab or window. Reload to refresh your session.
787 |
You signed out in another tab or window. Reload to refresh your session.
788 |
789 |
790 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
815 |
816 |
817 | Press h to open a hovercard with more details.
818 |
819 |
820 |
821 |
822 |
823 |
824 |
--------------------------------------------------------------------------------
/loss_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 |
5 | def nll_loss(y_hat, y, reduce=True):
6 | y_hat = y_hat.permute(0,2,1)
7 | y = y.squeeze(-1)
8 | loss = F.nll_loss(y_hat, y)
9 | return loss
10 |
11 | def test_loss():
12 | yhat = torch.rand(16, 100, 54)
13 | y = torch.rand(16, 100, 1)
14 | loss = nll_loss(yhat, y.squeeze(-1))
--------------------------------------------------------------------------------
/lrschedule.py:
--------------------------------------------------------------------------------
1 | # reference: https://raw.githubusercontent.com/r9y9/wavenet_vocoder/master/lrschedule.py
2 |
3 | import numpy as np
4 |
5 |
6 | # https://github.com/tensorflow/tensor2tensor/issues/280#issuecomment-339110329
7 | def noam_learning_rate_decay(init_lr, global_step, warmup_steps=4000):
8 | # Noam scheme from tensor2tensor:
9 | warmup_steps = float(warmup_steps)
10 | step = global_step + 1.
11 | lr = init_lr * warmup_steps**0.5 * np.minimum(
12 | step * warmup_steps**-1.5, step**-0.5)
13 | return lr
14 |
15 |
16 | def step_learning_rate_decay(init_lr, global_step,
17 | anneal_rate=0.98,
18 | anneal_interval=30000):
19 | return init_lr * anneal_rate ** (global_step // anneal_interval)
20 |
21 |
22 | def cyclic_cosine_annealing(init_lr, global_step, T, M):
23 | """Cyclic cosine annealing
24 |
25 | https://arxiv.org/pdf/1704.00109.pdf
26 |
27 | Args:
28 | init_lr (float): Initial learning rate
29 | global_step (int): Current iteration number
30 | T (int): Total iteration number (i,e. nepoch)
31 | M (int): Number of ensembles we want
32 |
33 | Returns:
34 | float: Annealed learning rate
35 | """
36 | TdivM = T // M
37 | return init_lr / 2.0 * (np.cos(np.pi * ((global_step - 1) % TdivM) / TdivM) + 1.0)
38 |
39 |
40 | def test_noam():
41 | lr = 1e-3
42 | init_lr = 1e-3
43 | for i in range(50000):
44 | print(i, lr)
45 | lr = noam_learning_rate_decay(init_lr, i)
46 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from hparams import hparams as hp
5 | from torch.utils.data import DataLoader, Dataset
6 | from distributions import *
7 | from utils import num_params, mulaw_quantize, inv_mulaw_quantize
8 |
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | class ResBlock(nn.Module) :
13 | def __init__(self, dims) :
14 | super().__init__()
15 | self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
16 | self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
17 | self.batch_norm1 = nn.BatchNorm1d(dims)
18 | self.batch_norm2 = nn.BatchNorm1d(dims)
19 |
20 | def forward(self, x) :
21 | residual = x
22 | x = self.conv1(x)
23 | x = self.batch_norm1(x)
24 | x = F.relu(x)
25 | x = self.conv2(x)
26 | x = self.batch_norm2(x)
27 | return x + residual
28 |
29 | class MelResNet(nn.Module) :
30 | def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims) :
31 | super().__init__()
32 | self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=5, bias=False)
33 | self.batch_norm = nn.BatchNorm1d(compute_dims)
34 | self.layers = nn.ModuleList()
35 | for i in range(res_blocks) :
36 | self.layers.append(ResBlock(compute_dims))
37 | self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
38 |
39 | def forward(self, x) :
40 | x = self.conv_in(x)
41 | x = self.batch_norm(x)
42 | x = F.relu(x)
43 | for f in self.layers : x = f(x)
44 | x = self.conv_out(x)
45 | return x
46 |
47 | class Stretch2d(nn.Module) :
48 | def __init__(self, x_scale, y_scale) :
49 | super().__init__()
50 | self.x_scale = x_scale
51 | self.y_scale = y_scale
52 |
53 | def forward(self, x) :
54 | b, c, h, w = x.size()
55 | x = x.unsqueeze(-1).unsqueeze(3)
56 | x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
57 | return x.view(b, c, h * self.y_scale, w * self.x_scale)
58 |
59 | class UpsampleNetwork(nn.Module) :
60 | def __init__(self, feat_dims, upsample_scales, compute_dims,
61 | res_blocks, res_out_dims, pad) :
62 | super().__init__()
63 | total_scale = np.cumproduct(upsample_scales)[-1]
64 | self.indent = pad * total_scale
65 | self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims)
66 | self.resnet_stretch = Stretch2d(total_scale, 1)
67 | self.up_layers = nn.ModuleList()
68 | for scale in upsample_scales :
69 | k_size = (1, scale * 2 + 1)
70 | padding = (0, scale)
71 | stretch = Stretch2d(scale, 1)
72 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
73 | conv.weight.data.fill_(1. / k_size[1])
74 | self.up_layers.append(stretch)
75 | self.up_layers.append(conv)
76 |
77 | def forward(self, m) :
78 | aux = self.resnet(m).unsqueeze(1)
79 | aux = self.resnet_stretch(aux)
80 | aux = aux.squeeze(1)
81 | m = m.unsqueeze(1)
82 | for f in self.up_layers : m = f(m)
83 | m = m.squeeze(1)[:, :, self.indent:-self.indent]
84 | return m.transpose(1, 2), aux.transpose(1, 2)
85 |
86 |
87 | class Model(nn.Module) :
88 | def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
89 | feat_dims, compute_dims, res_out_dims, res_blocks):
90 | super().__init__()
91 | if hp.input_type == 'raw':
92 | self.n_classes = 2
93 | elif hp.input_type == 'mixture':
94 | # mixture requires multiple of 3, default at 10 component mixture, i.e 3 x 10 = 30
95 | self.n_classes = 30
96 | elif hp.input_type == 'mulaw':
97 | self.n_classes = hp.mulaw_quantize_channels
98 | elif hp.input_type == 'bits':
99 | self.n_classes = 2**bits
100 | else:
101 | raise ValueError("input_type: {hp.input_type} not supported")
102 | self.rnn_dims = rnn_dims
103 | self.aux_dims = res_out_dims // 4
104 | self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims,
105 | res_blocks, res_out_dims, pad)
106 | self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
107 | self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
108 | self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
109 | self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
110 | self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
111 | self.fc3 = nn.Linear(fc_dims, self.n_classes)
112 | num_params(self)
113 |
114 | def forward(self, x, mels) :
115 | bsize = x.size(0)
116 | h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
117 | h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
118 | mels, aux = self.upsample(mels)
119 |
120 | aux_idx = [self.aux_dims * i for i in range(5)]
121 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
122 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
123 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
124 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
125 |
126 | x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
127 | x = self.I(x)
128 | res = x
129 | x, _ = self.rnn1(x, h1)
130 |
131 | x = x + res
132 | res = x
133 | x = torch.cat([x, a2], dim=2)
134 | x, _ = self.rnn2(x, h2)
135 |
136 | x = x + res
137 | x = torch.cat([x, a3], dim=2)
138 | x = F.relu(self.fc1(x))
139 |
140 | x = torch.cat([x, a4], dim=2)
141 | x = F.relu(self.fc2(x))
142 |
143 | x = self.fc3(x)
144 |
145 | if hp.input_type == 'raw':
146 | return x
147 | elif hp.input_type == 'mixture':
148 | return x
149 | elif hp.input_type == 'bits' or hp.input_type == 'mulaw':
150 | return F.log_softmax(x, dim=-1)
151 | else:
152 | raise ValueError("input_type: {hp.input_type} not supported")
153 |
154 |
155 | def preview_upsampling(self, mels) :
156 | mels, aux = self.upsample(mels)
157 | return mels, aux
158 |
159 | def generate(self, mels) :
160 | self.eval()
161 | output = []
162 | rnn1 = self.get_gru_cell(self.rnn1)
163 | rnn2 = self.get_gru_cell(self.rnn2)
164 |
165 | with torch.no_grad() :
166 | x = torch.zeros(1, 1).cuda()
167 | h1 = torch.zeros(1, self.rnn_dims).cuda()
168 | h2 = torch.zeros(1, self.rnn_dims).cuda()
169 |
170 | mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
171 | mels, aux = self.upsample(mels)
172 |
173 | aux_idx = [self.aux_dims * i for i in range(5)]
174 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
175 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
176 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
177 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
178 |
179 | seq_len = mels.size(1)
180 |
181 | for i in tqdm(range(seq_len)) :
182 |
183 | m_t = mels[:, i, :]
184 | a1_t = a1[:, i, :]
185 | a2_t = a2[:, i, :]
186 | a3_t = a3[:, i, :]
187 | a4_t = a4[:, i, :]
188 |
189 | x = torch.cat([x, m_t, a1_t], dim=1)
190 | x = self.I(x)
191 | h1 = rnn1(x, h1)
192 |
193 | x = x + h1
194 | inp = torch.cat([x, a2_t], dim=1)
195 | h2 = rnn2(inp, h2)
196 |
197 | x = x + h2
198 | x = torch.cat([x, a3_t], dim=1)
199 | x = F.relu(self.fc1(x))
200 |
201 | x = torch.cat([x, a4_t], dim=1)
202 | x = F.relu(self.fc2(x))
203 | x = self.fc3(x)
204 | if hp.input_type == 'raw':
205 | if hp.distribution == 'beta':
206 | sample = sample_from_beta_dist(x.unsqueeze(0))
207 | elif hp.distribution == 'gaussian':
208 | sample = sample_from_gaussian(x.unsqueeze(0))
209 | elif hp.input_type == 'mixture':
210 | sample = sample_from_discretized_mix_logistic(x.unsqueeze(-1),hp.log_scale_min)
211 | elif hp.input_type == 'bits':
212 | posterior = F.softmax(x, dim=1).view(-1)
213 | distrib = torch.distributions.Categorical(posterior)
214 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
215 | elif hp.input_type == 'mulaw':
216 | posterior = F.softmax(x, dim=1).view(-1)
217 | distrib = torch.distributions.Categorical(posterior)
218 | sample = inv_mulaw_quantize(distrib.sample(), hp.mulaw_quantize_channels, True)
219 | output.append(sample.view(-1))
220 | x = torch.FloatTensor([[sample]]).cuda()
221 | output = torch.stack(output).cpu().numpy()
222 | self.train()
223 | return output
224 |
225 |
226 | def batch_generate(self, mels) :
227 | """mel should be of shape [batch_size x 80 x mel_length]
228 | """
229 | self.eval()
230 | output = []
231 | rnn1 = self.get_gru_cell(self.rnn1)
232 | rnn2 = self.get_gru_cell(self.rnn2)
233 | b_size = mels.shape[0]
234 | assert len(mels.shape) == 3, "mels should have shape [batch_size x 80 x mel_length]"
235 |
236 | with torch.no_grad() :
237 | x = torch.zeros(b_size, 1).cuda()
238 | h1 = torch.zeros(b_size, self.rnn_dims).cuda()
239 | h2 = torch.zeros(b_size, self.rnn_dims).cuda()
240 |
241 | mels = torch.FloatTensor(mels).cuda()
242 | mels, aux = self.upsample(mels)
243 |
244 | aux_idx = [self.aux_dims * i for i in range(5)]
245 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
246 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
247 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
248 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
249 |
250 | seq_len = mels.size(1)
251 |
252 | for i in tqdm(range(seq_len)) :
253 |
254 | m_t = mels[:, i, :]
255 | a1_t = a1[:, i, :]
256 | a2_t = a2[:, i, :]
257 | a3_t = a3[:, i, :]
258 | a4_t = a4[:, i, :]
259 |
260 | x = torch.cat([x, m_t, a1_t], dim=1)
261 | x = self.I(x)
262 | h1 = rnn1(x, h1)
263 |
264 | x = x + h1
265 | inp = torch.cat([x, a2_t], dim=1)
266 | h2 = rnn2(inp, h2)
267 |
268 | x = x + h2
269 | x = torch.cat([x, a3_t], dim=1)
270 | x = F.relu(self.fc1(x))
271 |
272 | x = torch.cat([x, a4_t], dim=1)
273 | x = F.relu(self.fc2(x))
274 | x = self.fc3(x)
275 | if hp.input_type == 'raw':
276 | sample = sample_from_beta_dist(x.unsqueeze(0))
277 | elif hp.input_type == 'mixture':
278 | sample = sample_from_discretized_mix_logistic(x.unsqueeze(-1),hp.log_scale_min)
279 | elif hp.input_type == 'bits':
280 | posterior = F.softmax(x, dim=1).view(b_size, -1)
281 | distrib = torch.distributions.Categorical(posterior)
282 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
283 | elif hp.input_type == 'mulaw':
284 | posterior = F.softmax(x, dim=1).view(b_size, -1)
285 | distrib = torch.distributions.Categorical(posterior)
286 | print(type(distrib.sample()))
287 | sample = inv_mulaw_quantize(distrib.sample(), hp.mulaw_quantize_channels, True)
288 | output.append(sample.view(-1))
289 | x = sample.view(b_size,1)
290 | output = torch.stack(output).cpu().numpy()
291 | self.train()
292 | # output is a batch of wav segments of shape [batch_size x seq_len]
293 | # will need to merge into one wav of size [batch_size * seq_len]
294 | assert output.shape[1] == b_size
295 | output = (output.swapaxes(1,0)).reshape(-1)
296 | return output
297 |
298 | def get_gru_cell(self, gru) :
299 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
300 | gru_cell.weight_hh.data = gru.weight_hh_l0.data
301 | gru_cell.weight_ih.data = gru.weight_ih_l0.data
302 | gru_cell.bias_hh.data = gru.bias_hh_l0.data
303 | gru_cell.bias_ih.data = gru.bias_ih_l0.data
304 | return gru_cell
305 |
306 |
307 | def build_model():
308 | """build model with hparams settings
309 |
310 | """
311 | if hp.input_type == 'raw':
312 | print('building model with Beta distribution output')
313 | elif hp.input_type == 'mixture':
314 | print("building model with mixture of logistic output")
315 | elif hp.input_type == 'bits':
316 | print("building model with quantized bit audio")
317 | elif hp.input_type == 'mulaw':
318 | print("building model with quantized mulaw encoding")
319 | else:
320 | raise ValueError('input_type provided not supported')
321 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits,
322 | hp.pad, hp.upsample_factors, hp.num_mels,
323 | hp.compute_dims, hp.res_out_dims, hp.res_blocks)
324 |
325 | return model
326 |
327 | def no_test_build_model():
328 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits,
329 | hp.pad, hp.upsample_factors, hp.num_mels,
330 | hp.compute_dims, hp.res_out_dims, hp.res_blocks).cuda()
331 | print(vars(model))
332 |
333 |
334 | def test_batch_generate():
335 | model = Model(hp.rnn_dims, hp.fc_dims, hp.bits,
336 | hp.pad, hp.upsample_factors, hp.num_mels,
337 | hp.compute_dims, hp.res_out_dims, hp.res_blocks).cuda()
338 | print(vars(model))
339 | batch_mel = torch.rand(3, 80, 100)
340 | output = model.batch_generate(batch_mel)
341 | print(output.shape)
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess dataset
3 |
4 | usage: preproess.py [options]
5 |
6 | options:
7 | --output-dir= Directory where processed outputs are saved. [default: data_dir].
8 | -h, --help Show help message.
9 | """
10 | import os
11 | from docopt import docopt
12 | import numpy as np
13 | import math, pickle, os
14 | from audio import *
15 | from hparams import hparams as hp
16 | from utils import *
17 | from tqdm import tqdm
18 |
19 | def get_wav_mel(path):
20 | """Given path to .wav file, get the quantized wav and mel spectrogram as numpy vectors
21 |
22 | """
23 | wav = load_wav(path)
24 | mel = melspectrogram(wav)
25 | if hp.input_type == 'raw':
26 | return wav.astype(np.float32), mel
27 | elif hp.input_type == 'mulaw':
28 | quant = mulaw_quantize(wav, hp.mulaw_quantize_channels)
29 | return quant.astype(np.int), mel
30 | elif hp.input_type == 'bits':
31 | quant = quantize(wav)
32 | return quant.astype(np.int), mel
33 | else:
34 | raise ValueError("hp.input_type {} not recognized".format(hp.input_type))
35 |
36 |
37 |
38 |
39 |
40 | def process_data(wav_dir, output_path, mel_path, wav_path):
41 | """
42 | given wav directory and output directory, process wav files and save quantized wav and mel
43 | spectrogram to output directory
44 | """
45 | dataset_ids = []
46 | # get list of wav files
47 | wav_files = os.listdir(wav_dir)
48 | # check wav_file
49 | assert len(wav_files) != 0 or wav_files[0][-4:] == '.wav', "no wav files found!"
50 | # create training and testing splits
51 | test_wav_files = wav_files[:4]
52 | wav_files = wav_files[4:]
53 | for i, wav_file in enumerate(tqdm(wav_files)):
54 | # get the file id
55 | file_id = '{:d}'.format(i).zfill(5)
56 | wav, mel = get_wav_mel(os.path.join(wav_dir,wav_file))
57 | # save
58 | np.save(os.path.join(mel_path,file_id+".npy"), mel)
59 | np.save(os.path.join(wav_path,file_id+".npy"), wav)
60 | # add to dataset_ids
61 | dataset_ids.append(file_id)
62 |
63 | # save dataset_ids
64 | with open(os.path.join(output_path,'dataset_ids.pkl'), 'wb') as f:
65 | pickle.dump(dataset_ids, f)
66 |
67 | # process testing_wavs
68 | test_path = os.path.join(output_path,'test')
69 | os.makedirs(test_path, exist_ok=True)
70 | for i, wav_file in enumerate(test_wav_files):
71 | wav, mel = get_wav_mel(os.path.join(wav_dir,wav_file))
72 | # save test_wavs
73 | np.save(os.path.join(test_path,"test_{}_mel.npy".format(i)),mel)
74 | np.save(os.path.join(test_path,"test_{}_wav.npy".format(i)),wav)
75 |
76 |
77 | print("\npreprocessing done, total processed wav files:{}.\nProcessed files are located in:{}".format(len(wav_files), os.path.abspath(output_path)))
78 |
79 |
80 |
81 | if __name__=="__main__":
82 | args = docopt(__doc__)
83 | wav_dir = args[""]
84 | output_dir = args["--output-dir"]
85 |
86 | # create paths
87 | output_path = os.path.join(output_dir,"")
88 | mel_path = os.path.join(output_dir,"mel")
89 | wav_path = os.path.join(output_dir,"wav")
90 |
91 | # create dirs
92 | os.makedirs(output_path, exist_ok=True)
93 | os.makedirs(mel_path, exist_ok=True)
94 | os.makedirs(wav_path, exist_ok=True)
95 |
96 | # process data
97 | process_data(wav_dir, output_path, mel_path, wav_path)
98 |
99 |
100 |
101 | def test_get_wav_mel():
102 | wav, mel = get_wav_mel('sample.wav')
103 | print(wav.shape, mel.shape)
104 | print(wav)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | docopt
2 | librosa
3 | nnmnkwii
4 | tqdm
5 | lws
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Training WaveRNN Model.
2 |
3 | usage: train.py [options]
4 |
5 | options:
6 | --checkpoint-dir= Directory where to save model checkpoints [default: checkpoints].
7 | --checkpoint= Restore model from checkpoint path if given.
8 | -h, --help Show this help message and exit
9 | """
10 | from docopt import docopt
11 |
12 | import os
13 | from os.path import dirname, join, expanduser
14 | from tqdm import tqdm
15 |
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import librosa
19 |
20 | from model import build_model
21 |
22 | import torch
23 | from torch import nn
24 | import torch.nn.functional as F
25 | from torch import optim
26 | from torch.utils.data import DataLoader
27 |
28 | from model import build_model
29 | from distributions import *
30 | from loss_function import nll_loss
31 | from dataset import raw_collate, discrete_collate, AudiobookDataset
32 | from hparams import hparams as hp
33 | from lrschedule import noam_learning_rate_decay, step_learning_rate_decay
34 |
35 | global_step = 0
36 | global_epoch = 0
37 | global_test_step = 0
38 | use_cuda = torch.cuda.is_available()
39 |
40 | def save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch):
41 | checkpoint_path = join(
42 | checkpoint_dir, "checkpoint_step{:09d}.pth".format(step))
43 | optimizer_state = optimizer.state_dict()
44 | global global_test_step
45 | torch.save({
46 | "state_dict": model.state_dict(),
47 | "optimizer": optimizer_state,
48 | "global_step": step,
49 | "global_epoch": epoch,
50 | "global_test_step": global_test_step,
51 | }, checkpoint_path)
52 | print("Saved checkpoint:", checkpoint_path)
53 |
54 |
55 | def _load(checkpoint_path):
56 | if use_cuda:
57 | checkpoint = torch.load(checkpoint_path)
58 | else:
59 | checkpoint = torch.load(checkpoint_path,
60 | map_location=lambda storage, loc: storage)
61 | return checkpoint
62 |
63 |
64 | def load_checkpoint(path, model, optimizer, reset_optimizer):
65 | global global_step
66 | global global_epoch
67 | global global_test_step
68 |
69 | print("Load checkpoint from: {}".format(path))
70 | checkpoint = _load(path)
71 | model.load_state_dict(checkpoint["state_dict"])
72 | if not reset_optimizer:
73 | optimizer_state = checkpoint["optimizer"]
74 | if optimizer_state is not None:
75 | print("Load optimizer state from {}".format(path))
76 | optimizer.load_state_dict(checkpoint["optimizer"])
77 | global_step = checkpoint["global_step"]
78 | global_epoch = checkpoint["global_epoch"]
79 | global_test_step = checkpoint.get("global_test_step", 0)
80 |
81 | return model
82 |
83 |
84 | def test_save_checkpoint():
85 | checkpoint_path = "checkpoints/"
86 | device = torch.device("cuda" if use_cuda else "cpu")
87 | model = build_model()
88 | optimizer = optim.Adam(model.parameters(), lr=1e-4)
89 | global global_step, global_epoch, global_test_step
90 | save_checkpoint(device, model, optimizer, global_step, checkpoint_path, global_epoch)
91 |
92 | model = load_checkpoint(checkpoint_path+"checkpoint_step000000000.pth", model, optimizer, False)
93 |
94 |
95 | def evaluate_model(model, data_loader, checkpoint_dir, limit_eval_to=5):
96 | """evaluate model and save generated wav and plot
97 |
98 | """
99 | test_path = data_loader.dataset.test_path
100 | test_files = os.listdir(test_path)
101 | counter = 0
102 | output_dir = os.path.join(checkpoint_dir,'eval')
103 | for f in test_files:
104 | if f[-7:] == "mel.npy":
105 | mel = np.load(os.path.join(test_path,f))
106 | wav = model.generate(mel)
107 | # save wav
108 | wav_path = os.path.join(output_dir,"checkpoint_step{:09d}_wav_{}.wav".format(global_step,counter))
109 | librosa.output.write_wav(wav_path, wav, sr=hp.sample_rate)
110 | # save wav plot
111 | fig_path = os.path.join(output_dir,"checkpoint_step{:09d}_wav_{}.png".format(global_step,counter))
112 | fig = plt.plot(wav.reshape(-1))
113 | plt.savefig(fig_path)
114 | # clear fig to drawing to the same plot
115 | plt.clf()
116 | counter += 1
117 | # stop evaluation early via limit_eval_to
118 | if counter >= limit_eval_to:
119 | break
120 |
121 |
122 | def train_loop(device, model, data_loader, optimizer, checkpoint_dir):
123 | """Main training loop.
124 |
125 | """
126 | # create loss and put on device
127 | if hp.input_type == 'raw':
128 | if hp.distribution == 'beta':
129 | criterion = beta_mle_loss
130 | elif hp.distribution == 'gaussian':
131 | criterion = gaussian_loss
132 | elif hp.input_type == 'mixture':
133 | criterion = discretized_mix_logistic_loss
134 | elif hp.input_type in ["bits", "mulaw"]:
135 | criterion = nll_loss
136 | else:
137 | raise ValueError("input_type:{} not supported".format(hp.input_type))
138 |
139 |
140 |
141 | global global_step, global_epoch, global_test_step
142 | while global_epoch < hp.nepochs:
143 | running_loss = 0
144 | for i, (x, m, y) in enumerate(tqdm(data_loader)):
145 | x, m, y = x.to(device), m.to(device), y.to(device)
146 | y_hat = model(x, m)
147 | y = y.unsqueeze(-1)
148 | loss = criterion(y_hat, y)
149 | # calculate learning rate and update learning rate
150 | if hp.fix_learning_rate:
151 | current_lr = hp.fix_learning_rate
152 | elif hp.lr_schedule_type == 'step':
153 | current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval)
154 | else:
155 | current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps)
156 | for param_group in optimizer.param_groups:
157 | param_group['lr'] = current_lr
158 | optimizer.zero_grad()
159 | loss.backward()
160 | # clip gradient norm
161 | nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm)
162 | optimizer.step()
163 |
164 | running_loss += loss.item()
165 | avg_loss = running_loss / (i+1)
166 | # saving checkpoint if needed
167 | if global_step != 0 and global_step % hp.save_every_step == 0:
168 | save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch)
169 | # evaluate model if needed
170 | if global_step != 0 and global_test_step !=True and global_step % hp.evaluate_every_step == 0:
171 | print("step {}, evaluating model: generating wav from mel...".format(global_step))
172 | evaluate_model(model, data_loader, checkpoint_dir)
173 | print("evaluation finished, resuming training...")
174 |
175 | # reset global_test_step status after evaluation
176 | if global_test_step is True:
177 | global_test_step = False
178 | global_step += 1
179 |
180 | print("epoch:{}, running loss:{}, average loss:{}, current lr:{}".format(global_epoch, running_loss, avg_loss, current_lr))
181 | global_epoch += 1
182 |
183 |
184 |
185 | if __name__=="__main__":
186 | args = docopt(__doc__)
187 | #print("Command line args:\n", args)
188 | checkpoint_dir = args["--checkpoint-dir"]
189 | checkpoint_path = args["--checkpoint"]
190 | data_root = args[""]
191 |
192 | # make dirs, load dataloader and set up device
193 | os.makedirs(checkpoint_dir, exist_ok=True)
194 | os.makedirs(os.path.join(checkpoint_dir,'eval'), exist_ok=True)
195 | dataset = AudiobookDataset(data_root)
196 | if hp.input_type == 'raw':
197 | collate_fn = raw_collate
198 | elif hp.input_type == 'mixture':
199 | collate_fn = raw_collate
200 | elif hp.input_type in ['bits', 'mulaw']:
201 | collate_fn = discrete_collate
202 | else:
203 | raise ValueError("input_type:{} not supported".format(hp.input_type))
204 | data_loader = DataLoader(dataset, collate_fn=collate_fn, shuffle=True, num_workers=0, batch_size=hp.batch_size)
205 | device = torch.device("cuda" if use_cuda else "cpu")
206 | print("using device:{}".format(device))
207 |
208 | # build model, create optimizer
209 | model = build_model().to(device)
210 | optimizer = optim.Adam(model.parameters(),
211 | lr=hp.initial_learning_rate, betas=(
212 | hp.adam_beta1, hp.adam_beta2),
213 | eps=hp.adam_eps, weight_decay=hp.weight_decay,
214 | amsgrad=hp.amsgrad)
215 |
216 | if hp.fix_learning_rate:
217 | print("using fixed learning rate of :{}".format(hp.fix_learning_rate))
218 | elif hp.lr_schedule_type == 'step':
219 | print("using exponential learning rate decay")
220 | elif hp.lr_schedule_type == 'noam':
221 | print("using noam learning rate decay")
222 |
223 | # load checkpoint
224 | if checkpoint_path is None:
225 | print("no checkpoint specified as --checkpoint argument, creating new model...")
226 | else:
227 | model = load_checkpoint(checkpoint_path, model, optimizer, False)
228 | print("loading model from checkpoint:{}".format(checkpoint_path))
229 | # set global_test_step to True so we don't evaluate right when we load in the model
230 | global_test_step = True
231 |
232 | # main train loop
233 | try:
234 | train_loop(device, model, data_loader, optimizer, checkpoint_dir)
235 | except KeyboardInterrupt:
236 | print("Interrupted!")
237 | pass
238 | finally:
239 | print("saving model....")
240 | save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch)
241 |
242 |
243 | def test_eval():
244 | data_root = "data_dir"
245 | dataset = AudiobookDataset(data_root)
246 | if hp.input_type == 'raw':
247 | collate_fn = raw_collate
248 | elif hp.input_type == 'bits':
249 | collate_fn = discrete_collate
250 | else:
251 | raise ValueError("input_type:{} not supported".format(hp.input_type))
252 | data_loader = DataLoader(dataset, collate_fn=collate_fn, shuffle=True, num_workers=0, batch_size=hp.batch_size)
253 | device = torch.device("cuda" if use_cuda else "cpu")
254 | print("using device:{}".format(device))
255 |
256 | # build model, create optimizer
257 | model = build_model().to(device)
258 |
259 | evaluate_model(model, data_loader)
260 |
261 |
262 |
263 |
264 |
265 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def num_params(model) :
5 | parameters = filter(lambda p: p.requires_grad, model.parameters())
6 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
7 | print('Trainable Parameters: %.3f million' % parameters)
8 |
9 |
10 | # for mulaw encoding and decoding in torch tensors, modified from: https://github.com/pytorch/audio/blob/master/torchaudio/transforms.py
11 | def mulaw_quantize(x, quantization_channels=256):
12 | """Encode signal based on mu-law companding. For more info see the
13 | `Wikipedia Entry `_
14 |
15 | This algorithm assumes the signal has been scaled to between -1 and 1 and
16 | returns a signal encoded with values from 0 to quantization_channels - 1
17 |
18 | Args:
19 | quantization_channels (int): Number of channels. default: 256
20 |
21 | """
22 | mu = quantization_channels - 1
23 | if isinstance(x, np.ndarray):
24 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
25 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
26 | elif isinstance(x, (torch.Tensor, torch.LongTensor)):
27 |
28 | if isinstance(x, torch.LongTensor):
29 | x = x.float()
30 | mu = torch.FloatTensor([mu])
31 | x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
32 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
33 | return x_mu
34 |
35 |
36 | def inv_mulaw_quantize(x_mu, quantization_channels=256, cuda=False):
37 | """Decode mu-law encoded signal. For more info see the
38 | `Wikipedia Entry `_
39 |
40 | This expects an input with values between 0 and quantization_channels - 1
41 | and returns a signal scaled between -1 and 1.
42 |
43 | Args:
44 | quantization_channels (int): Number of channels. default: 256
45 |
46 | """
47 | mu = quantization_channels - 1.
48 | if isinstance(x_mu, np.ndarray):
49 | x = ((x_mu) / mu) * 2 - 1.
50 | x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
51 | elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)):
52 | if isinstance(x_mu, (torch.LongTensor, torch.cuda.LongTensor)):
53 | x_mu = x_mu.float()
54 | if cuda:
55 | mu = (torch.FloatTensor([mu])).cuda()
56 | else:
57 | mu = torch.FloatTensor([mu])
58 | x = ((x_mu) / mu) * 2 - 1.
59 | x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
60 | return x
61 |
62 |
63 | def test_inv_mulaw():
64 | wav = torch.rand(5, 5000)
65 | wav = wav.cuda()
66 | de_quant = inv_mulaw_quantize(wav, 512, True)
--------------------------------------------------------------------------------