├── README.md ├── __init__.py ├── ae_bn.py ├── autoencoder_model.py ├── chassis.py ├── checkpoint.py ├── dat ├── example_train.log ├── librispeech.dev-clean.rdb ├── librispeech.some.dat └── librispeech.test-clean.rdb ├── data.py ├── doc ├── combining_vae_and_ar.txt ├── commitment_loss_and_batching.txt ├── generalized_batching.txt ├── loss_terms.txt ├── mfcc_inverter_notes.txt ├── notes.txt ├── padding_notes.txt ├── rfield_notes.txt ├── rfield_notes2.txt ├── todo.txt ├── upsampling_notes.txt ├── vae_and_ar_issues.txt ├── vconv_notes.old.txt └── vconv_notes.txt ├── grad_analysis.py ├── hparams.py ├── jitter.py ├── logs └── vq.log ├── mfcc.py ├── mfcc_inverter.py ├── netmisc.py ├── par ├── arch.ae.json ├── arch.basic.json ├── arch.mi.json ├── arch.vae.json ├── arch.vqvae-ema.json ├── train.basic.json ├── train.mi.json └── train.vae.json ├── parse_tools.py ├── preprocess.py ├── results ├── a39d.png └── a39d.txt ├── scripts ├── librispeech_to_rdb.sh ├── train_plot.py └── viewlog.sh ├── test.py ├── test_data.py ├── test_model.py ├── test_vconv.old.py ├── test_vconv.py ├── train.py ├── util.py ├── vae_bn.py ├── vconv.py ├── vq_bn.py ├── vqema_bn.py ├── wave_encoder.py └── wavenet.py /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Jan Chorowski, Jan 2019 paper" 2 | 3 | This is a PyTorch implementation of https://arxiv.org/abs/1901.08810. 4 | 5 | [Under Construction] 6 | 7 | ## Update June 14, 2020 8 | 9 | Training a simpler model to perform the "mfcc inversion" task. The idea is: 10 | 11 | 1. preprocess: wav -> mfcc 12 | 2. vqvae: mfcc -> z -> mfcc 13 | 3. mfcc-inverter: mfcc -> wav 14 | 15 | The mfcc-inverter model is just a wavenet conditioned on mfcc vectors (1 every 16 | 160 timesteps) which produces the original wav used to compute the mfcc 17 | vectors. It is a probabilistic inverse of the preprocessing step. 18 | 19 | It should be noted that there is loss of information in the preprocessing step, 20 | so the inverter cannot attain 100% accuracy unless it overfits the data. 21 | 22 | Once the mfcc inverter model is trained, it can be used in conjunction with 23 | a vq-vae model that starts and ends with MFCC. One advantage to this is that 24 | the training of the vq-vae model may be slightly less compute intensive, since 25 | there are only 39 components (one mfcc vector plus first and second 26 | derivatives) every 160 timesteps, instead of 160. 27 | 28 | See results directory for some preliminary training results. 29 | 30 | ## Update April 14, 2019 31 | 32 | Began training on Librispeech dev (http://www.openslr.org/resources/12/dev-clean.tar.gz), 33 | see dat/example\_train.log 34 | 35 | ## Update May 12, 2019 36 | 37 | First runs using vqvae mode. After ~200 iterations, only one quantized vector is 38 | used as a representative. Currently troubleshooting. 39 | 40 | ## Update Nov 7, 2019 41 | 42 | Resumed work as of Sept, 2019. Implemented EMA for updates. Fixed a bug in 43 | the VQVAE Loss function in vq_bn.py:312 44 | 45 | * Was: l2_loss_embeds = self.l2(self.bn.ze, self.bn.emb) 46 | * Now: l2_loss_embeds = self.l2(self.bn.sg(self.bn.ze), self.bn.emb) 47 | 48 | Training still exhibits codebook collapse. This seems due to the phenomenon of 49 | WaveNet learning to rely exclusively on the autoregressive input and ignore the 50 | conditioning input. 51 | 52 | 53 | # TODO 54 | 1. VAE and VQVAE versions of the bottleneck / training objectives [DONE] 55 | 2. Inference mode 56 | 57 | # Example training setup 58 | 59 | ```sh 60 | code_dir=/path/to/ae-wavenet 61 | run_dir=/path/to/my_runs 62 | 63 | # Get the data 64 | cd $run_dir 65 | wget http://www.openslr.org/resources/12/dev-clean.tar.gz 66 | tar zxvf dev-clean.tar.gz 67 | $code_dir/scripts/librispeech_to_rdb.sh LibriSpeech/dev-clean > librispeech.dev-clean.rdb 68 | 69 | # Preprocess the data 70 | # This stores a flattened, indexed copy of the sound data plus calculated MFCCs 71 | python preprocess.py librispeech.dev-clean.rdb librispeech.dev-clean.dat -nq 256 -sr 16000 72 | 73 | # Train 74 | # New mode 75 | cd $code_dir 76 | python train.py new -af par/arch.basic.json -tf par/train.basic.json -nb 4 -si 1000 \ 77 | -vqn 1000 $run_dir/model%.ckpt $run_dir/librispeech.dev-clean.dat $run_dir/data_slices.dat 78 | 79 | # Resume mode - resume from step 10000, save every 1000 steps 80 | python train.py resume -nb 4 -si 1000 $run_dir/model%.ckpt $run_dir/model10000.ckpt 81 | 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrbigelow/ae-wavenet/80b9c46637151f053f74728fc756f1d01ab0aa69/__init__.py -------------------------------------------------------------------------------- /ae_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import netmisc 4 | 5 | class AE(nn.Module): 6 | def __init__(self, n_in, n_out, bias=True): 7 | super(AE, self).__init__() 8 | self.linear = nn.Conv1d(n_in, n_out, 1, bias=bias) 9 | netmisc.xavier_init(self.linear) 10 | 11 | def forward(self, x): 12 | """ 13 | ze: (B, Q, N) 14 | """ 15 | self.ze = self.linear(x) 16 | return self.ze 17 | 18 | 19 | class AELoss(nn.Module): 20 | def __init__(self, bottleneck, norm_gamma): 21 | super(AELoss, self).__init__() 22 | self.logsoftmax = nn.LogSoftmax(1) # input is (B, Q, N) 23 | self.bottleneck = bottleneck 24 | self.register_buffer('norm_gamma', torch.tensor(norm_gamma)) 25 | # self.register_buffer('two', torch.tensor(2, dtype=torch.int32)) 26 | self.register_buffer('two', torch.tensor(2, dtype=torch.float32)) 27 | self.register_buffer('one', torch.tensor(1.0)) 28 | 29 | def forward(self, quant_pred, target_wav): 30 | 31 | log_pred = self.logsoftmax(quant_pred) 32 | target_wav_gather = target_wav.long().unsqueeze(1) 33 | log_pred_target = torch.gather(log_pred, 1, target_wav_gather) 34 | 35 | rec_loss = - log_pred_target.mean() 36 | ze_norm = (self.bottleneck.ze ** self.two).sum(dim=1).sqrt() 37 | 38 | norm_loss = self.norm_gamma * torch.abs(ze_norm - self.one).mean() 39 | total_loss = rec_loss + norm_loss 40 | 41 | self.metrics = { 42 | 'rec': rec_loss, 43 | 'norm': norm_loss 44 | } 45 | 46 | return total_loss 47 | 48 | -------------------------------------------------------------------------------- /autoencoder_model.py: -------------------------------------------------------------------------------- 1 | # Full Autoencoder model 2 | from sys import stderr 3 | from hashlib import md5 4 | import numpy as np 5 | from pickle import dumps 6 | import torch 7 | from torch import nn 8 | from torch.nn.modules import loss 9 | from scipy.cluster.vq import kmeans 10 | 11 | import ae_bn 12 | import mfcc 13 | import parse_tools 14 | import vconv 15 | import util 16 | import vq_bn 17 | import vqema_bn 18 | import vae_bn 19 | import wave_encoder as enc 20 | import wavenet as dec 21 | 22 | 23 | class AutoEncoder(nn.Module): 24 | """ 25 | Full Autoencoder model. The _initialize method allows us to seamlessly initialize 26 | from __init__ or __setstate__ 27 | """ 28 | def __init__(self, opts, dataset): 29 | opts_dict = vars(opts) 30 | enc_params = parse_tools.get_prefixed_items(opts_dict, 'enc_') 31 | bn_params = parse_tools.get_prefixed_items(opts_dict, 'bn_') 32 | dec_params = parse_tools.get_prefixed_items(opts_dict, 'dec_') 33 | dec_params['n_speakers'] = dataset.num_speakers() 34 | 35 | self.init_args = { 36 | 'enc_params': enc_params, 37 | 'bn_params': bn_params, 38 | 'dec_params': dec_params, 39 | 'n_mel_chan': dataset.num_mel_chan(), 40 | 'training': opts.training 41 | } 42 | self._initialize() 43 | 44 | def _initialize(self): 45 | super(AutoEncoder, self).__init__() 46 | enc_params = self.init_args['enc_params'] 47 | bn_params = self.init_args['bn_params'] 48 | dec_params = self.init_args['dec_params'] 49 | n_mel_chan = self.init_args['n_mel_chan'] 50 | training = self.init_args['training'] 51 | 52 | self.encoder = enc.Encoder(n_in=n_mel_chan, parent_vc=None, **enc_params) 53 | 54 | bn_type = bn_params['type'] 55 | bn_extra = dict((k, v) for k, v in bn_params.items() if k != 'type') 56 | 57 | # In each case, the objective function's 'forward' method takes the 58 | # same arguments. 59 | if bn_type == 'vqvae': 60 | self.bottleneck = vq_bn.VQ(**bn_extra, n_in=enc_params['n_out']) 61 | self.objective = vq_bn.VQLoss(self.bottleneck) 62 | 63 | elif bn_type == 'vqvae-ema': 64 | self.bottleneck = vqema_bn.VQEMA(**bn_extra, n_in=enc_params['n_out'], 65 | training=training) 66 | self.objective = vqema_bn.VQEMALoss(self.bottleneck) 67 | 68 | elif bn_type == 'vae': 69 | # mu and sigma members 70 | self.bottleneck = vae_bn.VAE(n_in=enc_params['n_out'], 71 | n_out=bn_params['n_out']) 72 | self.objective = vae_bn.SGVBLoss(self.bottleneck, 73 | free_nats=bn_params['free_nats']) 74 | 75 | elif bn_type == 'ae': 76 | self.bottleneck = ae_bn.AE(n_out=bn_extra['n_out'], n_in=enc_params['n_out']) 77 | self.objective = ae_bn.AELoss(self.bottleneck, 0.001) 78 | 79 | else: 80 | raise InvalidArgument('bn_type must be one of "ae", "vae", or "vqvae"') 81 | 82 | self.bn_type = bn_type 83 | self.decoder = dec.WaveNet( 84 | **dec_params, 85 | parent_vc=self.encoder.vc['end'], 86 | n_lc_in=bn_params['n_out'] 87 | ) 88 | self.vc = self.decoder.vc 89 | self.decoder.post_init() 90 | 91 | def post_init(self, dataset): 92 | self.encoder.set_parent_vc(dataset.mfcc_vc) 93 | self._init_geometry(dataset.window_batch_size) 94 | 95 | def _init_geometry(self, batch_win_size): 96 | """ 97 | Initializes lengths and trimming needed to produce batch_win_size 98 | output 99 | 100 | self.enc_in_len - encoder input length (timesteps) 101 | self.dec_in_len - decoder input length (timesteps) 102 | self.trim_ups_out - trims decoder lc_dense before use 103 | self.trim_dec_out - trims wav_dec_input to wav_dec_output 104 | self.trim_dec_in - trims wav_enc_input to wav_dec_input 105 | 106 | The trimming vectors are needed because, due to striding geometry, 107 | output tensors cannot be produced in single-increment sizes, therefore 108 | must be over-produced in some cases. 109 | """ 110 | # Calculate max length of mfcc encoder input and wav decoder input 111 | w = batch_win_size 112 | mfcc_vc = self.encoder.vc['beg'].parent 113 | end_enc_vc = self.encoder.vc['end'] 114 | end_ups_vc = self.decoder.vc['last_upsample'] 115 | beg_grcc_vc = self.decoder.vc['beg_grcc'] 116 | end_grcc_vc = self.decoder.vc['end_grcc'] 117 | 118 | # naming: (d: decoder, e: encoder, u: upsample), (o: output, i:input) 119 | do = vconv.GridRange((0, 100000), (0, w), 1) 120 | di = vconv.input_range(beg_grcc_vc, end_grcc_vc, do) 121 | ei = vconv.input_range(mfcc_vc, end_grcc_vc, do) 122 | mi = vconv.input_range(mfcc_vc.child, end_grcc_vc, do) 123 | eo = vconv.output_range(mfcc_vc, end_enc_vc, ei) 124 | uo = vconv.output_range(mfcc_vc, end_ups_vc, ei) 125 | 126 | # Needed for trimming various tensors 127 | self.enc_in_len = ei.sub_length() 128 | self.enc_in_mel_len = mi.sub_length() 129 | # used by jitter_index 130 | self.embed_len = eo.sub_length() 131 | 132 | # sets size for wav_dec_in 133 | self.dec_in_len = di.sub_length() 134 | 135 | # trims wav_enc_input to wav_dec_input 136 | self.trim_dec_in = torch.tensor([di.sub[0] - ei.sub[0], di.sub[1] - 137 | ei.sub[0]], dtype=torch.long) 138 | 139 | # needed by wavenet to trim upsampled local conditioning tensor 140 | self.decoder.trim_ups_out = torch.tensor([di.sub[0] - uo.sub[0], 141 | di.sub[1] - uo.sub[0]], dtype=torch.long) 142 | 143 | # 144 | self.trim_dec_out = torch.tensor( 145 | [do.sub[0] - di.sub[0], do.sub[1] - di.sub[0]], 146 | dtype=torch.long) 147 | 148 | def print_geometry(self): 149 | """ 150 | Print the convolutional geometry 151 | """ 152 | vc = self.encoder.vc['beg'].parent 153 | while vc is not None: 154 | print(vc) 155 | vc = vc.child 156 | 157 | 158 | def __getstate__(self): 159 | state = { 160 | 'init_args': self.init_args, 161 | # 'state_dict': self.state_dict() 162 | } 163 | return state 164 | 165 | def __setstate__(self, state): 166 | self.init_args = state['init_args'] 167 | self._initialize() 168 | # self.load_state_dict(state['state_dict']) 169 | 170 | 171 | def init_codebook(self, data_source, n_samples): 172 | """ 173 | Initialize the VQ Embedding with samples from the encoder 174 | """ 175 | if self.bn_type not in ('vqvae', 'vqvae-ema'): 176 | raise RuntimeError('init_vq_embed only applies to the vqvae model type') 177 | 178 | bn = self.bottleneck 179 | e = 0 180 | n_codes = bn.emb.shape[0] 181 | k = bn.emb.shape[1] 182 | samples = np.empty((n_samples, k), dtype=np.float) 183 | 184 | with torch.no_grad(): 185 | while e != n_samples: 186 | vbatch = next(data_source) 187 | encoding = self.encoder(vbatch.mel_enc_input) 188 | ze = self.bottleneck.linear(encoding) 189 | ze = ze.permute(0, 2, 1).flatten(0, 1) 190 | c = min(n_samples - e, ze.shape[0]) 191 | samples[e:e + c,:] = ze.cpu()[0:c,:] 192 | e += c 193 | 194 | km, __ = kmeans(samples, n_codes) 195 | bn.emb[...] = torch.from_numpy(km) 196 | 197 | if self.bn_type == 'vqvae-ema': 198 | bn.ema_numer = bn.emb * bn.ema_gamma_comp 199 | bn.ema_denom = bn.n_sum_ones * bn.ema_gamma_comp 200 | 201 | def checksum(self): 202 | """Return checksum of entire set of model parameters""" 203 | return util.tensor_digest(self.parameters()) 204 | 205 | 206 | def forward(self, mels, wav_dec, voice_inds, jitter_index): 207 | """ 208 | B: n_batch 209 | M: n_mels 210 | T: receptive field of autoencoder 211 | T': receptive field of decoder 212 | R: size of local conditioning output of encoder (T - encoder.vc.total()) 213 | N: n_win (# consecutive samples processed in one batch channel) 214 | Q: n_quant 215 | mels: (B, M, T) 216 | wav_compand: (B, T) 217 | wav_dec: (B, T') 218 | Outputs: 219 | quant_pred (B, Q, N) # predicted wav amplitudes 220 | """ 221 | encoding = self.encoder(mels) 222 | self.encoding_bn = self.bottleneck(encoding) 223 | quant = self.decoder(wav_dec, self.encoding_bn, voice_inds, 224 | jitter_index) 225 | return quant 226 | 227 | def run(self, vbatch): 228 | """ 229 | Run the model on one batch, returning the predicted and 230 | actual output 231 | B, T, Q: n_batch, n_timesteps, n_quant 232 | Outputs: 233 | quant_pred: (B, Q, T) (the prediction from the model) 234 | wav_batch_out: (B, T) (the actual data from the same timesteps) 235 | """ 236 | # Slice each wav input 237 | trim = self.trim_dec_out 238 | wav_batch_out = vbatch.wav_dec_input[:,trim[0]:trim[1]] 239 | # wav_batch_out = torch.take(vbatch.wav_dec_input, vbatch.loss_wav_slice) 240 | #for b, (sl_b, sl_e) in enumerate(vbatch.loss_wav_slice): 241 | # wav_batch_out[b] = vbatch.wav_dec_input[b,sl_b:sl_e] 242 | 243 | # self.wav_batch_out = wav_batch_out 244 | # self.wav_onehot_dec = wav_onehot_dec 245 | 246 | quant = self.forward(vbatch.mel_enc_input, vbatch.wav_dec_input, 247 | vbatch.voice_index, vbatch.jitter_index) 248 | 249 | pred, target = quant[...,:-1], wav_batch_out[...,1:] 250 | 251 | loss = self.objective(pred, target) 252 | ag_inputs = (vbatch.mel_enc_input, self.encoding_bn) 253 | mel_grad, bn_grad = torch.autograd.grad(loss, ag_inputs, retain_graph=True) 254 | self.objective.metrics.update({ 255 | 'mel_grad_sd': mel_grad.std(), 256 | 'bn_grad_sd': bn_grad.std() 257 | }) 258 | # loss.backward(create_graph=True, retain_graph=True) 259 | return pred, target, loss 260 | 261 | -------------------------------------------------------------------------------- /chassis.py: -------------------------------------------------------------------------------- 1 | from sys import stderr 2 | import torch as t 3 | from tensorboardX import SummaryWriter 4 | # this SummaryWriter doesn't work with torch_xla, causes crash 5 | # from torch.utils.tensorboard import SummaryWriter 6 | import data 7 | import autoencoder_model as ae 8 | import mfcc_inverter as mi 9 | import checkpoint as ckpt 10 | import util 11 | import netmisc 12 | import librosa 13 | import os.path 14 | import time 15 | 16 | try: 17 | import torch_xla 18 | import torch_xla.core.xla_model as xm 19 | import torch_xla.distributed.parallel_loader as pl 20 | except ModuleNotFoundError: 21 | pass 22 | 23 | 24 | class GPULoaderIter(object): 25 | def __init__(self, loader, device): 26 | self.loader_iter = iter(loader) 27 | self.device = device 28 | 29 | def __iter__(self): 30 | return self 31 | 32 | def __next__(self): 33 | items = next(self.loader_iter) 34 | return tuple(item.to(self.device) if isinstance(item, t.Tensor) else 35 | item for item in items) 36 | 37 | 38 | def reduce_add(vlist): 39 | return t.stack(vlist).sum(dim=0) 40 | 41 | def reduce_mean(vlist): 42 | return t.stack(vlist).mean(dim=0) 43 | 44 | class Chassis(object): 45 | """ 46 | Coordinates the construction of the model, dataset, optimizer, 47 | checkpointing state, and GPU/TPU iterator wrappers. 48 | 49 | Provides a single function for training the model from the constructed 50 | setup. 51 | 52 | """ 53 | def __init__(self, device, index, hps, dat_file): 54 | self.is_tpu = (hps.hw in ('TPU', 'TPU-single')) 55 | if self.is_tpu: 56 | num_replicas = xm.xrt_world_size() 57 | rank = xm.get_ordinal() 58 | elif hps.hw == 'GPU': 59 | if not t.cuda.is_available(): 60 | raise RuntimeError('GPU requested but not available') 61 | num_replicas = 1 62 | rank = 0 63 | elif hps.hw == 'CPU': 64 | num_replicas = 1 65 | rank = 0 66 | else: 67 | raise ValueError(f'Chassis: Invalid device "{hps.hw}" requested') 68 | 69 | self.replica_index = index 70 | 71 | self.state = ckpt.Checkpoint(hps, dat_file, train_mode=True, 72 | ckpt_file=hps.get('ckpt_file', None), 73 | num_replicas=num_replicas, rank=rank) 74 | 75 | hps = self.state.hps 76 | if not self.is_tpu or xm.is_master_ordinal(): 77 | print('Hyperparameters:\n', file=stderr) 78 | print('\n'.join(f'{k} = {v}' for k, v in hps.items()), file=stderr) 79 | 80 | self.learning_rates = dict(zip(hps.learning_rate_steps, 81 | hps.learning_rate_rates)) 82 | 83 | if self.state.model.bn_type == 'vae': 84 | self.anneal_schedule = dict(zip(hps.bn_anneal_weight_steps, 85 | hps.bn_anneal_weight_vals)) 86 | 87 | self.ckpt_path = util.CheckpointPath(hps.ckpt_template, not self.is_tpu 88 | or xm.is_master_ordinal()) 89 | 90 | self.softmax = t.nn.Softmax(1) # input to this is (B, Q, N) 91 | self.hw = hps.hw 92 | 93 | if hps.hw == 'GPU': 94 | self.device_loader = GPULoaderIter(self.state.data.loader, device) 95 | self.state.to(device) 96 | else: 97 | para_loader = pl.ParallelLoader(self.state.data.loader, [device]) 98 | self.device_loader = para_loader.per_device_loader(device) 99 | self.num_devices = xm.xrt_world_size() 100 | self.state.to(device) 101 | 102 | self.state.init_torch_generator() 103 | 104 | if not self.is_tpu or xm.is_master_ordinal(): 105 | self.writer = SummaryWriter(log_dir=hps.log_dir) 106 | else: 107 | self.writer = None 108 | 109 | def train(self): 110 | hps = self.state.hps 111 | ss = self.state 112 | current_stats = {} 113 | writer_stats = {} 114 | 115 | # for resuming the learning rate 116 | sorted_lr_steps = sorted(self.learning_rates.keys()) 117 | lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.data.global_step) 118 | ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) 119 | 120 | if ss.model.bn_type != 'none': 121 | sorted_as_steps = sorted(self.anneal_schedule.keys()) 122 | as_index = util.greatest_lower_bound(sorted_as_steps, 123 | ss.data.global_step) 124 | ss.model.objective.update_anneal_weight(self.anneal_schedule[sorted_as_steps[as_index]]) 125 | 126 | if ss.model.bn_type in ('vqvae', 'vqvae-ema'): 127 | ss.model.init_codebook(self.data_iter, 10000) 128 | 129 | start_time = time.time() 130 | 131 | for batch_num, batch in enumerate(self.device_loader): 132 | wav, mel, voice, jitter, position = batch 133 | global_step = len(ss.data.dataset) * position[0] + position[1] 134 | 135 | # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr) 136 | # stderr.flush() 137 | if (batch_num % hps.save_interval == 0 and batch_num != 0): 138 | self.save_checkpoint(position) 139 | 140 | if hps.skip_loop_body: 141 | continue 142 | 143 | lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step) 144 | ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) 145 | # if ss.data.global_step in self.learning_rates: 146 | # ss.update_learning_rate(self.learning_rates[ss.data.global_step]) 147 | 148 | if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule: 149 | ss.model.objective.update_anneal_weight(self.anneal_schedule[ss.data.global_step]) 150 | 151 | ss.optim.zero_grad() 152 | quant, self.target, loss = self.state.model.run(wav, mel, voice, jitter) 153 | self.probs = self.softmax(quant) 154 | self.mel_enc_input = mel 155 | # print(f'after model.run', file=stderr) 156 | # stderr.flush() 157 | loss.backward() 158 | 159 | # print(f'after loss.backward()', file=stderr) 160 | # stderr.flush() 161 | 162 | if batch_num % hps.progress_interval == 0: 163 | pars_copy = [p.data.clone() for p in ss.model.parameters()] 164 | 165 | # print(f'after pars_copy', file=stderr) 166 | # stderr.flush() 167 | 168 | if self.is_tpu: 169 | xm.optimizer_step(ss.optim) 170 | else: 171 | ss.optim.step() 172 | 173 | ss.optim_step += 1 174 | 175 | if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000: 176 | ss.model.bottleneck.update_codebook() 177 | 178 | tprb_m = self.avg_prob_target() 179 | 180 | if batch_num % hps.progress_interval == 0: 181 | iterator = zip(pars_copy, ss.model.named_parameters()) 182 | uw_ratio = { np[0]: t.norm(c - np[1].data) / c.norm() for c, np 183 | in iterator } 184 | 185 | writer_stats.update({ 'uwr': uw_ratio }) 186 | 187 | if self.is_tpu: 188 | count = torch_xla._XLAC._xla_get_replication_devices_count() 189 | loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m], 190 | scale=1.0 / count) 191 | # loss_red = xm.all_reduce('all_loss', loss, reduce_mean) 192 | # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean) 193 | else: 194 | loss_red = loss 195 | tprb_red = tprb_m 196 | 197 | writer_stats.update({ 198 | 'loss_r': loss_red, 199 | 'tprb_r': tprb_red, 200 | 'optim_step': ss.optim_step 201 | }) 202 | 203 | 204 | current_stats.update({ 205 | 'optim_step': ss.optim_step, 206 | 'gstep': global_step, 207 | # 'gstep': ss.data.global_step, 208 | 'epoch': position[0], 209 | 'step': position[1], 210 | # 'loss': loss, 211 | 'lrate': ss.optim.param_groups[0]['lr'], 212 | # 'tprb_m': tprb_m, 213 | # 'pk_d_m': avg_peak_dist 214 | }) 215 | current_stats.update(ss.model.objective.metrics) 216 | 217 | if ss.model.bn_type in ('vae'): 218 | current_stats['free_nats'] = ss.model.objective.free_nats 219 | current_stats['anneal_weight'] = \ 220 | ss.model.objective.anneal_weight.item() 221 | 222 | if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'): 223 | current_stats.update(ss.model.encoder.metrics) 224 | 225 | if self.is_tpu: 226 | xm.add_step_closure( 227 | self.train_update, 228 | args=(writer_stats, current_stats)) 229 | else: 230 | self.train_update(writer_stats, current_stats) 231 | 232 | # if not self.is_tpu or xm.is_master_ordinal(): 233 | # if batch_num in range(25, 50) or batch_num in range(75, 100): 234 | stderr.flush() 235 | elapsed = time.time() - start_time 236 | # print(f'{elapsed}, worker {self.replica_index}, batch {batch_num}', file=stderr) 237 | # stderr.flush() 238 | 239 | def train_update(self, writer_stats, stdout_stats): 240 | if self.replica_index == 0: 241 | netmisc.print_metrics(stdout_stats, self.replica_index, 100) 242 | if self.writer: 243 | self.writer.add_scalars('metrics', { k: writer_stats[k].item() for k 244 | in ('loss_r', 'tprb_r') }, writer_stats['optim_step']) 245 | 246 | self.writer.add_scalars('uw ratio', writer_stats['uwr'], writer_stats['optim_step']) 247 | self.writer.flush() 248 | 249 | 250 | def save_checkpoint(self, position): 251 | global_step = len(self.state.data.dataset) * position[0] + position[1] 252 | ckpt_file = self.ckpt_path.path(global_step.item()) 253 | self.state.save(ckpt_file, position[0], position[1]) 254 | 255 | if not self.is_tpu or xm.is_master_ordinal(): 256 | print('Saved checkpoint to {}'.format(ckpt_file), file=stderr) 257 | stderr.flush() 258 | 259 | def avg_max(self): 260 | """Average max value for the predictions. As the prediction becomes 261 | more peaked, this should go up""" 262 | max_val, max_ind = t.max(self.probs, dim=1) 263 | mean = t.mean(max_val) 264 | return mean 265 | 266 | def avg_prob_target(self): 267 | """Average probability given to target""" 268 | target_probs = t.gather(self.probs, 1, self.target.long().unsqueeze(1)) 269 | mean = t.mean(target_probs) 270 | return mean 271 | 272 | 273 | class DataContainer(t.nn.Module): 274 | def __init__(self, my_values): 275 | super().__init__() 276 | for key in my_values: 277 | setattr(self, key, my_values[key]) 278 | 279 | def forward(self): 280 | pass 281 | 282 | 283 | class InferenceChassis(object): 284 | """ 285 | Coordinates construction of model and dataset for running inference 286 | """ 287 | def __init__(self, device, index, hps, dat_file): 288 | self.output_dir = hps.output_dir 289 | self.n_replicas = hps.dec_n_replicas 290 | try: 291 | self.data_write_tmpl = hps.data_write_tmpl 292 | except AttributeError: 293 | self.data_write_tmpl = None 294 | 295 | self.state = ckpt.InferenceState(hps, dat_file, hps.ckpt_file) 296 | self.state.model.wavenet.set_n_replicas(self.n_replicas) 297 | self.state.model.eval() 298 | self.sample_rate = hps.sample_rate 299 | 300 | if hps.hw in ('GPU', 'CPU'): 301 | self.device_loader = GPULoaderIter(self.state.data.loader, device) 302 | self.state.to(device) 303 | else: 304 | import torch_xla.core.xla_model as xm 305 | import torch_xla.distributed.parallel_loader as pl 306 | para_loader = pl.ParallelLoader(self.state.data.loader, [device]) 307 | self.device_loader = para_loader.per_device_loader(device) 308 | self.num_devices = xm.xrt_world_size() 309 | self.state.to(device) 310 | 311 | def infer(self, model_scr=None): 312 | n_quant = self.state.model.wavenet.n_quant 313 | 314 | for batch in self.device_loader: 315 | wav, mel, voice_idx, jitter_idx, file_paths, position = batch 316 | if self.data_write_tmpl: 317 | dc = t.jit.script(DataContainer({ 318 | 'mel': mel, 319 | 'wav': wav, 320 | 'voice': voice_idx, 321 | 'jitter': jitter_idx 322 | })) 323 | dc.save(self.data_write_tmpl) 324 | print('saved {}'.format(self.data_write_tmpl)) 325 | 326 | out_template = os.path.join(self.output_dir, 327 | os.path.basename(os.path.splitext(file_paths[0])[0]) 328 | + '.{}.wav') 329 | 330 | if model_scr: 331 | with t.no_grad(): 332 | wav = model_scr(wav, mel, voice_idx, jitter_idx) 333 | else: 334 | wav = self.state.model(wav, mel, voice_idx, jitter_idx) 335 | 336 | wav_orig, wav_sample = wav[0,...], wav[1:,...] 337 | 338 | # save results to specified files 339 | for i in range(self.n_replicas): 340 | wav_final = util.mu_decode_torch(wav_sample[i], n_quant) 341 | path = out_template.format('rep' + str(i)) 342 | librosa.output.write_wav(path, wav_final.cpu().numpy(), self.sample_rate) 343 | 344 | wav_final = util.mu_decode_torch(wav_orig, n_quant) 345 | path = out_template.format('orig') 346 | librosa.output.write_wav(path, wav_final.cpu().numpy(), self.sample_rate) 347 | 348 | print('Wrote {}'.format( 349 | out_template.format('0-'+str(self.n_replicas-1)))) 350 | 351 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import io 3 | import torch as t 4 | from sys import stderr 5 | import util 6 | import data 7 | import mfcc_inverter as mi 8 | import hparams 9 | 10 | try: 11 | import torch_xla.core.xla_model as xm 12 | except ModuleNotFoundError: 13 | pass 14 | 15 | class Checkpoint(object): 16 | ''' 17 | Encapsulates full state of training 18 | ''' 19 | 20 | def __init__(self, override_hps, dat_file, train_mode=True, ckpt_file=None, 21 | num_replicas=1, rank=0): 22 | """ 23 | Initialize total state 24 | """ 25 | if ckpt_file is not None: 26 | ckpt = t.load(ckpt_file) 27 | if 'hps' in ckpt: 28 | hps = hparams.Hyperparams(**ckpt['hps']) 29 | else: 30 | hps = hparams.Hyperparams() 31 | hps.update(override_hps) 32 | 33 | t.manual_seed(hps.random_seed) 34 | 35 | if hps.global_model == 'autoencoder': 36 | self.model = ae.AutoEncoder(hps) 37 | elif hps.global_model == 'mfcc_inverter': 38 | self.model = mi.MfccInverter(hps) 39 | 40 | slice_size = self.model.get_input_size(hps.n_win_batch) 41 | 42 | self.data = data.DataProcessor(hps, dat_file, self.model.mfcc, 43 | slice_size, train_mode, start_epoch=0, start_step=0, 44 | num_replicas=num_replicas, rank=rank) 45 | 46 | self.model.override(hps.n_win_batch) 47 | 48 | if ckpt_file is None: 49 | self.optim = t.optim.Adam(params=self.model.parameters(), 50 | lr=hps.learning_rate_rates[0]) 51 | self.optim_step = 0 52 | 53 | else: 54 | sub_state = { k: v for k, v in ckpt['model_state_dict'].items() if 55 | '_lead' not in k and 'left_wing_size' not in k } 56 | self.model.load_state_dict(sub_state, strict=False) 57 | if 'epoch' in ckpt: 58 | self.data.dataset.set_pos(ckpt['epoch'], ckpt['step']) 59 | else: 60 | global_step = ckpt['step'] 61 | epoch = global_step // len(self.data.dataset) 62 | step = global_step % len(self.data.dataset) 63 | self.data.dataset.set_pos(epoch, step) 64 | 65 | self.optim = t.optim.Adam(self.model.parameters()) 66 | self.optim.load_state_dict(ckpt['optim']) 67 | self.optim_step = ckpt['optim_step'] 68 | # self.torch_rng_state = ckpt['rand_state'] 69 | # self.torch_cuda_rng_states = ckpt['cuda_rand_states'] 70 | 71 | 72 | self.device = None 73 | self.torch_rng_state = t.get_rng_state() 74 | if t.cuda.is_available(): 75 | self.torch_cuda_rng_states = t.cuda.get_rng_state_all() 76 | else: 77 | self.torch_cuda_rng_states = None 78 | 79 | self.hps = hps 80 | 81 | 82 | def save(self, ckpt_file, epoch, step): 83 | # cur_device = self.device 84 | old_device = self.to(t.device('cpu')) 85 | mstate_dict = self.model.state_dict() 86 | ostate = self.optim.state_dict() 87 | state = { 88 | 'hps': self.hps, 89 | 'epoch': epoch, 90 | 'step': step, 91 | 'optim_step': self.optim_step, 92 | 'model_state_dict': mstate_dict, 93 | 'optim': ostate, 94 | 'rand_state': t.get_rng_state(), 95 | 'cuda_rand_states': (t.cuda.get_rng_state_all() if 96 | t.cuda.is_available() else None) 97 | } 98 | if self.hps.hw in ('GPU', 'CPU'): 99 | t.save(state, ckpt_file) 100 | else: 101 | xm.save(state, ckpt_file, master_only=True) 102 | self.to(old_device) 103 | 104 | def to(self, device): 105 | """Hack to move both model and optimizer to device""" 106 | old_device = self.device 107 | self.device = device 108 | self.model.to(device) 109 | ostate = self.optim.state_dict() 110 | self.optim = t.optim.Adam(self.model.parameters()) 111 | self.optim.load_state_dict(ostate) 112 | return old_device 113 | 114 | def optim_checksum(self): 115 | return util.digest(self.optim.state_dict()) 116 | 117 | def init_torch_generator(self): 118 | """Hack to set the generator state""" 119 | t.set_rng_state(self.torch_rng_state) 120 | #print('saved generator state: {}'.format( 121 | # util.tensor_digest(self.torch_cuda_rng_states))) 122 | #t.cuda.set_rng_state_all(self.torch_cuda_rng_states) 123 | if t.cuda.is_available(): 124 | if self.torch_cuda_rng_states is not None: 125 | t.cuda.set_rng_state(self.torch_cuda_rng_states[0]) 126 | ndiff = t.cuda.get_rng_state().ne(self.torch_cuda_rng_states[0]).sum() 127 | if ndiff != 0: 128 | print(('Warning: restored and checkpointed ' 129 | 'GPU state differs in {} positions').format(ndiff), file=stderr) 130 | stderr.flush() 131 | 132 | def update_learning_rate(self, learning_rate): 133 | for g in self.optim.param_groups: 134 | g['lr'] = learning_rate 135 | 136 | 137 | class InferenceState(object): 138 | """ 139 | Restores a trained model for inference 140 | """ 141 | 142 | def __init__(self, override_hps, dat_file, ckpt_file): 143 | 144 | ckpt = t.load(ckpt_file) 145 | if 'hps' in ckpt: 146 | hps = hparams.Hyperparams(**ckpt['hps']) 147 | hps.update(override_hps) 148 | 149 | if hps.global_model == 'autoencoder': 150 | self.model = ae.AutoEncoder(hps) 151 | elif hps.global_model == 'mfcc_inverter': 152 | self.model = mi.MfccInverter(hps) 153 | 154 | sub_state = { k: v for k, v in ckpt['model_state_dict'].items() if 155 | '_lead' not in k and 'left_wing_size' not in k } 156 | self.model.load_state_dict(sub_state, strict=False) 157 | self.model.override(n_win_batch=1) 158 | 159 | self.data = data.DataProcessor(hps, dat_file, self.model.mfcc, 160 | slice_size=None, train_mode=False) 161 | 162 | self.device = None 163 | 164 | def to(self, device): 165 | self.device = device 166 | self.model.to(device) 167 | 168 | -------------------------------------------------------------------------------- /dat/example_train.log: -------------------------------------------------------------------------------- 1 | nohup python train.py new -nb 32 -af par/arch.basic.json -tf par/train.basic.json -si 1000 -rws 1000000 -fpu 1.0 /mnt/data/ckpt/ae-wavenet/full%.ckpt /mnt/data/librispeech.dev-clean.rdb > ../../full.log 2>& 1 & 2 | nohup: ignoring input 3 | Initializing model parameters 4 | Starting training... 5 | Step Loss AvgProbTarget PeakDist AvgMax 6 | 0 6.66073 0.00394 65.54933 0.09090 7 | 10 5.70685 0.00496 49.70761 0.03354 8 | 20 5.54577 0.00504 47.05232 0.02857 9 | 30 5.43756 0.00510 47.42469 0.01912 10 | 40 5.39568 0.00528 46.57148 0.02194 11 | 50 5.43824 0.00505 47.94815 0.03081 12 | 60 5.22184 0.00660 39.68941 0.03081 13 | 70 5.04230 0.00778 37.66391 0.04293 14 | 80 5.01674 0.00807 34.91858 0.05784 15 | 90 5.01576 0.00638 41.40374 0.03709 16 | 100 4.82980 0.00700 42.97318 0.03141 17 | 110 4.61712 0.01027 34.41427 0.05701 18 | 120 4.65827 0.00853 41.67912 0.05928 19 | 130 4.69492 0.00709 43.26987 0.03658 20 | 140 4.53711 0.00957 38.61590 0.05383 21 | 150 4.62835 0.00739 37.57878 0.05389 22 | 160 5.03130 0.00572 47.09746 0.02914 23 | 170 4.58089 0.00890 37.63877 0.05653 24 | 180 4.50866 0.01026 32.46947 0.04569 25 | 190 4.48754 0.00953 34.47150 0.06440 26 | 200 4.65966 0.00822 38.55568 0.05021 27 | 210 4.63757 0.00766 47.55101 0.05572 28 | 220 4.42633 0.00879 42.35860 0.04989 29 | 230 4.29379 0.01204 34.89320 0.07602 30 | 240 4.37704 0.01092 39.20797 0.07502 31 | 250 4.28365 0.01019 35.05424 0.07012 32 | 260 4.75136 0.00561 47.53735 0.02373 33 | 270 4.47797 0.00601 42.63039 0.03843 34 | 280 4.41622 0.00568 43.76664 0.02510 35 | 290 4.56792 0.00602 41.12871 0.03287 36 | 300 4.36622 0.00586 44.38063 0.03196 37 | 310 4.39905 0.00558 43.79729 0.02890 38 | 320 4.22855 0.00618 38.92936 0.03909 39 | 330 4.12058 0.00713 34.71312 0.05359 40 | 340 3.98082 0.00675 35.73479 0.05900 41 | 350 3.84604 0.00731 28.41164 0.05512 42 | 360 4.15422 0.00547 44.91571 0.04695 43 | 370 3.99540 0.00688 37.48479 0.05936 44 | 380 4.03482 0.00623 31.56238 0.05785 45 | 390 3.97674 0.00579 23.76999 0.06681 46 | 400 3.42920 0.00710 10.98240 0.10305 47 | 410 3.58615 0.01038 15.36434 0.12466 48 | 420 3.76125 0.00750 16.68941 0.11072 49 | 430 3.60955 0.00704 16.94899 0.10144 50 | 440 3.08182 0.01285 10.10728 0.14432 51 | 450 3.70023 0.00537 13.85010 0.07373 52 | 460 3.38741 0.00754 11.04873 0.11829 53 | 470 3.44488 0.00572 12.23707 0.09205 54 | 480 3.21584 0.00569 8.11255 0.09598 55 | 490 3.35601 0.00588 11.77981 0.12003 56 | 500 3.25750 0.00527 10.38757 0.10333 57 | 510 3.33255 0.00467 12.84028 0.08828 58 | 520 3.16834 0.00590 5.73348 0.09262 59 | 530 3.20742 0.00781 8.48767 0.10524 60 | 540 3.38442 0.00780 9.18702 0.10728 61 | 550 3.29940 0.00632 12.48779 0.11790 62 | 560 3.30615 0.00544 11.53520 0.10171 63 | 570 3.24176 0.00577 12.32447 0.10986 64 | 580 3.18723 0.00596 9.05520 0.12297 65 | 590 2.83200 0.00605 5.10560 0.13738 66 | 600 2.98084 0.00712 7.83776 0.14564 67 | 610 2.71847 0.00974 6.09064 0.18782 68 | 620 2.60462 0.01193 4.55316 0.21492 69 | 630 2.56853 0.01185 6.30951 0.21405 70 | 640 3.02742 0.00725 8.61039 0.14214 71 | 650 2.52465 0.00804 3.63374 0.18722 72 | 660 2.44871 0.00659 3.27610 0.19042 73 | 670 2.75736 0.00647 4.63063 0.15198 74 | 680 3.09209 0.00600 9.12428 0.14674 75 | 690 3.03678 0.00598 9.00970 0.13737 76 | 700 2.68628 0.00599 4.88578 0.15989 77 | 710 3.04777 0.00545 8.42217 0.12913 78 | 720 3.04865 0.00568 8.44540 0.14802 79 | 730 3.08064 0.00562 8.53616 0.14018 80 | 740 3.18428 0.00902 9.89943 0.12933 81 | 750 2.88879 0.01007 6.12201 0.13028 82 | 760 2.68636 0.00871 4.67301 0.17382 83 | 770 3.03822 0.00585 8.68355 0.12803 84 | 780 2.84694 0.00625 7.02359 0.14687 85 | 790 2.80414 0.00617 6.16020 0.16501 86 | 800 3.01958 0.00561 7.26880 0.14766 87 | 810 2.89497 0.00602 6.64691 0.15850 88 | 820 2.72111 0.00632 5.97689 0.16337 89 | 830 2.93951 0.00675 8.26796 0.14349 90 | 840 3.14187 0.00866 8.63182 0.14626 91 | 850 3.13478 0.00553 8.35021 0.13230 92 | 860 3.03846 0.00920 8.96348 0.16501 93 | 870 3.19560 0.00748 9.75778 0.13621 94 | 880 3.14135 0.00719 8.95714 0.13182 95 | 890 3.00126 0.00599 6.54945 0.13239 96 | 900 2.91641 0.00561 6.87955 0.13830 97 | 910 2.94487 0.00594 6.99282 0.14684 98 | 920 3.08267 0.00683 8.64104 0.12598 99 | 930 3.07904 0.00531 8.42158 0.12234 100 | 940 3.04743 0.00539 8.69109 0.12576 101 | 950 3.23008 0.00606 8.84710 0.12495 102 | 960 3.12322 0.00542 8.80400 0.12674 103 | 970 3.11117 0.00729 7.81777 0.12648 104 | 980 3.18191 0.00658 9.66248 0.13055 105 | 990 2.76402 0.01067 6.31956 0.16577 106 | 1000 2.68206 0.00579 4.89368 0.16145 107 | Saved checkpoint to /mnt/data/ckpt/ae-wavenet/full1000.ckpt 108 | 1010 3.01859 0.00592 7.73563 0.13317 109 | 1020 3.04369 0.00583 8.96636 0.14326 110 | 1030 2.91307 0.00548 6.85656 0.14569 111 | 1040 2.87419 0.00724 7.15362 0.16057 112 | 1050 2.76969 0.00838 6.20438 0.16831 113 | 1060 2.94483 0.00670 8.89775 0.14244 114 | 1070 2.87404 0.00693 7.11829 0.16822 115 | 1080 2.77498 0.00886 5.56956 0.16203 116 | 1090 3.33083 0.00597 10.08393 0.11377 117 | 1100 3.18288 0.00591 9.46564 0.12195 118 | 1110 2.87133 0.01001 6.84842 0.17776 119 | 1120 2.79739 0.00662 6.84950 0.17672 120 | 1130 2.88968 0.00480 7.50335 0.15091 121 | 1140 2.61876 0.00825 8.11111 0.20580 122 | 1150 2.78869 0.00584 6.41786 0.17552 123 | 1160 2.98100 0.00703 7.92541 0.15671 124 | 1170 2.74401 0.00575 5.07196 0.16881 125 | 1180 2.67681 0.00882 5.53137 0.18611 126 | 1190 2.76825 0.00799 6.20306 0.18375 127 | 1200 2.79742 0.00704 7.63170 0.18185 128 | 1210 2.46790 0.00627 3.78927 0.19234 129 | 1220 2.90801 0.00728 8.20881 0.16132 130 | 1230 2.94563 0.00660 6.31334 0.15890 131 | 1240 3.01796 0.00650 7.23300 0.15346 132 | 1250 2.93731 0.00712 7.43499 0.16446 133 | 1260 3.25980 0.00537 10.42385 0.10912 134 | 1270 3.14164 0.00538 8.52323 0.13414 135 | 1280 2.81126 0.00542 5.84614 0.15412 136 | 1290 2.75833 0.00725 6.37057 0.17811 137 | 1300 2.53782 0.00976 5.23707 0.21064 138 | 1310 2.59489 0.00702 4.31298 0.17110 139 | 1320 2.86540 0.00625 6.86219 0.15155 140 | 1330 2.71670 0.00683 5.17648 0.17313 141 | 1340 2.75158 0.00577 5.43894 0.16158 142 | 1350 2.87513 0.01035 7.31082 0.18132 143 | 1360 2.61028 0.00953 5.10716 0.19037 144 | 1370 2.72034 0.00663 5.33321 0.16042 145 | 1380 2.48532 0.00928 4.43499 0.20115 146 | 1390 2.83902 0.00678 5.61147 0.16047 147 | 1400 2.63733 0.01322 5.32795 0.20977 148 | 1410 2.78052 0.01005 7.09447 0.18074 149 | 1420 2.60327 0.00967 4.64033 0.19721 150 | 1430 3.04716 0.00643 7.88733 0.14687 151 | 1440 2.74379 0.00738 6.05795 0.18900 152 | 1450 2.99852 0.00678 7.33118 0.14244 153 | 1460 3.06846 0.00662 7.91104 0.13758 154 | 1470 2.85562 0.00680 6.11267 0.15456 155 | 1480 2.72737 0.00593 6.75599 0.17853 156 | 1490 2.31471 0.00755 3.51557 0.21285 157 | 1500 2.62450 0.00694 4.73288 0.18410 158 | 1510 3.22426 0.00710 9.86794 0.12179 159 | 1520 3.28708 0.00624 11.16499 0.11534 160 | 1530 3.12092 0.00752 8.91595 0.14254 161 | 1540 3.04308 0.00757 6.85536 0.14137 162 | 1550 3.06093 0.00783 7.07974 0.13139 163 | 1560 3.23815 0.00725 8.66403 0.11636 164 | 1570 2.74154 0.00877 5.78963 0.17411 165 | 1580 2.42347 0.01071 4.75970 0.22790 166 | 1590 2.76834 0.00827 6.08166 0.17412 167 | 1600 2.77775 0.00729 7.38685 0.18044 168 | 1610 2.76105 0.00613 6.33285 0.16992 169 | 1620 2.72077 0.00841 6.74593 0.20177 170 | 1630 2.70875 0.00683 6.76664 0.18193 171 | 1640 2.93674 0.00829 7.46085 0.17790 172 | 1650 2.73079 0.00665 5.49054 0.16728 173 | 1660 2.60637 0.00796 4.47102 0.18272 174 | 1670 2.59005 0.00896 6.81609 0.23257 175 | 1680 2.61570 0.00618 5.58178 0.19591 176 | 1690 2.39183 0.00828 3.72557 0.22694 177 | 1700 2.79557 0.00653 6.42565 0.17750 178 | 1710 3.11821 0.00635 7.95989 0.14060 179 | 1720 2.97524 0.00679 7.80747 0.16221 180 | 1730 2.85840 0.00591 6.71755 0.15744 181 | 1740 2.47712 0.01213 5.14116 0.21832 182 | 1750 2.88432 0.00767 7.10704 0.15732 183 | 1760 2.66427 0.00820 4.57399 0.17766 184 | 1770 3.01803 0.00582 7.88685 0.14526 185 | 1780 2.80344 0.00781 6.49306 0.16621 186 | 1790 2.73260 0.00570 6.61063 0.15813 187 | 1800 2.45691 0.00597 4.95474 0.20820 188 | 1810 2.81347 0.00633 6.31058 0.16622 189 | 1820 2.68905 0.00646 5.41104 0.17391 190 | 1830 2.55818 0.00708 5.58214 0.19659 191 | 1840 2.53600 0.01061 5.03784 0.21235 192 | 1850 2.47984 0.00880 4.96648 0.19531 193 | 1860 2.76543 0.01000 6.41403 0.18302 194 | 1870 2.57636 0.01034 4.38745 0.20862 195 | 1880 2.79167 0.00781 5.66798 0.16167 196 | 1890 2.62978 0.01026 5.07818 0.20027 197 | 1900 2.95898 0.00715 6.23467 0.13612 198 | 1910 2.75976 0.00647 6.91140 0.18266 199 | 1920 2.93959 0.00722 6.63446 0.15195 200 | 1930 2.86457 0.00791 5.42337 0.14320 201 | 1940 3.00291 0.00743 6.25431 0.13613 202 | 1950 2.90500 0.00734 5.90218 0.13987 203 | 1960 2.62112 0.00821 4.44995 0.18828 204 | 1970 2.49094 0.00929 3.86015 0.20534 205 | 1980 2.46517 0.00837 3.63805 0.20196 206 | 1990 2.32326 0.00668 3.29430 0.20933 207 | 2000 2.59924 0.00607 5.69864 0.19824 208 | Saved checkpoint to /mnt/data/ckpt/ae-wavenet/full2000.ckpt 209 | 2010 2.47777 0.00696 4.15374 0.19653 210 | 2020 2.64001 0.00590 6.59363 0.19983 211 | 2030 2.65116 0.00888 5.37512 0.18759 212 | 2040 2.72883 0.00593 6.57639 0.17616 213 | 2050 2.60458 0.00582 4.11135 0.17200 214 | 2060 2.62336 0.01420 5.46863 0.19670 215 | 2070 2.63927 0.00821 4.63625 0.19126 216 | 2080 2.55294 0.00923 4.49988 0.19589 217 | 2090 2.89332 0.00666 7.76808 0.17455 218 | 2100 2.68968 0.00602 5.60357 0.18078 219 | 2110 2.56659 0.00623 4.82627 0.19576 220 | 2120 3.11226 0.00724 8.40757 0.12526 221 | 2130 3.01763 0.00751 6.68858 0.12949 222 | 2140 2.97086 0.00793 6.17002 0.13460 223 | 2150 3.17475 0.00564 10.82076 0.12155 224 | 2160 3.61512 0.00616 11.06873 0.11316 225 | 2170 3.11458 0.00591 6.81322 0.11305 226 | 2180 3.41845 0.00628 10.56130 0.10135 227 | 2190 3.05680 0.00553 7.37213 0.13642 228 | 2200 2.71041 0.00584 5.92277 0.17291 229 | 2210 2.84167 0.00623 6.15888 0.15410 230 | 2220 2.58086 0.00751 6.17780 0.21406 231 | 2230 2.65336 0.00890 5.43774 0.19140 232 | 2240 2.70330 0.00787 5.45283 0.17768 233 | 2250 2.72741 0.00680 5.97019 0.17529 234 | 2260 2.63708 0.00878 6.49054 0.20948 235 | 2270 2.51103 0.00781 4.69097 0.21464 236 | 2280 2.71962 0.00687 6.83261 0.18552 237 | 2290 2.67264 0.00708 6.18977 0.18443 238 | 2300 2.74768 0.00727 8.11566 0.20395 239 | 2310 3.01899 0.00566 8.53676 0.14185 240 | 2320 2.98119 0.00572 8.53317 0.14350 241 | 2330 2.64680 0.00723 6.14871 0.21424 242 | 2340 2.69388 0.01163 5.68666 0.20296 243 | 2350 3.13728 0.00559 9.14009 0.13116 244 | 2360 3.24074 0.00551 9.66966 0.11507 245 | 2370 3.05343 0.00533 8.23755 0.12897 246 | 2380 2.60543 0.00573 4.72557 0.17856 247 | 2390 2.65020 0.00640 6.53855 0.18530 248 | 2400 2.66701 0.00574 5.77275 0.17889 249 | 2410 2.78377 0.00704 6.84339 0.17638 250 | 2420 2.78696 0.00556 7.35129 0.18201 251 | 2430 2.55613 0.00623 5.66056 0.20140 252 | 2440 2.94221 0.00688 9.38146 0.16041 253 | 2450 2.93039 0.00582 6.96336 0.14550 254 | 2460 3.01077 0.00577 8.90649 0.14891 255 | 2470 2.85082 0.00562 6.69097 0.15403 256 | 2480 2.65768 0.00615 4.50587 0.17242 257 | 2490 2.48694 0.00609 5.07028 0.21959 258 | 2500 2.78177 0.00577 6.62356 0.17209 259 | 2510 3.25047 0.00603 8.13135 0.10615 260 | 2520 3.16453 0.00495 9.37213 0.11942 261 | 2530 3.04452 0.00764 8.83094 0.12263 262 | 2540 3.04047 0.00633 8.38985 0.14717 263 | 2550 3.10786 0.00576 8.58465 0.13274 264 | 2560 3.07376 0.00565 8.52814 0.13471 265 | 2570 2.52240 0.01181 6.08034 0.22120 266 | 2580 2.43037 0.00714 4.54430 0.21038 267 | 2590 2.58623 0.00716 5.79969 0.21061 268 | 2600 2.56126 0.00794 4.49306 0.20453 269 | 2610 2.52293 0.00658 3.78053 0.19078 270 | 2620 2.56859 0.00764 4.34507 0.19926 271 | 2630 2.76011 0.00712 6.52179 0.17587 272 | 2640 2.57426 0.01008 5.67792 0.22149 273 | 2650 2.74378 0.00638 5.22067 0.16952 274 | 2660 2.69279 0.00813 5.15398 0.19908 275 | 2670 2.76138 0.00539 6.13350 0.16482 276 | 2680 2.55628 0.00722 5.63434 0.22609 277 | 2690 2.55226 0.00690 5.63937 0.21755 278 | 2700 3.07587 0.00750 8.16212 0.12975 279 | -------------------------------------------------------------------------------- /dat/librispeech.some.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrbigelow/ae-wavenet/80b9c46637151f053f74728fc756f1d01ab0aa69/dat/librispeech.some.dat -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Preprocess Data 2 | from sys import stderr, exit 3 | import pickle 4 | import numpy as np 5 | import torch as t 6 | from torch.utils.data import Dataset, DataLoader, Sampler, SequentialSampler 7 | import jitter 8 | from torch import nn 9 | import vconv 10 | import copy 11 | import parse_tools 12 | from hparams import setup_hparams 13 | from collections import namedtuple 14 | 15 | import util 16 | import mfcc 17 | 18 | 19 | def parse_catalog(sam_file): 20 | try: 21 | catalog = [] 22 | with open(sam_file) as sam_fh: 23 | for s in sam_fh.readlines(): 24 | (vid, wav_path) = s.strip().split('\t') 25 | catalog.append([int(vid), wav_path]) 26 | except (FileNotFoundError, IOError): 27 | raise RuntimeError("Couldn't open or read samples file {}".format(sam_file)) 28 | return catalog 29 | 30 | def convert(catalog, dat_file, n_quant, sample_rate=16000): 31 | """ 32 | Convert all input data and save a dat file 33 | """ 34 | import librosa 35 | if n_quant <= 2**8: 36 | snd_dtype = np.uint8 37 | elif n_quant <= 2**15: 38 | snd_dtype = np.int16 39 | else: 40 | snd_dtype = np.int32 41 | 42 | n_mel_chan = None 43 | speaker_ids = set(id for id,__ in catalog) 44 | speaker_id_map = dict((v,k) for k,v in enumerate(speaker_ids)) 45 | snd_data = np.empty((0), dtype=snd_dtype) 46 | samples = [] 47 | 48 | for (voice_id, snd_path) in catalog: 49 | snd, _ = librosa.load(snd_path, sample_rate) 50 | snd_mu = util.mu_encode_np(snd, n_quant).astype(snd_dtype) 51 | wav_b = len(snd_data) 52 | wav_e = wav_b + len(snd_mu) 53 | snd_data.resize(wav_e) 54 | snd_data[wav_b:wav_e] = snd_mu 55 | samples.append( 56 | SpokenSample( 57 | voice_index=speaker_id_map[voice_id], 58 | wav_b=wav_b, wav_e=wav_e, 59 | # mel_b=mel_b, mel_e=mel_e, 60 | file_path=snd_path 61 | ) 62 | ) 63 | if len(samples) % 100 == 0: 64 | print('Converted {} files of {}.'.format(len(samples), 65 | len(catalog), file=stderr)) 66 | stderr.flush() 67 | 68 | with open(dat_file, 'wb') as dat_fh: 69 | state = { 70 | 'samples': samples, 71 | 'snd_dtype': snd_dtype, 72 | 'snd_data': snd_data 73 | } 74 | pickle.dump(state, dat_fh) 75 | 76 | 77 | 78 | SpokenSample = namedtuple('SpokenSample', [ 79 | 'voice_index', # index of the speaker for this sample 80 | 'wav_b', # start position of sample in full wav data buffer 81 | 'wav_e', # end position of sample in full wav data buffer 82 | 'file_path' # path to .wav file for this sample 83 | ] 84 | ) 85 | 86 | 87 | class LoopingRandomSampler(Sampler): 88 | def __init__(self, dataset, num_replicas=1, rank=0, start_epoch=0): 89 | super().__init__(dataset) 90 | self.dataset = dataset 91 | self.num_replicas = num_replicas 92 | self.rank = rank 93 | self.epoch = start_epoch 94 | print(f'LoopingRandomSampler with {self.rank} out of {self.num_replicas}', file=stderr) 95 | 96 | def __iter__(self): 97 | def _gen(): 98 | while True: 99 | g = t.Generator() 100 | g.manual_seed(self.epoch * self.num_replicas + self.rank) 101 | n = len(self.dataset) 102 | vals = list(range(self.rank, n, self.num_replicas)) 103 | perms = t.randperm(len(vals), generator=g).tolist() 104 | print(f'LoopingRandomSampler: first 10 perms: {perms[:10]}', 105 | file=stderr) 106 | indices = [vals[i] for i in perms] 107 | for i in indices: 108 | yield i 109 | self.epoch += 1 110 | 111 | return _gen() 112 | 113 | def __len__(self): 114 | return int(2**31) 115 | 116 | 117 | 118 | 119 | def load_data(dat_file): 120 | try: 121 | with open(dat_file, 'rb') as dat_fh: 122 | dat = pickle.load(dat_fh) 123 | except IOError: 124 | print(f'Could not open preprocessed data file {dat_file}.', file=stderr) 125 | stderr.flush() 126 | return dat 127 | 128 | class TrackerDataset(Dataset): 129 | """ 130 | Tracks and provides the epoch and step. 131 | If using with replicas and a subsetting sampler that samples 132 | 1/sampling_freq of the dataset 133 | """ 134 | def __init__(self, dataset, start_epoch=0, start_step=0, sampling_freq=1): 135 | self.dataset = dataset 136 | self.epoch = start_epoch 137 | self.step = start_step 138 | self.sampling_freq = sampling_freq 139 | self.len = None 140 | 141 | def __len__(self): 142 | if self.len is None: 143 | self.len = len(self.dataset) 144 | return self.len 145 | 146 | def __getitem__(self, item): 147 | self.step += self.sampling_freq 148 | if self.step >= len(self): 149 | self.epoch += 1 150 | self.step = 0 151 | return self.dataset[item], self.epoch, self.step 152 | 153 | def set_pos(self, epoch, step): 154 | self.epoch = epoch 155 | self.step = step 156 | 157 | 158 | class SliceDataset(Dataset): 159 | """ 160 | Return slices of wav files of fixed size 161 | """ 162 | def __init__(self, slice_size, n_win_batch): 163 | self.slice_size = slice_size 164 | self.n_win_batch = n_win_batch 165 | self.in_start = [] 166 | 167 | 168 | def load_data(self, dat_file): 169 | dat = load_data(dat_file) 170 | self.samples = dat['samples'] 171 | self.snd_data = dat['snd_data'].astype(dat['snd_dtype']) 172 | 173 | w = self.n_win_batch 174 | for sam in self.samples: 175 | for b in range(sam.wav_b, sam.wav_e - self.slice_size, w): 176 | self.in_start.append((b, sam.voice_index)) 177 | 178 | def num_speakers(self): 179 | ns = max(s.voice_index for s in self.samples) + 1 180 | return ns 181 | 182 | def __len__(self): 183 | return len(self.in_start) 184 | 185 | def __getitem__(self, item): 186 | s, voice_ind = self.in_start[item] 187 | return self.snd_data[s:s + self.slice_size], voice_ind 188 | 189 | 190 | 191 | class WavFileDataset(Dataset): 192 | """ 193 | Returns entire wav files 194 | """ 195 | def __init__(self): 196 | super().__init__() 197 | 198 | def load_data(self, dat_file): 199 | dat = load_data(dat_file) 200 | self.samples = dat['samples'] 201 | self.snd_data = dat['snd_data'].astype(dat['snd_dtype']) 202 | 203 | def num_speakers(self): 204 | ns = max(s.voice_index for s in self.samples) + 1 205 | return ns 206 | 207 | def __len__(self): 208 | return len(self.samples) 209 | 210 | def __getitem__(self, item): 211 | sam = self.samples[item] 212 | return (self.snd_data[sam.wav_b:sam.wav_e], 213 | sam.voice_index, 214 | sam.file_path) 215 | 216 | 217 | class Collate(): 218 | def __init__(self, mfcc, jitter, train_mode): 219 | self.train_mode = train_mode 220 | self.mfcc = mfcc 221 | self.jitter = jitter 222 | 223 | def __call__(self, batch): 224 | data = [b[0] for b in batch] 225 | 226 | # epoch, step 227 | position = t.tensor(batch[-1][1:]) 228 | 229 | wav = t.stack([t.from_numpy(d[0]) for d in data]).float() 230 | mel = t.stack([t.from_numpy(self.mfcc(d[0])) for d in 231 | data]).float() 232 | voice = t.tensor([d[1] for d in data]).long() 233 | jitter = t.stack([t.from_numpy(self.jitter(mel.size()[2])) for _ in 234 | range(len(data))]).long() 235 | 236 | if self.train_mode: 237 | return wav, mel, voice, jitter, position 238 | else: 239 | paths = [b[0][2] for b in batch] 240 | return wav, mel, voice, jitter, paths, position 241 | 242 | 243 | class DataProcessor(): 244 | def __init__(self, hps, dat_file, mfcc_func, slice_size, train_mode, 245 | start_epoch=0, start_step=0, num_replicas=1, rank=0): 246 | super().__init__() 247 | jitter_func = jitter.Jitter(hps.jitter_prob) 248 | 249 | train_collate_fn = Collate(mfcc_func, jitter_func, train_mode=True) 250 | test_collate_fn = Collate(mfcc_func, jitter_func, train_mode=False) 251 | 252 | if train_mode: 253 | slice_dataset = SliceDataset(slice_size, hps.n_win_batch) 254 | slice_dataset.load_data(dat_file) 255 | stderr.flush() 256 | self.dataset = TrackerDataset(slice_dataset, start_epoch, 257 | start_step, sampling_freq=num_replicas) 258 | self.sampler = LoopingRandomSampler(self.dataset, num_replicas, 259 | rank, start_epoch) 260 | self.loader = DataLoader(self.dataset, sampler=self.sampler, 261 | # If set >0, multiprocessing is used, which prevents 262 | # getting accurate position information 263 | num_workers=hps.n_loader_workers, 264 | batch_size=hps.n_batch, pin_memory=False, 265 | collate_fn=train_collate_fn) 266 | else: 267 | wav_dataset = WavFileDataset() 268 | wav_dataset.load_data(dat_file) 269 | self.dataset = TrackerDataset(wav_dataset, 0, 0) 270 | self.sampler = SequentialSampler(self.dataset) 271 | self.loader = DataLoader(self.dataset, batch_size=1, 272 | sampler=self.sampler, pin_memory=False, drop_last=False, 273 | collate_fn=test_collate_fn) 274 | 275 | @property 276 | def global_step(self): 277 | return len(self.dataset) * self.dataset.epoch + self.dataset.step 278 | 279 | """ 280 | @property 281 | def epoch(self): 282 | return self.dataset.epoch 283 | 284 | @property 285 | def step(self): 286 | return self.dataset.step 287 | """ 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /doc/combining_vae_and_ar.txt: -------------------------------------------------------------------------------- 1 | The original formulation of the autoencoder is: 2 | 3 | Q(z|x) 4 | P(x|z) 5 | 6 | However, since WaveNet is autoregressive, this formulation doesn't fit. 7 | Rather, WaveNet models: 8 | 9 | P(x_t|x_1..x_(t-1),c_1..c_(t-1)) 10 | 11 | where c_t means "the local conditioning vector at time t". 12 | 13 | I'm not sure how to reconcile the VAE formula with the Autoregressive structure 14 | of WaveNet. It's not clear what the paper does, nor are there many details 15 | that I could find in several references. Here is my best guess. 16 | 17 | Let's make the following shorthand, separating notation: 18 | 19 | x: x_t 20 | a: x_1..x_(t-1) 21 | z: c_1..c_(t-1) 22 | 23 | Then, we can write the encoder and decoder as: 24 | 25 | Q(z|x,a) 26 | P(x|z,a) 27 | 28 | I think that introducing a common condition 'a' for both the encoder and 29 | decoder doesn't change the essential theory. I don't understand VAE theory 30 | well enough to know this for sure though. Assuming this is okay, if we let x 31 | be just one timestep, then, in this particular architecture, z is {c_(t-2000), 32 | c_(t-1680), ..., c_(t-80)}, because 1) the encoder produces one conditioning 33 | vector every 320 timesteps, and 2) WaveNet's receptive field is the window 34 | [t-2000, t-1] roughly. In this formulation, then, for the ELBO, we model P(z) 35 | with a diagonal multivariate Gaussian N(0, I), which will have ~ 64 * 6 36 | dimensions (64 channels for each local conditioning vector). 37 | 38 | But, there is still the issue of how to combine multiple consecutive timesteps. 39 | Note that this training style is NOT equivalent to letting x = {x_t, ..., 40 | x_(t+b)} for b a batch size. If we did that, then P(x|z,a) would be a product 41 | of individual regressive steps. Instead, we want to average each of 42 | P(x_t|z,a), P(x_(t+1)|z,a), ... P(x_(t+b)|z,a) together for the SGD gradient 43 | calculation. 44 | 45 | Logically, it seems consistent to define z as the exact set of local 46 | conditioning vectors that fall within WaveNet's receptive field. Because of 47 | the one-every-320 steps frequency, there will be 320 distinct training examples 48 | that all use the same set of local conditioning vectors. Then, the next sample 49 | will drop the oldest one and pick up a new one, and use these for the next 319. 50 | 51 | The VAE training objective as given in the paper (equation 3) is: 52 | 53 | L = - log p(x | z_q(x)) 54 | + ||sg(z_e(x)) - e_(q(x))||^2 55 | + gamma * || z_e(x) - sg(e_q(x))||^2 56 | 57 | Here, the notation is a bit confusing: 58 | 59 | q(x) = argmin_i|| z_e(x) - e_i||^2 60 | z_e(x) means Q(z|x,a) in the above. 61 | 62 | I simply rewrite in their terms, but separating out the x into the current 63 | timestep x and the autoregressive context a. 64 | 65 | So: 66 | 67 | We now write z_e(x,a) instead of z_e(x), and: 68 | q(x,a) = argmin_i|| z_e(x,a) - e_i||^2 69 | 70 | L = -log p(x|z_q(x,a)) 71 | + ||sg(z_e(x,a)) - e_(q(x,a))||^2 72 | + gamma * ||z_e(x,a) - sg(e_q(x,a))||^2 73 | 74 | (NOTE: I've added a minus sign to the first term. They say the first term is 75 | the "negative log-likelihood of the reconstruction", so I believe there should 76 | be a minus sign. It is given as log p(x | z_q(x)) in the paper and in the 77 | VQ-VAE paper. I think this is a mistake) 78 | 79 | Back up a bit, and consider how WaveNet is trained alone. A single logical sample 80 | from WaveNet is sampling a 256-way softmax representing the amplitude at a single 81 | timestep. This is supervised by providing the one-hot encoded target amplitude, 82 | and minimizing the cross-entropy H(t,p), where t is the one-hot target and p is the 83 | 256-way softmax. The next timestep overlaps almost completely in the activations 84 | across the stacks. So, instead of re-calculating all of that, timesteps are batched 85 | together, and the cross-entropies are individually taken at each timestep, H(t_i,p_i) 86 | of the batch i=(1..b). 87 | 88 | This technique seems to be nothing more than a caching mechanism combined with 89 | a batching mechanism, in which just another dimension of every calculation is 90 | used to batch all calculations of activations and gradients, just like the 91 | typical batching. The only tricky thing is that the second batch dimension 92 | (the time dimension) happens also to overlap with the caching. 93 | 94 | Would this form of timestep batching also work for the full VQ-VAE? What are the 95 | requirements of a loss function in order for such a thing to work? 96 | 97 | I believe it will work naturally as long as all of the loss terms are properly 98 | propagated forward. So, how does the decoder receive its conditioning inputs? 99 | 100 | 1. The Cross Entropy Loss is averaged across all timesteps. 101 | 2. Each individual CE loss term draws its input down through WaveNet's 102 | dilated stack, which reaches further back in time at each level. 103 | 3. Each level receives the upsampled conditioning vectors 104 | 4. The upsampling module draws on lower-frequency conditioning vectors 105 | 106 | So, propagating the derivatives of each CE loss term automatically accumulates 107 | the accumulated gradients to each of these vectors. 108 | 109 | The question is 110 | 111 | Rewriting this in the separating notation 112 | 113 | L = - log p(x | z_q(x,a)) 114 | 115 | Note that in this expression, we can batch together overlapping contexts. 116 | 117 | If we let: 118 | 119 | x_1, a_1 = x_t, x_1..x_(t-1) 120 | x_2, a_2 = x_(t+1), x_2..x_t 121 | ... 122 | x_b, a_b = x_(t+b-1), x_b..x_(t+b-2) 123 | 124 | for some batch size b 125 | 126 | For the SGD calculation, we need to compute the gradient of: 127 | 128 | grad J_VAE(theta, phi, x_1, a_1) + 129 | grad J_VAE(theta, phi, x_2, a_2) + 130 | ... 131 | grad J_VAE(theta, phi, x_b, a_b) 132 | 133 | Note that we can calculatej 134 | But, note that the first term, E_q(z|x,a;phi)[log p(x|z,a; theta)], consists 135 | only of 136 | 137 | When averaging gradients through the model for each of these timesteps, note 138 | that the first term (the expectation) the appropriate set of embedding vectors 139 | will automatically be discovered due to the connectivity. But, for the 140 | KL-divergence term, the number of times each local vector is used will follow a 141 | truncated triangle pattern. The full divergence penalty should be the 142 | per-sample average, where each 320 samples use a different set of vectors. 143 | 144 | The solution seems to be to let PyTorch's autograd do the work. Propagate the 145 | two norms through WaveNet in such a way that each one is multiplied by the 146 | number of times it is actually used within WaveNet's stacks. I need to clarify 147 | the notion of "number of times used". 148 | 149 | First, let's outline how the local conditioning vectors make their way through 150 | WaveNet. All filters have size 2, and a dilation that increases in powers of 2 151 | as the layer goes up. So, for instance, layer 10 (in either block) uses h[t] 152 | and h[t-512] to produce output at h[t]. The filter just below uses h[t], 153 | h[t-256], h[t-512], h[t-768]. For i = t - s, we show s for the top three 154 | layers: 155 | 156 | 157 | 0 0 0 0 158 | 64 159 | 128 128 160 | 192 161 | 256 256 256 162 | 320 163 | 384 384 164 | 448 165 | 512 512 512 512 166 | 576 167 | 640 640 168 | 704 169 | 768 768 768 170 | 832 171 | 896 896 172 | 173 | So, each local conditioning vector gets used a different number of times by a 174 | particular timestep of WaveNet. On average, each one is used the same number 175 | of times as the prediction position moves forward. But, within a small batch, 176 | the particular ones that get used depend on the exact positions. 177 | 178 | These local conditioning vectors are the output of an upsampling procedure. 179 | So, each up-sampled vector represents the result of four transpose 180 | convolutions. But, what grounds do you define the notion of the number of 181 | times each quantized vector is used in a particular upsampled vector? For one 182 | thing, the upsampling (like any convolution) does a transformation on each 183 | input vector, and then does element-wise sum of the result of the transform. 184 | If we regard the determinant of the transformation as the "amount" of the 185 | vector used, then, we could also multiply the two l2-norm components by this value. 186 | 187 | But, this is a bit odd, because, whether WaveNet uses "more" of one vector or another 188 | in a particular upsampled conditioning vector, may not be the right interpretation. 189 | If the conditioning elements were simple scalars, then it might be sensible to 190 | use the scalar filter elememts as the "times". Since they are transformations, 191 | we need a determinant. 192 | 193 | What would happen if we instead did the upsampling with the original vectors, and with the 194 | quantized vectors, and recorded the l2 norm between those two quantities? Would this 195 | be the same as taking an L2 Norm of the representative vectors and then taking some 196 | combination of them? 197 | 198 | Let: 199 | 200 | {l0, l5, l10, l15, l20} be the computed local condition vectors 201 | {q0, q5, q10, q15, q20} be the nearest neighbor quantized vectors 202 | {F0, F1, ..., F25} be the filter matrices 203 | 204 | The transpose convolution will be: 205 | 206 | Uq = F0 q0 + F5 q5 + F10 q10 + F15 q15 + F20 q20 207 | Ul = F0 l0 + F5 l5 + F10 l10 + F15 l15 + F20 l20 208 | 209 | Uq - Ul = Conv(F, (q - l)) 210 | 211 | We want to know the overall L2 norm of the difference between the q and l 212 | vectors. Note that: 213 | 214 | Uq = Conv(F, q) 215 | Ul = Conv(F, l) 216 | Uq - Ul = Conv(F, (q - l)) 217 | ||Uq - Ul|| = ||Conv(F, (q - l))|| 218 | 219 | 220 | ||Fg||^2 = |F| ||g||^2 221 | 222 | ||sum_i(F_i g_i)||^2 = 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /doc/commitment_loss_and_batching.txt: -------------------------------------------------------------------------------- 1 | Commitment Loss and batching in the Autoregressive dimension 2 | 3 | The Commitment loss in the model is: 4 | 5 | || z_{e}(x) - e_{q(x)} || ^2 6 | 7 | (which comes in two forms due to technical reasons for gradient calculation: 8 | 9 | CL1 = || sg(z_{e}(x) - e_{q(x)} ||^2 10 | CL2 = || z_{e}(x) - sg(e_{q(x)} ||^2 11 | 12 | In these formulas, z will consist of a range of timesteps at stride of 320. A 13 | single "sample" in this context is a window of wav data the length of one 14 | receptive field of wavenet, with its 256-way softmax output being the 15 | prediction for the sample. But, we want to batch multiple samples in the time 16 | dimension, in order to take advantage of their shared activations. 17 | 18 | But, due to the structure, there will exist 320 consecutive samples that all 19 | use the same collection of z vectors, and the next set of 320 will drop the 20 | first and add one at the end. If the batch size is greater than 320, then this 21 | pattern needs to be taken into account in order to preserve what I would call 22 | the "stochastic batch-size invariance" property: 23 | 24 | Stochastic Batch-size invariance 25 | 26 | In SGD, the average gradient calculated for a batch size of N should be the 27 | same as the average gradient that would be calculated if you averaged the 28 | gradients from N individual samples. This is trivial for the main batch 29 | dimension across different source data, because in this case, all calculations 30 | throughout the network are independent from each other. The only thing they 31 | have in common in the dependency graph are network weights as source nodes. 32 | 33 | But, in the timestep batch dimension, consecutive samples share activations. 34 | Thus, each different sample may use a particular activation a different number 35 | of times. This isn't a problem, because the gradients are all calculated 36 | automatically. However, if the model structure doesn't have a mechanism for 37 | routing activations such as the z conditioning vectors, we need to manually 38 | account for their use or non-use in the model across different samples. 39 | 40 | Is there a simple way to integrate the commitment loss into the autograd in 41 | such a way that it can be accounted for correctly? 42 | 43 | 44 | -------------------------------------------------------------------------------- /doc/generalized_batching.txt: -------------------------------------------------------------------------------- 1 | The nice thing is that we can define the windows used for gradient calculation 2 | by first selecting a sufficient aligned window, and then sub-selecting the 3 | set of predictions we want, which can be a subset of the total number of 4 | predictions arising from the window. 5 | 6 | There are a couple design decisions to be made. 7 | 8 | 1. Truncate each wav file to the nearest maximal receptive field that is needed 9 | to produce an integer number of mels. 10 | 11 | 2. Choose some number of consecutive windows to use. (Or, we may even want to 12 | use skipping windows) 13 | 14 | 3. Back-calculate the receptive field needed for the desired timestep calculations, 15 | using rfield. 16 | 17 | 4. The mask can be used to achieve out-of-register predictions 18 | 19 | 5. The picking logic could be designed in such a way to cluster, or not. 20 | 21 | Either way, the effect of "clumpiness" of data, due to the batch size N, is 22 | one that should be explored. 23 | 24 | rfield has init_nv, but it doesn't have the analogous function to compute the 25 | number of output elements that result from a given number of input elements. 26 | 27 | But, let's review how init_nv works: 28 | 29 | init_nv moves up the chain of parents. on each node, it calls _num_in_elem, 30 | which returns the required number of input elements to achieve the requested 31 | number of output elements. It repeats this until it gets to the end, and then 32 | calls _expand_stats_. 33 | 34 | _expand_stats_ does the opposite: it moves down the chain towards the children, 35 | calling _num_out_elem for the given input elements. 36 | 37 | We can do the reverse as well: 38 | 39 | start at the parent with the 'available number', call 'expand_stats', then call 40 | init_nv with the final number 41 | 42 | Is there a better way to do that? 43 | 44 | We really only need a forward and then reverse pass. 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /doc/loss_terms.txt: -------------------------------------------------------------------------------- 1 | What is happening in the VQVAE? The encoder is presented with windows of the 2 | wav data. For each window, it produces 14 vectors (64-dimensional), spaced one 3 | every 320 timesteps. Each one is matched with its L2-nearest neighbor in a 4 | dictionary of 4096 embedding vectors. The original outputs from the encoder 5 | are named ze. The nearest neighbor embeddings are named zq (q = quantized). 6 | 7 | These zq are used as input to condition the decoder. The decoder also receives 8 | an overlapping window, similar but not identical to the wav data window input 9 | to the encoder. Together with the conditioning zq, the decoder produces 10 | prediction output in autoregressive fashion. 11 | 12 | The log probability given to the correct output is the first loss term. Note 13 | that, because this log probability is derived from logsoftmax, its gradient 14 | will affect all of the incoming logits. 15 | 16 | The gradient from the logsoftmax flows backward through the decoder, but when 17 | it reaches the conditioning inputs zq, the gradient is applied in a 18 | pass-through scheme to the ze corresponding to them. In this way, the encoder 19 | gets an approximate signal as to how it should improve its outputs. 20 | 21 | The pass-through gradient and quantization has two effects. First is that the 22 | decoder sees the same quantized vector zq for multiple timesteps, even though 23 | it may supply a gradient consistently in a particular direction. Meanwhile, 24 | the gradient is passed on to the corresponding ze, which then drifts in the 25 | gradient direction. This drift may bring it closer to the zq or farther away. 26 | But, ultimately, it will end up drifting into another Voronoi cell, and the 27 | decoder will then receive a different zq as input. Theoretically, this process 28 | will find the optimal zq for the decoder for that context. 29 | 30 | The structure of the embedding vectors themselves may be suboptimal as well. 31 | Apart from linear independence, though, it's hard to imagine what properties 32 | are desirable for the distribution of an embedding space. And, whether each 33 | embedding vector is moving so much that its "role" changes from one usage to 34 | the next. 35 | 36 | However, the ability of the model to move the embedding vectors gives it 37 | freedom during the learning process as well. Perhaps the decoder has 38 | difficulty with a particular vector, and needs it to move a bit. This makes 39 | some sense. 40 | 41 | 42 | To effect this, there are loss terms two and three. Loss term two is L2 error, 43 | which is the average L2 distance between the zq and ze. This loss trains the 44 | embedding vectors themselves specifically, and does NOT train the encoder 45 | output ze. 46 | 47 | The third term is the same loss (L2 squared distance) as the second, except it 48 | is scaled by gamma. And, it does NOT train the embedding vectors zq. It does 49 | train the encoder output ze however. 50 | 51 | They call this the "commitment loss". They say the commitment loss is 52 | "introduced to encourage the encoder to produce vectors which lie close to 53 | prototypes. Without the commitment loss, VQ-VAE training can diverge by 54 | emitting representations with unbounded magnitude." 55 | 56 | 57 | So, at any given training step, the following things happen: 58 | 59 | Forward: 60 | 61 | 1. encoder consumes wav snippets and outputs ze 2. the nearest neighbor zq are 62 | found from the embedding table 3. the zq, plus the wav snippet are fed to the 63 | decoder 4. the decoder outputs the logsoftmax probabilities for the next 64 | timestep 65 | 66 | Loss terms: 67 | 68 | 5. the actual next timestep value is compared with the logsoftmax corresponding 69 | to that value. this is the reconstruction loss. 70 | 71 | 6. the squared L2 distance between the zq and ze is used as a loss and 72 | propagated to the embedding vectors zq. this is the L2 error loss. 73 | 74 | 7. the gamma-scaled squared L2 distance between the zq and ze is used as a loss 75 | and propagated to the encoder outputs ze. This is the commitment loss. 76 | 77 | Backward: 78 | 79 | 8. the reconstruction loss produces gradients through the parameters of the 80 | decoder. when the gradients reach the zq inputs to the decoder, they are 81 | propagated directly to the representative ze's. These gradients are modest, 82 | since the reconstruction loss value is in the range of 5-10. 83 | 84 | 9. the L2 error loss value is very large at first, around 50000. Thus it 85 | produces very large gradients, passing them on to the embedding table vectors. 86 | This produces pressure to move the zq towards their respective ze. 87 | 88 | 10. the commitment loss value is also very large, differing only by the gamma 89 | factor. These very large gradients are passed to the original encoder outputs 90 | ze. This produces pressure to move them to be close to their representatives 91 | zq. 92 | 93 | The encoder parameters receive the combined gradients from both the commitment 94 | loss and the reconstruction loss. These may oppose each other. For instance, 95 | suppose the reconstruction loss gradient wants to push a zq to the left, and 96 | the representative ze is to left of the zq. The ze will receive the 97 | pass-through gradient, which pushes it to the left. But the commitment loss 98 | will push the ze towards the zq (i.e. to the right). 99 | 100 | In short, the ze gets pushed by reconstruction (via pass-through) and 101 | commitment loss. 102 | 103 | The zq gets pushed only by the l2 error, towards its representative ze. 104 | 105 | Should we expect the set of 14 embedding vectors to be different? 106 | 107 | Collapse 108 | 109 | During training, the ze vectors output by the encoder have a very high range in 110 | their component values, as given by ze.min(), ze.max(). Meanwhile, emb is 111 | initialized with a xavier_uniform, which is a uniform within a range of xmin, 112 | xmax, which also takes a multiplier term which affects the range. I've tried 113 | multiplier terms of 1, 10, 100, 1000. With a gain=10, the min/max values are 114 | around +/- 0.38, while the min/max values for ze start out around +/- 25. 115 | 116 | Over time, emb min/max values slowly expand, while ze min/max rapidly shrink. 117 | For a few hundred steps, the number of distinct zq vectors mapped stay around 118 | 10-12. As the ze min/max values approach about 3x the range of emb min/max, 119 | the number of distinct zq vectors mapped starts to shrink down to one. 120 | Ultimately, this one is the same vector at each timestep. At that point, the 121 | encoder has no expressive power, and the decoder's gradients w.r.t to it 122 | presumably shrink to zero, and thus the weights stay put, effectively treating 123 | it as a bias term. 124 | 125 | One possible issue is that, while all of the encoder parameters receive 126 | gradients at every timestep, only at most 12 of the 4096 embedding vectors 127 | receive gradients. So, if there is a scale mismatch, where the majority of ze 128 | outputs lie well outside the region of embedding vectors, then they all rapidly 129 | shrink. Even so, it doesn't seem as though any particular embedding vector is 130 | singled out at this stage due to its being "pulled out" of the cloud. 131 | 132 | Instead, one of the vectors gets pulled in to be about 10x shorter than the 133 | rest, and then, it becomes the single representative for all of them. 134 | 135 | I think this happens because, as the training starts out, and all of the ze are 136 | much longer on average than the emb, the ze are under a very intense gradient 137 | to become shorter. Meanwhile, only at most 12 of the 4096 emb vectors 138 | experience any gradient to become longer at each timestep. So, a majority of 139 | the 4096 vectors are exactly where they started, a moderate average length. 140 | Some have been pulled outward, but only once or twice. 141 | 142 | Question: if you have an N-dimensional hypercube volume [-1, 1]^n and a 143 | uniform distribution of vectors within it, what is the distribution of their L2 144 | lengths? It will be very strongly peaked towards the maximum, because the 145 | "surface area" grows as radius^(n-1) 146 | 147 | 148 | The main problem is that one of the failure modes of the encoder is that it 149 | just dies out. All of the weights go to zero, and its outputs go to zero, and 150 | the one vector which happens to be closest to zero becomes the representatitve 151 | vector for all outputs. The rep vector receives the L2 error signal compounded 152 | 12 times each timestep, so gets strongly pulled in towards zero. The L2 error 153 | then shrinks to zero, and the commitment loss as well goes to zero. And, even 154 | if the decoder is providing a pass-through gradient to the encoder output, with 155 | 12 repeated uses, the gradients may tend to cancel each other. And, in any 156 | case, the commitment loss counters any reconstruction loss that is trying to 157 | diversify consecutive outputs. 158 | 159 | What to do about this? There can't be any hard-and-fast rule constraining the 160 | encoder to maximize distances between its output vectors for a given input 161 | window, because in cases of silence of slow speech, repeating a vector seems 162 | like the right thing to do. However, this vulnerability leads me to question 163 | whether Jitter is really the best approach. It doesn't seem like it is. 164 | 165 | In general, we could assume that the encoder outputs embedding vectors of a 166 | certain length range. The range itself is not really relevant, but the various 167 | length ratios and directions are. In any case, at any point in the training, 168 | we would like the overall distribution of the embedding vectors to roughly 169 | coincide with the distribution of output vectors (were we to run them on a 170 | large batch of data) 171 | 172 | At a uniform density, the number of vectors that exist within a certain length 173 | range will depend on r^n in n dimensions. 174 | 175 | Another way of thinking of this is that the overall density of vectors output 176 | by the encoder should match the density of the embedding. It's not so much 177 | that the encoder outputs zero or close-to-zero vectors that is the problem. 178 | The problem is that these vectors are distributed much more densely than the 179 | embedding vectors are distributed, so they all map to the same nearest 180 | neighbor. 181 | 182 | It is this mapping to the same nearest neighbor that is the root of the 183 | problem, because that is what allows the two L2 loss terms to go to zero and 184 | the decoder gradients for the conditioning vector to go to zero. 185 | 186 | What can we say about the stability of this setup? Given a particular 187 | initialization of weights and biases for the encoder, the distribution of 188 | output embeddings is determined from the distribution of inputs. If we could 189 | characterize this, and initialize the embedding dictionary to a similar 190 | distribution, would this be stable? Let's see...well, if this were the case, 191 | then it is likely that the sets of output embeddings would find distinct 192 | representatives. Their movement towards the representatives would not perturb 193 | them too much. Also, given that the relative distribution is uniform, a random 194 | sampling from it would have a much better chance of identifying distinct 195 | vectors. (This might require a batch size > 1 though, since the 196 | window-batching exhibits correlation). But, given the overall similarity of 197 | the two distributions (i.e. relative uniformity) it doesn't seem like the L2 198 | loss terms would change the *shape* of either distribution, but rather the fine 199 | structure. 200 | 201 | But, the problem is there is no way to know what sort of shape this 202 | distribution takes on. One could make ad-hoc arguments that it has a certain 203 | range, and perhaps symmetry between different dimensions (i.e. every dimension 204 | has the same marginal statistics). And, perhaps symmetry in the marginal 205 | distribution of each dimension itself, around zero. 206 | 207 | These would be sufficient statistics-y types of information. But, they would 208 | only be known after running the naive encoder on lots of input data. It would 209 | be much more desirable to modify the loss function so that it can't fall into 210 | this degenerate state. 211 | 212 | So, let's test this out. How about let's initialize the 4096 embedding vectors 213 | with a sampling from the encoder before any training. To do this, we would need 214 | to pre-run the encoder on some data. Perhaps it would be good to use different data 215 | at the outset. (It will be revisited later, anyhow) 216 | 217 | So let's try it! 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /doc/mfcc_inverter_notes.txt: -------------------------------------------------------------------------------- 1 | the plan would be: 2 | 3 | 1. produce labeled data (x, y) pairs, where x is MFCC coefficients, and y is 4 | the wave window they are derived from. 5 | 2. feed wavenet the wave window as input, and use the MFCC coefficients as the 6 | local conditioning. 7 | 3. since the goal is to learn an NN version that inverts the function, train 8 | this to zero error. 9 | 10 | Potentially we could re-use the data.py module, and just not use the 11 | wav_dec_input. 12 | 13 | 14 | factor out of model.py the architecture of the model from the API. 15 | 16 | class Metrics is a nice encapsulation. it should be its own class. 17 | 18 | 19 | model.Preprocess is just one of the modules needed by WaveNet, but it is 20 | not part of the model in the sense of having any learnable parameters. 21 | Not sure why it needs to be an nn.Module. 22 | 23 | model.AutoEncoder is the main model, with a forward method. It also provides a 24 | 'run' method, which coordinates the actual and target output. 25 | 26 | **** 27 | model.Metrics manages the coordinated construction of the data, model, 28 | checkpoint state, data_loader and data_iter. 29 | 30 | it provides a main method called 'train', which encapsulates all that is needed 31 | during the training 32 | 33 | How does Metrics interact with the model? It calls: 34 | 35 | model.post_init 36 | model.objective.update_anneal_weight 37 | model.init_codebook 38 | model.bottleneck.update_codebook 39 | model.objectivve.metrics.update 40 | model.run 41 | 42 | dataset calls: dataset.post_init(model) 43 | 44 | 45 | Metrics provides several metrics functions, including: 46 | 47 | peak_dist, avg_max, avg_prob_target 48 | 49 | These are specific to WaveNet. 50 | 51 | Some good new names for 'Metrics' 52 | 53 | Chassis 54 | 55 | Is it possible to really do this? 56 | It doesn't seem so, because there are too many idiosyncrasies of the 57 | initialization. 58 | 59 | We need the post_init function in both the model and data. One could ask: 60 | why not just for the data. In that way, we would have: 61 | 62 | partially construct data object 63 | fully construct model object, using partially constructed data object 64 | post-initialize data object using fully constructed model object 65 | 66 | However, we want to be able to save a model trained with one window batch size, 67 | and then resume it and train with another window batch size. Since the data 68 | is 69 | 70 | 71 | Perhaps the easiest would be to just accept an option, and explicitly use 72 | whatever appropriate constructor 73 | 74 | The _init_geometry method of autoencoder_model mostly initializes geometry 75 | within the decoder, and most of this is needed for the 76 | 77 | 78 | 79 | chassis.loss_fn: 80 | calls run_batch 81 | calls model.objective 82 | computes gradients 83 | sets metrics 84 | calls loss.backward() 85 | 86 | 87 | chassis.run_batch: 88 | gets batch 89 | calls model.run 90 | collects output 91 | 92 | 93 | model.run: 94 | consumes batch 95 | outputs y', y pairs 96 | 97 | -------------------------------------------------------------------------------- /doc/padding_notes.txt: -------------------------------------------------------------------------------- 1 | Notes on librosa.feature.mfcc 2 | 3 | librosa.feature.mfcc considers all positions as valid, those in which the 4 | center element of the window covers one of the input elements. In the case of 5 | an even-lengthed window, if either the left or right center element covers one 6 | of the input elements, the position is considered valid. For this reason, we 7 | see I+1 outputs for an even sized window, hop length 1 and input length I. 8 | 9 | However, we only consider window positions valid if the full window overlaps 10 | input. So, the code in mfcc.py:ProcessWav::func corrects for this by first 11 | doing left-padding of the input, so that one of the window positions will 12 | have its left edge overlap the first element of the input. And, we then trim 13 | the output of librosa.feature.mfcc of its invalid window positions. 14 | 15 | The calculations are as follows: 16 | 17 | 18 | 1. calculate left_wing_sz, relative to the left center element. for odd-length 19 | filters, regard the center element as both left and right center. for even-length 20 | filters, 21 | 22 | 1. L' = L if odd-length, L + 1 for even-length filters 23 | R' = R for both odd-length and even-length filters 24 | 25 | 2. pad the left input by (L' % H) 26 | 27 | 3. trim left of output by L' // H. 28 | 29 | 4. assume we have an input size such that one of the outputs aligns with the 30 | end of the input at the right side of the window 31 | 32 | 5. trim right of output by R' // H # The number of additional invalid positions 33 | that librosa will use. 34 | 35 | 36 | So, now that we have trimmed the output and padded the input, we need to indicate what 37 | the left and right offsets are. But, these are simple, because we've already basically 38 | wrapped the librosa procedure in such a way that it generates valid, maximal convolutions. 39 | 40 | 5. foff.left = left_wing_sz 41 | 6. foff.right = right_wing_sz (includes right center element if it exists) 42 | 43 | 44 | F = filter_length 45 | H = hop_length 46 | L = left wing size 47 | R = right wing size (including the right center element if it exists) 48 | I = input length 49 | I' = I + (1 - F % 2) # This is the number of positions that librosa may place 50 | # the center element. 51 | C means index of the left center positions of the window 52 | B means index of the start of the window 53 | E means index of the end of the window 54 | D means index in the original input of the start of the window 55 | Q means index in the original input of the end of the window 56 | 57 | Z = ((I' - 1)// H # The last multiple 58 | 59 | C: 0 H 2H 3H 4H ... ZH 60 | B: -L H-L 2H-L 3H-L 4H-L ... ZH-L 61 | E: R H+R 2H+R 3H+R 4H+R ... ZH+R 62 | D: -P-L -P+H-L -P+2H-L -P+3H-L -P+4H-L ... -P+ZH-L 63 | Q: -P+R -P+H+R -P+2H+R -P+3H+R -P+4H+R ... -P+ZH+R 64 | 65 | 66 | Calculation of padding 67 | 68 | (P + nH - L) % H = 0 for P "and some n" 69 | 70 | (P - L) % H = 0 71 | P % H = H - (L % H) 72 | P = H - (L % H) 73 | 74 | 75 | Calculation of left trimming. The first valid position places the left center element 76 | at index L. We thus have L more positions, and are consuming them at a hop size of H, 77 | so L // H is the number. 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /doc/rfield_notes.txt: -------------------------------------------------------------------------------- 1 | Notes on rfield.py 2 | 3 | I = input size 4 | Or = requested output size 5 | Oa = actual output size (>= Or) 6 | S = stride (or inverse stride, if upsampling) 7 | LW = left wing size (the number of elements to the left of the filter central element) 8 | RW = right wing size ( the number of elements to the right of the filter central element) 9 | LP = left padding 10 | RP = right padding 11 | 12 | spaced(N, S) := any value in the range [N*S-(S-1), N*S] 13 | This is the number of total elements in a strided arrangement, 14 | starting at a value element, and using exactly N value elements. 15 | 16 | 17 | downsampling: LW + spaced(Or, S) + RW = LP + spaced(I, 1) + RP 18 | upsampling : LW + spaced(Or, 1) + RW = LP + spaced(I, S) + RP 19 | 20 | Illustration: 21 | 22 | 23 | if upsampling, solving for I: 24 | (I-1)*S+1 + LP + RP = Or + LW + RW 25 | IS - S + 1 + LP + RP = Or + LW + RW 26 | IS = Or + LW + RW + S - 1 -LP -RP 27 | I = ceil((S - 1 - LP - RP + Or + LW + RW) / S) 28 | Oa = spaced(I,S) + LP + RP - LW - RW 29 | 30 | if downsampling: 31 | 32 | solving for I: 33 | spaced(Or,S) + LW + RW = I + LP + RP 34 | I = spaced(Or,S) + LW + RW - LP - RP 35 | 36 | solving for Oa: 37 | (Oa-1)*S + 1 = I + LP + RP - LW - RW 38 | (Oa*S - S + 1 = I + LP + RP - LW - RW 39 | Oa*S = I + LP + RP - LW - RW + S - 1 40 | Oa = (I + LP + RP - LW - RW + S - 1) // S 41 | 42 | 43 | This last formula is integral, since: 44 | 45 | I-1 = spaced(Or,S) - 1 46 | = (Or-1)*S + 1 - 1 47 | = (Or-1)*S 48 | 49 | So, 50 | 51 | Or = (LP + RP - LW - RW + Or*S) / S 52 | and 53 | 54 | Oa = Or 55 | 56 | 57 | 58 | input_stride 59 | C **** 1 60 | B pp*-*-*-* 2 61 | A ********** 1 62 | 63 | 64 | In the above diagram, data A begins at one position to the left of data B, in 65 | terms of A's stride. Then, things get more complicated. B's stride is 2, because 66 | LayerAB is downsampling. 67 | 68 | data B begins 1/2 of a position to the left of data C, in terms of B's stride. 69 | 70 | But the intuition behind the formula was pretty simple. Each new offset builds on 71 | the previous one additively. At each new step, you are given an offset in terms of 72 | the input, which in general is both padded (self.left_pad) and dilated (self.stride_ratio) 73 | 74 | But, we don't want to confuse the stride of the output with the coordinate system that 75 | 76 | left_ind spacing pad_stride sr pos pos_formula 77 | C * * * 1 1 1/2 pos(B) + left_ind(C) * 78 | B * - * - * 1 1 3 2 pos(A) + left_ind(B) * pad_stride(A) 79 | A * * * * * * * * * * 0 2/3 1 0 0 80 | | 81 | 0 82 | 83 | 84 | Central Idea: 85 | 86 | An input tensor of elements is transformed in a series of steps, each producing 87 | a new tensor. At each step, all of the tensor elements have associated with 88 | them a physical coordinate (such as 'x'), and are regularly spaced. 89 | 90 | The difference between consecutive elements of a given tensor is called its 91 | 'spacing'. But, tensors may have two types of elements: 'value elements' and 92 | 'padding elements'. The physical distance between a pair of consecutive value 93 | elements (which, there may be intervening padding elements between this pair of 94 | value elements) is called 'value_spacing'. The physical distance between a 95 | pair of any two consecutive elements, ignoring their type, is called just 96 | 'spacing'. 97 | 98 | A transformation is characterized by its stride_ratio, which equals 99 | out.value_spacing / in.value_spacing. 100 | 101 | Here, there are a few important concepts. First is the spacing between 102 | elements in spacing_ratio. This equals output_spacing / input_spacing. Note, 103 | though that there are two kinds of spacing for each tensor: unpadded_spacing, 104 | and padded_spacing. 105 | 106 | The initial input spacing should be such that the output stride is a whole 107 | number. So, this reduces to the problem of finding the multiple for the 108 | overall stride that brings the output to a whole number. Actually, not only 109 | that, we need every intermediate number to be a whole number. 110 | 111 | So, how to do that? store each stride as a reduced ratio. For instance, your 112 | strides might be: 113 | 114 | 1, 2/3, 1/2, 1/4 115 | 116 | And, so, you'd need to find the least common multiple of the denominators. 117 | Actually, though, since the strides are all running products of either integers 118 | or reciprocal integers, then we have: 119 | 120 | We want to find the LCM of the denominators of all of the strides. This implies we need to 121 | accumulate them as we return. Could they be calculated in another way, on the way down? 122 | 123 | A simple rule might be, on the way down: 124 | 125 | Start assuming you are at unit stride. On the way down, an upsampling layer is no problem, 126 | because it dilates the stride in the layer below. However, a downsampling layer contracts 127 | the stride. So, you need to increase the multiple. Is there any reason why this wouldn't 128 | work? 129 | 130 | I think so. Essentially, you want your low watermark to be 1. This is the most dense 131 | layer, and it could be anywhere. So, if you simply bump up the stride every time it 132 | goes below 1, you will by definition have the densist layer at 1. 133 | 134 | So, the print function will not print on the way down. It will maintain a density 135 | 136 | 137 | There are several tasks needed. But, mainly, we need to accumulate the following data for 138 | every layer: 139 | 140 | left_pad, right_pad, stride, input_size, local_bounds 141 | 142 | left_pad and right_pad are easy 143 | And, maintain the min_stride along with that. 144 | 145 | At the end, start printing. Simple enough. 146 | 147 | 148 | 149 | all of these are available in one form or another. 150 | 151 | Another difficulty is, the recursion logic is intermixed with the measurement 152 | logic. But, for the print function we need to use the measurement logic 153 | in different ways. 154 | 155 | So, what is the overall blueprint for print? 156 | 157 | As we recurse, collect 158 | 159 | 160 | Another issue is that of keeping track of strides. Each operation imposes a stride ratio factor 161 | to the existing stride. Some intermediate calculations involve fractional 162 | 163 | Should the l_pos and r_pos fields be corrected somehow? Right now, l_pos is in the 164 | opposite direction that it should be. We can correct this by first reversing the sense 165 | of l_pos, and then by left-aligning it. 166 | 167 | 168 | The next two issues are: 169 | 170 | 1. If a transformation specifies certain combinations of padding and upsampling, and then 171 | the user requests the input size for too small of an output size, the number of input elements 172 | can be zero or negative. This shouldn't be allowed because it doesn't make any sense. 173 | 174 | So, where would it be detected that the number of input elements is non-positive? 175 | 176 | Is there any reason NOT to instead accumulate the tensor shapes as properties of rfield? 177 | One drawback is that then, there can only be one version of the stats. However, it seems 178 | to me that there isn't a very meaningful use case other than that. Just have one 179 | instance of a model, and one calculation of dimensions needed. 180 | 181 | It does make the model "sticky" though. Whereas, before, the model is basically a set of 182 | pipes that can expand and contract as necessary, with no extra structures that reflect 183 | any particular choice of dimensions, and the stats list is a separate entity that can be 184 | discarded. Now, the stats are integrated into the model - the important distinction 185 | is that the stats are now in one-to-one existence with the model instance. 186 | 187 | For instance, now you can't have two separate stats lists for two different input 188 | dimensions for your model. You have to choose one. 189 | 190 | But, again, this doesn't seem like it interferes with any reasonable use case. The main 191 | use case is to calculate the input needed for a particular desired output, and, second, 192 | in the case where intermediate inputs are needed, their offsets relative to the main 193 | input are easily calculated as well. 194 | 195 | 196 | Rfield.dst -> Stats 197 | Rfield.src -> Stats 198 | Rfield.parent -> Rfield 199 | Stats.src -> Rfield 200 | Stats.dst -> Rfield 201 | 202 | 203 | Now, we need a simple way to check that the input and output sizes match 204 | 205 | -------------------------------------------------------------------------------- /doc/rfield_notes2.txt: -------------------------------------------------------------------------------- 1 | I believe that by PyTorch and Tensorflow treat transpose convolutions as 2 | follows: 3 | 4 | 1. For S=inverse_stride, add S-1 zero-valued spacing elements between each 5 | pair of input elements. 6 | 2. Add the desired left and right padding elements to each end. 7 | 8 | For example: 9 | 10 | @@@#***#***#***#***#***#@@@@ 11 | 12 | What they do NOT do is something like: 13 | 14 | @@@*#***#***#***#***#***#*@@@@ 15 | 16 | In other words, the non-padded region always begins and ends with a value 17 | element '#'. 18 | 19 | This creates the following problem. Suppose I want to calculate the receptive 20 | field of a length 1 output for a transpose convolution with: 21 | 22 | left_wing_size=12 23 | right_wing_size=12 24 | left_pad = 4 25 | right_pad = 4 26 | stride = 1/5 27 | 28 | We would have: 29 | 30 | @@@@#****#****#****#@@@@ 31 | [ | ] 32 | l k r 33 | 34 | This would not accommodate an output of 1 since the total input size is one 35 | smaller than the filter size of 25. However: 36 | 37 | @@@@#****#****#****#*@@@@ 38 | @@@@#****#****#****#****#@@@@ 39 | [ | ] 40 | l k r 41 | 42 | Both of these would. And, note that the receptive fields for them would be: 43 | 44 | [0,4) 45 | [0,5) 46 | 47 | In the first case, all of the padding is used, and in the second, one 48 | additional value element is used, and padding is ignored. 49 | 50 | In computing receptive fields and influence fields, we adopt the assumption 51 | that spacing is only applied between pairs of input value elements. 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /doc/todo.txt: -------------------------------------------------------------------------------- 1 | Write Documentation and usage examples for rfield.py 2 | Implement VAE and VQVAE bottlenecks 3 | Implement inference mode 4 | 5 | -------------------------------------------------------------------------------- /doc/upsampling_notes.txt: -------------------------------------------------------------------------------- 1 | Design of the local conditioning upsampling module 2 | 3 | The paper mentions (p. 5, first paragraph), "The representation was then 4 | upsampled 320 times (to match the 16kHz audio sampling rate)..." It isn't 5 | specified in the paper how upsampling is performed. In the original wavenet 6 | paper (https://arxiv.org/pdf/1609.03499.pdf) mentions that the upsampling uses 7 | transposed convolutions, but does not give any detail about how many transposed 8 | convolutions, strides, filter sizes, or padding strategies. 9 | 10 | Here, I adopt a few strategies: 11 | 12 | Strategy 1: I think obvious, the product of strides needs to be 320. My 13 | default is to use strides [5, 4, 4, 4]. 14 | 15 | Strategy 2: Use filter sizes large enough to cover at least a few 16 | non-zero-padding inputs. For example, for a stride of 5, if we want to use 5 17 | non-zero inputs, the filter must be at least 21 units long. 18 | 19 | Strategy 3: To make successive convolutions use a consistent number of non-zero 20 | inputs, the length of the filter must be a multiple of the stride. Thus, using 21 | exactly 5 non-zero inputs with a stride of 5 implies a filter size of 25. 22 | 23 | Strategy 4: Only use convolution positions that span filter_sz // stride 24 | non-zero inputs. In the above example, this would be 25 // 5 = 5 non-zero 25 | inputs. 26 | 27 | Strategy 5: For symmetry, have the same number of convolutions at each of the 5 28 | phases in the stride. This is admittedly a weakly supported strategy, but to 29 | me it feels like the best way to minimize any "edge effects" 30 | 31 | In order to achieve all of these goals using PyTorch's nn.ConvTranspose1d, we 32 | need to make careful use of the 'padding' parameter, and also trim the output 33 | of invalid filter positions. 34 | 35 | Here's an experiment with the padding parameter, with an input size of 7 and a 36 | stride of 5: 37 | 38 | x = torch.randn(1, 1, 7) 39 | for pad in range(28): 40 | tconv = nn.ConvTranspose1d(1, 1, kernel_size=25, stride=5, padding=pad) 41 | result = tconv(x) 42 | print(pad, result.shape) 43 | 44 | 0 torch.Size([1, 1, 55]) 45 | 1 torch.Size([1, 1, 53]) 46 | 2 torch.Size([1, 1, 51]) 47 | 3 torch.Size([1, 1, 49]) 48 | 4 torch.Size([1, 1, 47]) 49 | 5 torch.Size([1, 1, 45]) 50 | 6 torch.Size([1, 1, 43]) 51 | 7 torch.Size([1, 1, 41]) 52 | 8 torch.Size([1, 1, 39]) 53 | 9 torch.Size([1, 1, 37]) 54 | 10 torch.Size([1, 1, 35]) 55 | 11 torch.Size([1, 1, 33]) 56 | 12 torch.Size([1, 1, 31]) 57 | 13 torch.Size([1, 1, 29]) 58 | 14 torch.Size([1, 1, 27]) 59 | 15 torch.Size([1, 1, 25]) 60 | 16 torch.Size([1, 1, 23]) 61 | 17 torch.Size([1, 1, 21]) 62 | 18 torch.Size([1, 1, 19]) 63 | 19 torch.Size([1, 1, 17]) 64 | 20 torch.Size([1, 1, 15]) 65 | 21 torch.Size([1, 1, 13]) 66 | 22 torch.Size([1, 1, 11]) 67 | 23 torch.Size([1, 1, 9]) 68 | 24 torch.Size([1, 1, 7]) 69 | 25 torch.Size([1, 1, 5]) 70 | 26 torch.Size([1, 1, 3]) 71 | 27 torch.Size([1, 1, 1]) 72 | 73 | What seems to be happening is, between each neighboring pair of input elements, 74 | PyTorch adds (stride-1) zero-valued elements. With a padding=0 argument, at 75 | each end, it adds filt_sz - 1 zero-valued elements. Each additional padding 76 | value reduces by one the padding on the ends of the input. Here is the same 77 | experiment, showing first and last filter positions as [ ], and { }, with ^ 78 | marking the central position of the filter and also the "position" of the 79 | convolution result. 80 | 81 | 82 | position: 0 10 20 30 40 50 60 70 83 | position: | | | | | | | | 84 | spaced input: ------------------------*----*----*----*----*----*----*------------------------ 85 | (pad=0) [ ^ ] { ^ } 86 | (pad=1) [ ^ ] { ^ } 87 | (pad=2) [ ^ ] { ^ } 88 | (pad=3) [ ^ ] { ^ } 89 | (pad=4) [ ^ ] { ^ } 90 | (pad=5) [ ^ ] { ^ } 91 | (pad=6) [ ^ ] { ^ } 92 | (pad=7) [ ^ ] { ^ } 93 | (pad=8) [ ^ ] { ^ } 94 | (pad=9) [ ^ ] { ^ } 95 | (pad=10) [ ^ ] { ^ } 96 | (pad=11) [ ^ ] { ^ } 97 | (pad=12) [ ^ ] { ^ } 98 | (pad=13) [ ^ ] { ^ } 99 | (pad=14) [ ^ ] { ^ } 100 | (pad=15) [ ^ I ^ } 101 | (pad=16) [ ^ { ] ^ } 102 | (pad=17) [ ^ { ] ^ } 103 | (pad=18) [ ^ { ] ^ } 104 | (pad=19) [ ^ { ] ^ } 105 | (pad=20) [ ^ { ] ^ } 106 | (pad=21) [ { ] } 107 | (pad=22) [ { ^ ^ ] } 108 | (pad=23) [ { ^ ^ ] } 109 | (pad=24) [ { ^ ^ ] } 110 | (pad=25) [ { ^ ^ ] } 111 | (pad=26) [ { ^ ^ ] } 112 | (pad=27) I ^ I 113 | spaced input: ------------------------*----*----*----*----*----*----*------------------------ 114 | position: | | | | | | | | 115 | position: 0 10 20 30 40 50 60 70 116 | 117 | Note that only padding >= 20 have all filter positions covering exactly 5 118 | non-zero inputs (marked '*'). Also, note that the filter positions all have a 119 | "phase", defined by where the central position is in relation to the greatest 120 | upper bound non-zero element. The number of filter positions of each phase for 121 | padding >= 20 are shown here: 122 | 123 | pad phase_pattern ph0 ph1 ph2 ph3 ph4 out_sz pad_in_sz 124 | 20 340123401234012 3 3 3 3 3 15 39 125 | 21 4012340123401 3 3 2 2 3 13 37 126 | 22 01234012340 3 2 2 2 1 11 35 127 | 23 123401234 1 2 2 2 1 9 33 128 | 24 2340123 1 1 2 2 1 7 31 129 | 25 34012 1 1 1 1 1 5 29 130 | 26 401 1 1 0 0 1 3 27 131 | 27 0 1 0 0 0 0 1 25 132 | 133 | Only paddings 20 and 25 satisfy strategy 5. For these, output size is an even 134 | multiple of the stride. Note also that the starting and ending phases of the 135 | pattern is 3, 2. 136 | 137 | PyTorch adds filter_sz - 1 - padding zero-valued elements to each end of the 138 | input. In the below equations, 'padding' refers to the nn.ConvTranspose1d 139 | formal parameter called 'padding', NOT the number of elements actually added to 140 | the input on either side. 141 | 142 | Any filter position that starts more than stride - 1 positions before the first 143 | non-zero element will not cover the maximal number of non-zero elements in the 144 | input, and thus will not satisfy strategy 3. To summarize: 145 | 146 | end_padding = filter_sz - 1 - padding 147 | end_padding <= stride - 1 148 | 149 | filter_sz - 1 - padding <= stride - 1 150 | filter_sz - padding <= stride 151 | -padding <= stride - filter_sz 152 | padding >= filter_sz - stride (Criterion 1) 153 | 154 | So with a filter_sz = 25, stride = 5, padding >= 20, as we saw. 155 | 156 | Next, we want out_sz % stride == 0. To summarize: 157 | 158 | out_sz = pad_in_sz - filter_sz + 1 159 | pad_in_sz = 2 * end_padding + (num_nonzero - 1) * stride + 1 160 | 161 | pad_in_sz = 2 * (filter_sz - 1 - padding) + (num_nonzero - 1) * stride + 1 162 | out_sz = 2 * (filter_sz - 1 - padding) + (num_nonzero - 1) * stride + 1 - filter_sz + 1 163 | out_sz = filter_sz - 2 * padding + (num_nonzero - 1) * stride 164 | out_sz % stride = (filter_sz - 2 * padding) % stride 165 | out_sz % stride = (-2 * padding) % stride # filter_sz % stride == 0 166 | 167 | which implies padding % stride == 0 (Criterion 2) 168 | 169 | Finally, to maximize output, we minimize padding. The unique value of padding which 170 | satisfies both critera is just: 171 | 172 | padding = filter_sz - stride (Result) 173 | 174 | With this choice: 175 | 176 | end_padding = filter_sz - 1 - padding 177 | = filter_sz - 1 - (filter_sz - stride) 178 | = stride - 1 179 | 180 | as can be seen in the diagram above for padding=20. 181 | 182 | 183 | Timestep coordinates 184 | 185 | In the context of this paper, the embedding vectors output by the encoder occur 186 | one every 320 timesteps. It is a reasonable idea to assign each one to a 187 | particular timestep relative to the input wav file. Likewise, in the decoder's 188 | block of upsampling layers, we take the approach of assigning explicit 189 | timesteps to each output of the transpose convolution. 190 | 191 | Thus, adopting the specific set of strides [5, 4, 4, 4] as mentioned above, the 192 | input to the upsampling block will have one input for each 320 timesteps. The next 193 | layer will have one output for each 64 timesteps, then 16, then 4, then 1. 194 | 195 | Also, we can adopt the convention that the output timestep should correspond to 196 | the position of the central filter element. This implies an offset between the 197 | first input position and first output position, and between the last input 198 | position and last output position. In the code, these offsets are represented 199 | using the rfield.Fieldoffset class. 200 | 201 | Here is a diagram of the output region produced with padding=20, stride=5, 202 | filter_sz=25, and input_sz=7. The input elements '*' are taken to be at a 5x 203 | lower frequency than the output. This could represent a frequency from 1/320 204 | timesteps to 1/64 timesteps, or could be a later layer, from 1/5 timesteps to 205 | 1/1 timesteps. 206 | 207 | output +++++++++++++++ 208 | (pad=20) [ ^ { ] ^ } 209 | spaced input: ----*----*----*----*----*----*----*---- 210 | position: | | | | | 211 | position: 0 10 20 30 40 212 | 213 | We want to calculate the offset, in timesteps, between the coordinate of the 214 | leftmost non-zero element and the left-most output element, and the same idea 215 | on the right side. In this case, the offset is 12-4=8 on the left, and 34-26=8 216 | on the right. 217 | 218 | The offset is readily calculated as: 219 | 220 | left_wing_sz = (filter_sz - 1) // 2 # distance from left end of filter to center element 221 | right_wing_sz = (filter_sz - 1) - left_wing_sz 222 | end_padding = stride - 1 223 | 224 | left_offset = left_wing_sz - end_padding (Result 2) 225 | right_offset = right_wing_sz - end_padding 226 | 227 | 228 | Since the filter_sz is odd, the right_offset is the same. We now have a way of 229 | calculating the total size and positioning of output elements in the stack of 230 | upsampling units. This allows the decoder (WaveNet) to back-calculate how many 231 | 1/320 timestep embedding vectors are needed for conditioning on the receptive 232 | field for its own dilated convolutional stack. 233 | 234 | In wavenet.py::Upsampling, the full left and right offsets are recorded in the 235 | 'foff' member variable. 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /doc/vae_and_ar_issues.txt: -------------------------------------------------------------------------------- 1 | How to apply VAE training to an auto-regressive model? 2 | 3 | The VAE training objective involves two pieces of information from the 4 | autoencoder. The first is the mu and sigma vectors from the encoder - these 5 | are used to analytically compute the KL divergence from the standard normal 6 | prior p(z). The second is p(x|z), the posterior from the decoder (At least, I 7 | think this is the proper terminology). 8 | 9 | This training objective implies that there is a one-to-one relationship between 10 | hidden state z and input to the model. At least, for a *single* sample, the 11 | objective requires exactly one z and one x (and the single (mu, sigma) pair of 12 | vectors that gave rise to the z. It also seems to imply that there the 13 | individual samples are independent; there is no shared information, in the form 14 | of activations or inputs, between one (z, x, mu, sigma) tuple and another, 15 | because the decoder can assign a probability p(x|z), with no other information 16 | except z and x. 17 | 18 | However, in the WaveNet autoencoder setting, a few questions arise. First, the 19 | autoencoder consumes a window of W timesteps in order to produce one hidden 20 | vector z, which, it is only sensible, it would be assigned to a particular 21 | timestep in the middle of the window of wav data. In principle, the 22 | autoencoder could be used to produce encoding vectors at every timestep, by 23 | moving the input window over one timestep at a time. That isn't what the study 24 | does, but let's come back to that. 25 | 26 | Then, the decoder also cannot assign a probability to any output whatsoever 27 | with just one z vector. Its receptive field is ~2048 timesteps, which need ~6 28 | z vectors. 29 | 30 | So, in formula 8 of "Auto-encoding Variational Bayes", what would be the 31 | appropriate choice of z or x? 32 | 33 | In principle, we could view x as some window of wav data, and z as a collection 34 | of vectors. The probability of the reconstructed x would then just be the 35 | autoregressive probability assigned by WaveNet. 36 | -------------------------------------------------------------------------------- /doc/vconv_notes.old.txt: -------------------------------------------------------------------------------- 1 | What is shadow doing? 2 | 3 | It first calculates the induced field range, then 4 | calculates the position offsets. It's probably best to redo this. 5 | 6 | Steps needed to calculate induced field: 7 | 8 | For a given input range, the induced field range is also expressed as a range on the input. 9 | It is the set of elements in the input that are "covered" by the output. Covered means 10 | that the key element of the filter is above them. 11 | 12 | It is also known as the shadow. 13 | 14 | The IFR of a single element i, which is "far" from the edges, is [i-rw, i+lw]. 15 | The IFR of the entire range [o, l) is [lw, l-1-rw]. (Note the closed interval style) 16 | 17 | The IFR of a general single element, which may be close to the edges, is: 18 | [max(lw, i-rw), min(l-1-rw, i+lw)] 19 | 20 | If this range is empty (or inverted), there is no output. 21 | 22 | So, now that we have the IFR, we need to translate it into the output 23 | element. The output is related to the input in an affine way. It is 24 | first offset by lw, then spaced by osp. For the begin element, 25 | we want to choose the 26 | 27 | 28 | 1. translate the input index in_b to the spaced index si_b 29 | 2. calculate the induced field range of the entire spaced input, if_min, if_max 30 | 3. calculate the induced field range of the begin element, [bf_min, bf_max] 31 | 4. take the minimum position in the intersection, which is: 32 | max(if_min, bf_min), min(if_max, bf_max) 33 | bf_min_adj 34 | if it is empty, return the empty set 35 | 5. 36 | 37 | 38 | Steps for rfield: 39 | 40 | 1. calculate the maximal shadow in index coordinates: 41 | [if_min, if_max] = [lw, (in_l-1) * isp - rw] 42 | 43 | 2. for out_b, calculate its shadow, assuming no limits: 44 | [b_si_min0, b_si_max0] = [out_b * osp, out_b * osp + rw + lw] 45 | 46 | 3. calculate the restricted range of out_b's shadow: 47 | [b_si_min, b_si_max] = [max(if_min, b_si_min0), min(if_max, b_si_max0)] 48 | 49 | 4. for out_e - 1, calculate its shadow, assuming no limits: 50 | [e_si_min0, e_si_max0] = [(out_e - 1) * osp, (out_e - 1) * osp + rw + lw] 51 | 52 | 5. calculate the restricted range 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /doc/vconv_notes.txt: -------------------------------------------------------------------------------- 1 | NOTES on vconv module 2 | 3 | The vconv module contains two functions, output_range and input_range. 4 | 5 | Note that for convolutions of stride S > 1, there are S different inputs which 6 | all produce the same output. These different inputs differ by having [0, ..., 7 | S - 1) extra elements on the end relative to a minimal input. One feature of 8 | input_range is that it reports this minimal input when given an output. 9 | 10 | For convolutions of inverse stride, where the reciprocal S > 1, there are S 11 | different outputs which all *could* be prodouced from the same input 12 | information. However, the one among those which is produced by PyTorch or 13 | TensorFlow convolutions is the maximal one. 14 | 15 | So, when computing geometries, the following workflow is used: 16 | 17 | 1. Start with the total available input, and call output_range. 18 | 19 | stride remark 20 | 1 one possible output geometry 21 | S/1 one possible output geometry; up to S-1 input elements may be 22 | unused 23 | 1/S maximal output among S possible outputs that use the same input 24 | 25 | 2. Call input_range on the resulting output geometry. 26 | 27 | stride remark 28 | 1 one possible input geometry 29 | S/1 the minimal input among the S possible inputs is reported 30 | 1/S one possible input geometry, but up to S-1 additional output 31 | elements may be missing from the maximal output for this input 32 | geometry 33 | 34 | what happens in each of these three cases when we complete a round trip: 35 | 36 | x = initial input range 37 | y = output_range(x) 38 | xp = input_range(y) 39 | 40 | Stride 1: 41 | x == input_range(output_range(x)) for all x 42 | 43 | Stride S/1: 44 | Let x = input_range(output_range(x_initial)) 45 | Then: x == input_range(output_range(x)) 46 | x will be smaller than x_initial by [0, S) elements 47 | 48 | Stride 1/S: 49 | x == input_range(output_range(x)) for all x 50 | 51 | 52 | I believe these identities should be valid not just for the full range, 53 | but for any subranges as well. So, what regression tests do we need? 54 | 55 | Inputs 1: range of (lp, rp, lw, rw) for (full, sub, gs) 56 | Inputs 2: given (lp, rp, lw, rw), range of (full, sub, gs) 57 | 58 | Test 1 (strides 1, 1/S) 59 | xn = input_range(output_range(x)) 60 | assert xn.full == x.full 61 | assert xn.sub == x.sub 62 | 63 | Test 2 (strides S/1) 64 | xn = input_range(output_range(x)) 65 | xt = input_range(output_range(xn)) 66 | assert xn.full == xt.full 67 | assert xn.sub == xt.sub 68 | 69 | 1. test 1: stride = 1, inputs 1 70 | 2. test 1: stride = 1, inputs 2 71 | 3. test 2: stride = S, inputs 1 72 | 4. test 2: stride = S, inputs 2 73 | 5. test 1: stride = 1/S, inputs 1 74 | 6. test 1: stride = 1/S, inputs 2 75 | 76 | 77 | 78 | PROCEDURE FOR PREPARING WINDOW SLICES 79 | # define the convolutional chains 80 | # Note: the upsampling block is considered part of the encoder for these 81 | purposes 82 | enc = (mfcc_vc, last_upsample_vc) 83 | dec = (wavenet_beg_vc, wavenet_end_vc) 84 | autoenc = (mfcc_vc, wavenet_end_vc) 85 | 86 | # define complete input and output dimensions 87 | # This can be done during Slice initialization 88 | w = 2568938 89 | full_in = ((0, w), (0, w), 1) 90 | full_out = output_range(autoenc, *full_in) 91 | 92 | # decide on some desired slice of the output 93 | s = 1028539 94 | out_req = (full_out[0], (s, s + 100), 1) 95 | 96 | # decoder required input 97 | mid_req = input_range(dec, *out_req) 98 | 99 | # encoder required input 100 | in_req = input_range(enc, *mid_req) 101 | 102 | in_act = in_req 103 | 104 | # encoder actual output 105 | mid_act = output_range(enc, *in_act) 106 | 107 | # wav -> wav_mid 108 | trim_b, trim_e = tensor_slice(in_act, mid_act) 109 | wav_mid_ten = wav_ten[trim_b:trim_e] 110 | 111 | # lcond -> lcond_trim 112 | trim_b, trim_e = tensor_slice(mid_act, mid_req) 113 | lcond_trim_ten = lcond_ten[trim_b:trim_e] 114 | 115 | # wav -> wav_out 116 | # +1 since it is predicting the next step 117 | trim_b, trim_e = tensor_slice(out_req, out_req) 118 | wav_out_ten = wav_ten[trim_b+1:trim_e+1] 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /grad_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Functions for analyzing the gradients during training 4 | 5 | # Approach: At a given moment during training, calculate the gradients on each 6 | # weight for N minibatches of data. Then, calculate the standard deviation on 7 | # each weight over the N minibatches. Finally, report the average standard 8 | # deviation, and perhaps quantiles, over the weights in a given layer 9 | 10 | # This will inform whether the batch size is too small and thus too noisy 11 | # for a given learning rate 12 | 13 | # We need to have a function for copying the gradients after a call to 14 | # backward, into some larger vector with an extra dimension. 15 | 16 | # Due to memory constraints, we cannot store all N sets of gradients for all 17 | # parameters, nor is it efficient to store one at a time and re-run the 18 | # forward/backward pass for each parameter set. Instead, we use an incremental 19 | # formula for the variance, from 20 | # http://datagenetics.com/blog/november22017/index.html: 21 | 22 | # mu_0 = x_0, S_0 = 0 23 | # mu_n = mu_(n-1) + (x_n - mu_(n-1)) / n 24 | # S_n = S_(n-1) + (x_n - mu_(n-1)) (x_n - mu_n) 25 | # sigma_n = sqrt(S_n / n) 26 | def mu_s_incr(x_cur, n, mu_pre, s_pre): 27 | """ 28 | Calculate current mu and s from previous values using incremental formula. 29 | All three arguments are assumed to have the same shape and are computed 30 | elementwise 31 | """ 32 | if n == 0: 33 | return x_cur, x_cur.new_zeros(x_cur.shape) 34 | 35 | assert x_cur.shape == mu_pre.shape 36 | assert x_cur.shape == s_pre.shape 37 | mu_cur = mu_pre + (x_cur - mu_pre) / n 38 | s_cur = s_pre + (x_cur - mu_pre) * (x_cur - mu_cur) 39 | return mu_cur, s_cur 40 | 41 | 42 | def quantiles(x, quantiles): 43 | """ 44 | Return the quantiles of x. quantiles are given in [0, 1] 45 | """ 46 | qv = [0] * len(quantiles) 47 | for i, q in enumerate(quantiles): 48 | k = 1 + round(float(q) * (x.numel() - 1)) 49 | qv[i] = x.view(-1).kthvalue(k)[0].item() 50 | return qv 51 | 52 | 53 | def grad_stats(model, update_model_closure, n_batch, report_quantiles): 54 | """ 55 | Run n_batch'es of data through the model, accumulating an incremental 56 | mean and sd of the gradients. Report the quantiles of these sigma values 57 | per parameter. 58 | model is a torch.Module 59 | update_model_closure should fetch a new batch of data, then run 60 | forward()/backward() to update the gradients 61 | """ 62 | mu = {} 63 | s = {} 64 | 65 | update_model_closure() 66 | 67 | for name, par in model.named_parameters(): 68 | if par.grad is None: 69 | continue 70 | mu[name] = None 71 | s[name] = None 72 | 73 | for b in range(n_batch): 74 | update_model_closure() 75 | for name, par in model.named_parameters(): 76 | if par.grad is None: 77 | continue 78 | mu[name], s[name] = mu_s_incr(par.grad, b, mu[name], s[name]) 79 | 80 | quantile_values = {} 81 | for name, sval in s.items(): 82 | sig = (sval / n_batch).sqrt().cpu() 83 | quantile_values[name] = quantiles(sig, report_quantiles) 84 | 85 | return quantile_values 86 | 87 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | # All keys that appear in each entry of HPARAMS_REGISTRY must also appear in 2 | # some entry of DEFAULTS 3 | HPARAMS_REGISTRY = {} 4 | DEFAULTS = {} 5 | 6 | class Hyperparams(dict): 7 | def __getattr__(self, attr): 8 | try: 9 | return self[attr] 10 | except KeyError: 11 | raise AttributeError(f'attribute {attr} undefined') 12 | 13 | def __setattr__(self, attr, value): 14 | self[attr] = value 15 | 16 | def __getstate__(self): 17 | return self 18 | 19 | def __setstate__(self, state): 20 | self.update(state) 21 | 22 | 23 | def setup_hparams(hparam_set_names, kwargs): 24 | H = Hyperparams() 25 | if not isinstance(hparam_set_names, tuple): 26 | hparam_set_names = hparam_set_names.split(",") 27 | hparam_sets = [HPARAMS_REGISTRY[x.strip()] for x in hparam_set_names if x] + [kwargs] 28 | for k, v in DEFAULTS.items(): 29 | H.update(v) 30 | for hps in hparam_sets: 31 | for k in hps: 32 | if k not in H: 33 | raise ValueError(f"{k} not in default args") 34 | H.update(**hps) 35 | H.update(**kwargs) 36 | return H 37 | 38 | 39 | mfcc = Hyperparams( 40 | sample_rate = 16000, 41 | mfcc_win_sz = 400, 42 | mfcc_hop_sz = 160, 43 | n_mels = 80, 44 | n_mfcc = 13, 45 | n_lc_in = 39 46 | ) 47 | 48 | HPARAMS_REGISTRY["mfcc"] = mfcc 49 | DEFAULTS["mfcc"] = mfcc 50 | 51 | wavenet = Hyperparams( 52 | filter_sz = 2, 53 | n_lc_out = 128, 54 | lc_upsample_strides = [5, 4, 4, 2], 55 | lc_upsample_filt_sizes = [25, 16, 16, 16], 56 | n_res = 368, 57 | n_dil = 256, 58 | n_skp = 256, 59 | n_post = 256, 60 | n_quant = 256, 61 | n_blocks = 2, 62 | n_block_layers = 10, 63 | n_global_embed = 10, 64 | n_speakers = 40, 65 | jitter_prob = 0.0, 66 | free_nats = 9, 67 | bias = True 68 | ) 69 | 70 | 71 | HPARAMS_REGISTRY["wavenet"] = wavenet 72 | DEFAULTS["wavenet"] = wavenet 73 | 74 | mfcc_inverter = Hyperparams( 75 | global_model = 'mfcc_inverter' 76 | ) 77 | 78 | mfcc_inverter.update(wavenet) 79 | HPARAMS_REGISTRY['mfcc_inverter'] = mfcc_inverter 80 | DEFAULTS['mfcc_inverter'] = mfcc_inverter 81 | 82 | train_tpu = Hyperparams( 83 | hw = 'TPU', 84 | n_batch = 16, 85 | n_win_batch = 5000, 86 | n_epochs = 10, 87 | save_interval = 1000, 88 | progress_interval = 1, 89 | skip_loop_body = False, 90 | n_loader_workers = 4, 91 | log_dir = '/tmp', 92 | random_seed = 2507, 93 | learning_rate_steps = [ 0, 4e6, 6e6, 8e6 ], 94 | learning_rate_rates = [ 1e-4, 5e-5, 5e-5, 5e-5 ], 95 | ckpt_template = '%.ckpt', 96 | ckpt_file = None 97 | ) 98 | 99 | HPARAMS_REGISTRY["train"] = train_tpu 100 | DEFAULTS["train"] = train_tpu 101 | 102 | test = Hyperparams( 103 | sample_rate = 16000, 104 | output_dir = '/tmp', 105 | dec_n_replicas = 1, 106 | jit_script_path = None 107 | ) 108 | 109 | HPARAMS_REGISTRY["test"] = test 110 | DEFAULTS["test"] = test 111 | 112 | -------------------------------------------------------------------------------- /jitter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Jitter(object): 4 | """Time-jitter regularization. With probability [p, (1-2p), p], replace 5 | element i with element [i-1, i, i+1] respectively. Disallow a run of 3 6 | identical elements in the output. Let p = replacement probability, s = 7 | "stay probability" = (1-2p). 8 | 9 | To prevent three-in-a-rows, P(x_t=0|x_(t-2)=2, x_(t-1)=1) = 0 and is 10 | renormalized. Otherwise, all conditional distributions have the same 11 | shape, [p, (1-2p), p]. 12 | """ 13 | def __init__(self, replace_prob): 14 | """n_win gives number of 15 | """ 16 | super(Jitter, self).__init__() 17 | p, s = replace_prob, (1 - 2 * replace_prob) 18 | self.cond2d = np.tile([p, s, p], 9).reshape(3, 3, 3) 19 | self.cond2d[2][1] = [0, s/(p+s), p/(p+s)] 20 | 21 | def __call__(self, win_size): 22 | """ 23 | populates a tensor mask to be used for jitter, and sends it to GPU for 24 | next window 25 | """ 26 | index = np.ones((win_size + 1), dtype=np.int32) 27 | for t in range(2, win_size): 28 | p2 = index[t-2] 29 | p1 = index[t-1] 30 | index[t] = np.random.choice([0,1,2], 1, False, self.cond2d[p1][p1]) 31 | index[win_size] = 1 32 | index += np.arange(-1, win_size) 33 | return index[:-1] 34 | -------------------------------------------------------------------------------- /mfcc.py: -------------------------------------------------------------------------------- 1 | # Functions for extracting Mel and MFCC information from a raw Wave file 2 | 3 | # T = time step, F = frame (every ~100 timesteps or so) 4 | # Waveform: shape(T), values (limited range integers) 5 | # MFCC + d + a: shape(F, 13 * 3 see figure 1) 6 | # Output of Layer 1 Conv: shape(F, 39, 768) 7 | 8 | # From paper: 9 | # 80 log-mel filterbank features extracted every 10ms from 25ms-long windows 10 | # 13 MFCC features 11 | 12 | # From librosa.feature.melspectrogram: 13 | # n_fft (# timesteps in FFT window) 14 | # hop_length (# timesteps between successive window positions) 15 | 16 | # From librosa.filters.mel: 17 | # n_fft ("# FFT components" (is this passed on to melspectrogram?) 18 | # n_mels (# mel bands to generate) 19 | 20 | # From librosa.feature.mfcc): 21 | # n_mfcc (# of MFCCs to return) 22 | 23 | import numpy as np 24 | import vconv 25 | import math 26 | 27 | class ProcessWav(object): 28 | def __init__(self, sample_rate=16000, win_sz=400, hop_sz=160, n_mels=80, 29 | n_mfcc=13, name=None): 30 | self.sample_rate = sample_rate 31 | self.window_sz = win_sz 32 | self.hop_sz = hop_sz 33 | self.n_mels = n_mels 34 | self.n_mfcc = n_mfcc 35 | self.n_out = n_mfcc * 3 36 | self.vc = vconv.VirtualConv(filter_info=self.window_sz, stride=self.hop_sz, 37 | parent=None, name=name) 38 | 39 | def __call__(self, wav): 40 | import librosa 41 | # See padding_notes.txt 42 | # NOTE: This function can't be executed on GPU due to the use of 43 | # librosa.feature.mfcc 44 | # C, T: n_mels, n_timesteps 45 | # Output: C, T 46 | # This assert doesn't seem to work when we just want to process an entire wav file 47 | adj = 1 if self.window_sz % 2 == 0 else 0 48 | adj_l_wing_sz = self.vc.l_wing_sz + adj 49 | 50 | left_pad = adj_l_wing_sz % self.hop_sz 51 | trim_left = adj_l_wing_sz // self.hop_sz 52 | trim_right = self.vc.r_wing_sz // self.hop_sz 53 | 54 | # wav = wav.numpy() 55 | wav_pad = np.concatenate((np.zeros(left_pad), wav), axis=0) 56 | mfcc = librosa.feature.mfcc(y=wav_pad, sr=self.sample_rate, 57 | n_fft=self.window_sz, hop_length=self.hop_sz, 58 | n_mels=self.n_mels, n_mfcc=self.n_mfcc) 59 | 60 | def mfcc_pred_output_size(in_sz, window_sz, hop_sz): 61 | '''Reverse-engineered output size calculation derived by observing the 62 | behavior of librosa.feature.mfcc''' 63 | n_extra = 1 if window_sz % 2 == 0 else 0 64 | n_pos = in_sz + n_extra 65 | return n_pos // hop_sz + (1 if n_pos % hop_sz > 0 else 0) 66 | 67 | assert mfcc.shape[1] == mfcc_pred_output_size(wav_pad.shape[0], 68 | self.window_sz, self.hop_sz) 69 | 70 | mfcc_trim = mfcc[:,trim_left:-trim_right or None] 71 | 72 | mfcc_delta = librosa.feature.delta(mfcc_trim) 73 | mfcc_delta2 = librosa.feature.delta(mfcc_trim, order=2) 74 | mfcc_and_derivatives = np.concatenate((mfcc_trim, mfcc_delta, mfcc_delta2), axis=0) 75 | 76 | return mfcc_and_derivatives 77 | 78 | -------------------------------------------------------------------------------- /mfcc_inverter.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import vconv 4 | import parse_tools 5 | import wavenet as wn 6 | import data 7 | import mfcc 8 | 9 | class MfccInverter(nn.Module): 10 | """ 11 | WaveNet model for inverting the wave to mfcc function. 12 | Autoregressively generates wave data using MFCC local conditioning vectors 13 | does not use global condition vectors 14 | """ 15 | def __init__(self, hps): 16 | super(MfccInverter, self).__init__() 17 | self.bn_type = 'none' 18 | self.mfcc = mfcc.ProcessWav( 19 | sample_rate=hps.sample_rate, win_sz=hps.mfcc_win_sz, 20 | hop_sz=hps.mfcc_hop_sz, n_mels=hps.n_mels, n_mfcc=hps.n_mfcc) 21 | 22 | mfcc_vc = vconv.VirtualConv(filter_info=hps.mfcc_win_sz, 23 | stride=hps.mfcc_hop_sz, parent=None, name='MFCC') 24 | 25 | self.wavenet = wn.WaveNet(hps, parent_vc=mfcc_vc) 26 | self.objective = wn.RecLoss() 27 | self._init_geometry(hps.n_win_batch) 28 | 29 | 30 | def override(self, n_win_batch=None): 31 | """ 32 | override values from checkpoints 33 | """ 34 | if n_win_batch is not None: 35 | self.window_batch_size = n_win_batch 36 | 37 | 38 | def _init_geometry(self, n_win_batch): 39 | end_gr = vconv.GridRange((0, 100000), (0, n_win_batch), 1) 40 | end_vc = self.wavenet.vc['end_grcc'] 41 | end_gr_actual = vconv.compute_inputs(end_vc, end_gr) 42 | 43 | mfcc_vc = self.wavenet.vc['beg'].parent 44 | beg_grcc_vc = self.wavenet.vc['beg_grcc'] 45 | 46 | self.enc_in_len = mfcc_vc.in_len() 47 | self.enc_in_mel_len = self.embed_len = mfcc_vc.child.in_len() 48 | self.dec_in_len = beg_grcc_vc.in_len() 49 | 50 | di = beg_grcc_vc.input_gr 51 | wi = mfcc_vc.input_gr 52 | 53 | self.trim_dec_in = torch.tensor( 54 | [di.sub[0] - wi.sub[0], di.sub[1] - wi.sub[0] ], 55 | dtype=torch.long) 56 | 57 | # subrange on the wav input which corresponds to the output 58 | self.trim_dec_out = torch.tensor( 59 | [end_gr.sub[0] - wi.sub[0], end_gr.sub[1] - wi.sub[0]], 60 | dtype=torch.long) 61 | 62 | self.wavenet.trim_ups_out = torch.tensor([0, beg_grcc_vc.in_len()], 63 | dtype=torch.long) 64 | 65 | self.wavenet.post_init(n_win_batch) 66 | 67 | def get_input_size(self, output_size): 68 | return self.wavenet.get_input_size(output_size) 69 | 70 | def print_geometry(self): 71 | vc = self.wavenet.vc['beg'].parent 72 | while vc: 73 | print(vc) 74 | vc = vc.child 75 | 76 | print('trim_dec_in: {}'.format(self.trim_dec_in)) 77 | print('trim_dec_out: {}'.format(self.trim_dec_out)) 78 | print('trim_ups_out: {}'.format(self.wavenet.trim_ups_out)) 79 | 80 | 81 | def forward(self, wav, mel, voice, jitter): 82 | if self.training: 83 | return self.wavenet(wav, mel, voice, jitter) 84 | else: 85 | with torch.no_grad(): 86 | return self.wavenet(wav, mel, voice, jitter) 87 | 88 | 89 | def run(self, *inputs): 90 | """ 91 | """ 92 | wav, mel, voice, jitter = inputs 93 | mel.requires_grad_(True) 94 | 95 | trim = self.trim_dec_out 96 | wav_batch_out = wav[:,trim[0]:trim[1]] 97 | quant = self.forward(*inputs) 98 | 99 | pred, target = quant[...,:-1], wav_batch_out[...,1:] 100 | 101 | loss = self.objective(pred, target) 102 | ag_inputs = (mel) 103 | (mel_grad, ) = torch.autograd.grad(loss, ag_inputs, retain_graph=True) 104 | self.objective.metrics.update({ 105 | 'mel_grad_sd': mel_grad.std(), 106 | 'mel_grad_mean': mel_grad.mean() 107 | }) 108 | return pred, target, loss 109 | 110 | -------------------------------------------------------------------------------- /netmisc.py: -------------------------------------------------------------------------------- 1 | # Miscellaneous functions for the network 2 | import torch 3 | from torch import nn 4 | import vconv 5 | from sys import stderr 6 | import sys 7 | import re 8 | import collections as col 9 | 10 | def xavier_init(mod): 11 | if hasattr(mod, 'weight') and mod.weight is not None: 12 | nn.init.xavier_uniform_(mod.weight) 13 | if hasattr(mod, 'bias') and mod.bias is not None: 14 | nn.init.constant_(mod.bias, 0) 15 | 16 | 17 | this = sys.modules[__name__] 18 | this.print_iter = 0 19 | def set_print_iter(pos): 20 | this.print_iter = pos 21 | 22 | 23 | def print_metrics(metrics, worker_index, hdr_frequency): 24 | """ 25 | Flexibly prints a polymorphic set of metrics 26 | """ 27 | nlstrip = re.compile('\\n\s+') 28 | sep = '' 29 | h = '' 30 | s = '' 31 | d = col.OrderedDict({'w_idx': worker_index}) 32 | d.update(metrics) 33 | max_width = 12 34 | 35 | for k, v in d.items(): 36 | if isinstance(v, torch.Tensor) and v.numel() == 1: 37 | v = v.item() 38 | if isinstance(v, int): 39 | fmt = '{:d}' 40 | elif isinstance(v, float): 41 | fmt = '{:.3}' if v < 1e-2 else '{:.3f}' 42 | else: 43 | fmt = '{}' 44 | val = nlstrip.sub(' ', fmt.format(v)) 45 | if len(val) > max_width and not isinstance(v, torch.Tensor): 46 | val = '~' + val[-(max_width-1):] 47 | 48 | s += sep + val 49 | h += f'{sep}{k}' 50 | sep = '\t' 51 | 52 | if this.print_iter % hdr_frequency == 0 and worker_index == 0: 53 | print(h, file=stderr) 54 | 55 | print(s, file=stderr) 56 | this.print_iter += 1 57 | stderr.flush() 58 | 59 | -------------------------------------------------------------------------------- /par/arch.ae.json: -------------------------------------------------------------------------------- 1 | { 2 | "pre_sample_rate": 16000, 3 | "pre_mfcc_win_sz": 400, 4 | "pre_mfcc_hop_sz": 160, 5 | "pre_n_mels": 80, 6 | "pre_n_mfcc": 13, 7 | "enc_n_out": 768, 8 | "bn_type": "ae", 9 | "bn_n_out": 64, 10 | "bn_vq_gamma": 0.25, 11 | "dec_filter_sz": 2, 12 | "dec_n_lc_out": 128, 13 | "dec_lc_upsample_strides": [5, 4, 4, 4], 14 | "dec_lc_upsample_filt_sizes": [25, 16, 16, 16], 15 | "dec_n_res": 368, 16 | "dec_n_dil": 256, 17 | "dec_n_skp": 256, 18 | "dec_n_post": 256, 19 | "dec_n_quant": 256, 20 | "dec_n_blocks": 2, 21 | "dec_n_block_layers": 10, 22 | "dec_n_global_embed": 10 23 | } 24 | 25 | -------------------------------------------------------------------------------- /par/arch.basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "pre_sample_rate": 16000, 3 | "pre_win_sz": 400, 4 | "pre_hop_sz": 160, 5 | "pre_n_mels": 80, 6 | "pre_n_mfcc": 13, 7 | "enc_n_out": 768, 8 | "bn_type": "vqvae", 9 | "bn_n_out": 64, 10 | "bn_vq_gamma": 0.25, 11 | "dec_filter_sz": 2, 12 | "dec_n_lc_out": 128, 13 | "dec_lc_upsample_strides": [5, 4, 4, 4], 14 | "dec_lc_upsample_filt_sizes": [25, 16, 16, 16], 15 | "dec_n_res": 368, 16 | "dec_n_dil": 256, 17 | "dec_n_skp": 256, 18 | "dec_n_post": 256, 19 | "dec_n_quant": 256, 20 | "dec_n_blocks": 2, 21 | "dec_n_block_layers": 10, 22 | "dec_n_global_embed": 10 23 | } 24 | 25 | -------------------------------------------------------------------------------- /par/arch.mi.json: -------------------------------------------------------------------------------- 1 | { 2 | "global_model": "mfcc_inverter", 3 | "pre_sample_rate": 16000, 4 | "pre_mfcc_win_sz": 400, 5 | "pre_mfcc_hop_sz": 160, 6 | "pre_n_mels": 80, 7 | "pre_n_mfcc": 13, 8 | "mi_n_lc_in": 39, 9 | "dec_filter_sz": 2, 10 | "dec_n_lc_out": 128, 11 | "dec_lc_upsample_strides": [5, 4, 4, 2], 12 | "dec_lc_upsample_filt_sizes": [25, 16, 16, 16], 13 | "dec_n_res": 368, 14 | "dec_n_dil": 256, 15 | "dec_n_skp": 256, 16 | "dec_n_post": 256, 17 | "dec_n_quant": 256, 18 | "dec_n_blocks": 2, 19 | "dec_n_block_layers": 10, 20 | "dec_n_global_embed": 10 21 | } 22 | 23 | -------------------------------------------------------------------------------- /par/arch.vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "pre_sample_rate": 16000, 3 | "pre_mfcc_win_sz": 400, 4 | "pre_mfcc_hop_sz": 160, 5 | "pre_n_mels": 80, 6 | "pre_n_mfcc": 13, 7 | "enc_n_out": 768, 8 | "bn_type": "ae", 9 | "bn_n_out": 64, 10 | "bn_vq_gamma": 0.25, 11 | "dec_filter_sz": 2, 12 | "dec_n_lc_out": 128, 13 | "dec_lc_upsample_strides": [5, 4, 4, 4], 14 | "dec_lc_upsample_filt_sizes": [25, 16, 16, 16], 15 | "dec_n_res": 368, 16 | "dec_n_dil": 256, 17 | "dec_n_skp": 256, 18 | "dec_n_post": 256, 19 | "dec_n_quant": 256, 20 | "dec_n_blocks": 2, 21 | "dec_n_block_layers": 10, 22 | "dec_n_global_embed": 10 23 | } 24 | 25 | -------------------------------------------------------------------------------- /par/arch.vqvae-ema.json: -------------------------------------------------------------------------------- 1 | { 2 | "pre_sample_rate": 16000, 3 | "pre_win_sz": 400, 4 | "pre_hop_sz": 160, 5 | "pre_n_mels": 80, 6 | "pre_n_mfcc": 13, 7 | "enc_n_out": 768, 8 | "bn_type": "vqvae-ema", 9 | "bn_n_out": 32, 10 | "bn_vq_gamma": 0.25, 11 | "bn_vq_ema_gamma": 0.99, 12 | "dec_filter_sz": 2, 13 | "dec_n_lc_out": 128, 14 | "dec_lc_upsample_strides": [5, 4, 4, 4], 15 | "dec_lc_upsample_filt_sizes": [25, 16, 16, 16], 16 | "dec_n_res": 368, 17 | "dec_n_dil": 256, 18 | "dec_n_skp": 256, 19 | "dec_n_post": 256, 20 | "dec_n_quant": 256, 21 | "dec_n_blocks": 2, 22 | "dec_n_block_layers": 10, 23 | "dec_n_global_embed": 10 24 | } 25 | 26 | -------------------------------------------------------------------------------- /par/train.basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_batch": 16, 3 | "n_win_batch": 100, 4 | "jitter_prob": 0.12, 5 | "learning_rate_steps": [ 0, 4e6, 6e6, 8e6 ], 6 | "learning_rate_rates": [ 2e-5, 2e-4, 1e-4, 5e-5 ] 7 | } 8 | -------------------------------------------------------------------------------- /par/train.mi.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_batch": 16, 3 | "n_win_batch": 100, 4 | "jitter_prob": 0.0, 5 | "learning_rate_steps": [ 0, 4e6, 6e6, 8e6 ], 6 | "learning_rate_rates": [ 2e-5, 2e-4, 1e-4, 5e-5 ] 7 | } 8 | -------------------------------------------------------------------------------- /par/train.vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_batch": 16, 3 | "n_win_batch": 100, 4 | "jitter_prob": 0.12, 5 | "learning_rate_steps": [ 0, 4e6, 6e6, 8e6 ], 6 | "learning_rate_rates": [ 2e-5, 2e-4, 1e-4, 5e-5 ], 7 | "bn_anneal_weight_steps": [0, 10, 100, 200, 1000], 8 | "bn_anneal_weight_vals": [0, 0.1, 0.2, 0.3, 0.4] 9 | } 10 | -------------------------------------------------------------------------------- /parse_tools.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | top_usage = """ 4 | Usage: train.py {new|resume} [options] 5 | 6 | train.py new [options] 7 | -- train a new model 8 | train.py resume [options] 9 | -- resume training from .ckpt file 10 | """ 11 | 12 | test_usage = """ 13 | Usage: test.py {inverter} [options] 14 | 15 | test.py inverter [options] 16 | -- generate samples from the mfcc_inverter model 17 | """ 18 | 19 | 20 | # Training options common to both "new" and "resume" training modes 21 | def train_parser(): 22 | train = argparse.ArgumentParser(add_help=False) 23 | 24 | # integer arguments 25 | iargs = [ 26 | ('nb', 'n-batch', None), 27 | ('nw', 'n-win-batch', 100), 28 | ('ms', 'max-steps', 1e20), 29 | ('si', 'save-interval', 1000), 30 | ('pi', 'progress-interval', 1), 31 | ('rnd', 'random-seed', 2507), 32 | # VAE-specific Bottleneck 33 | ('fn', 'bn-free-nats', 9) 34 | ] 35 | 36 | 37 | # other arguments 38 | args = [ 39 | ('hw', 'hwtype', str, None, 'STR', 'GPU'), 40 | ('lrs', 'learning-rate-steps', int, '+', 'INT', [0, 4e6, 6e6, 8e6]), 41 | ('lrr', 'learning-rate-rates', float, '+', 'FLOAT', [4e-4, 2e-4, 1e-4, 5e-5]), 42 | ('aws', 'bn-anneal-weight-steps', int, '+', 43 | 'INT', [0, 2e3, 4e3, 6e3, 8e3, 1e4, 2e4, 3e4, 4e4, 5e4, 6e4]), 44 | ('awv', 'bn-anneal-weight-vals', float, '+', 45 | 'FLOAT', [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 46 | ) 47 | ] 48 | 49 | # help messages 50 | hmsg = { 51 | 'nb': 'Batch size', 52 | 'nw': '# of consecutive window samples in one slice', 53 | 'ms': 'Maximum number of training steps', 54 | 'si': 'Save a checkpoint after this many steps each time', 55 | 'pi': 'Print a progress message at this interval', 56 | 'rnd': 'Random seed for weights initialization etc', 57 | 'fn': 'number of free nats in KL divergence that are not penalized', 58 | 'hw': 'Harware target, one of CPU, GPU, or TPU', 59 | 'lrs': 'Learning rate starting steps to apply --learning-rate-rates', 60 | 'lrr': 'Each of these learning rates will be applied at the ' 61 | 'corresponding value for --learning-rate-steps', 62 | 'aws': 'Learning rate starting steps to apply --anneal-weight-vals', 63 | 'awv': 'Each of these anneal weights will be applied at the ' 64 | 'corresponding step for --anneal-weight-steps' 65 | } 66 | 67 | for sopt, lopt, t, n, meta, d in args: 68 | train.add_argument('--' + lopt, '-' + sopt, type=t, nargs=n, 69 | metavar=meta, default=d, help=hmsg[sopt]) 70 | 71 | for sopt, lopt, d in iargs: 72 | train.add_argument('--' + lopt, '-' + sopt, type=int, nargs=None, 73 | metavar='INT', default=d, help=hmsg[sopt]) 74 | 75 | train.add_argument('ckpt_template', type=str, 76 | metavar='CHECKPOINT_TEMPLATE', 77 | help="Full or relative path, including a filename template, containing " 78 | "a single %%, which will be replaced by the step number.") 79 | 80 | return train 81 | 82 | # Complete parser for cold-start mode 83 | def cold_parser(): 84 | tp = train_parser() 85 | cold = argparse.ArgumentParser(parents=[tp]) 86 | 87 | cold.add_argument('--arch-file', '-af', type=str, metavar='ARCH_FILE', 88 | help='INI file specifying architectural parameters') 89 | cold.add_argument('--train-file', '-tf', type=str, metavar='TRAIN_FILE', 90 | help='INI file specifying training and other hyperparameters') 91 | 92 | # Preprocessing parameters 93 | cold.add_argument('--pre-sample-rate', '-sr', type=int, metavar='INT', default=16000, 94 | help='# samples per second in input wav files') 95 | cold.add_argument('--pre-mfcc-win-sz', '-wl', type=int, metavar='INT', default=400, 96 | help='size of the MFCC window length in timesteps') 97 | cold.add_argument('--pre-mfcc-hop-sz', '-hl', type=int, metavar='INT', default=160, 98 | help='size of the hop length for MFCC preprocessing, in timesteps') 99 | cold.add_argument('--pre-n-mels', '-nm', type=int, metavar='INT', default=80, 100 | help='number of mel frequency values to calculate') 101 | cold.add_argument('--pre-n-mfcc', '-nf', type=int, metavar='INT', default=13, 102 | help='number of mfcc values to calculate') 103 | cold.prog += ' new' 104 | 105 | # Encoder architectural parameters 106 | cold.add_argument('--enc-n-out', '-no', type=int, metavar='INT', default=768, 107 | help='number of output channels') 108 | 109 | cold.add_argument('--global-model', '-gm', type=str, metavar='STR', 110 | default='autoencoder', 111 | help='type of model (autoencoder or mfcc_inverter)') 112 | 113 | # Bottleneck architectural parameters 114 | cold.add_argument('--bn-type', '-bt', type=str, metavar='STR', 115 | default='none', 116 | help='bottleneck type (one of "ae", "vae", "vqvae", or "none")') 117 | cold.add_argument('--bn-n-out', '-bo', type=int, metavar='INT', default=64, 118 | help='number of output channels for the bottleneck') 119 | cold.add_argument('--bn-vq-gamma', '-vqb', type=float, metavar='FLOAT', default=0.25, 120 | help='beta multiplier for commitment loss term, Eq 3 from Chorowski et al.') 121 | cold.add_argument('--bn-vq-n-embed', '-vqn', type=int, metavar='INT', default=4096, 122 | help='number of embedding vectors, K, in section 3.1 of VQVAE paper') 123 | 124 | # Parameters exclusive to Mfcc Inverter 125 | cold.add_argument('--mi-n-lc-in', '-mli', type=int, metavar='INT', default=-1, 126 | help='decoder number of local conditioning input channels') 127 | 128 | # Decoder architectural parameters (also used for mfccInverter) 129 | cold.add_argument('--jitter-prob', '-jp', type=float, metavar='FLOAT', 130 | default=0.12, 131 | help='replacement probability for time-jitter regularization') 132 | cold.add_argument('--dec-filter-sz', '-dfs', type=int, metavar='INT', default=2, 133 | help='decoder number of dilation kernel elements') 134 | # !!! This is set equal to --bn-n-out 135 | cold.add_argument('--dec-n-lc-out', '-dlo', type=int, metavar='INT', default=-1, 136 | help='decoder number of local conditioning output channels') 137 | cold.add_argument('--dec-n-res', '-dnr', type=int, metavar='INT', default=-1, 138 | help='decoder number of residual channels') 139 | cold.add_argument('--dec-n-dil', '-dnd', type=int, metavar='INT', default=-1, 140 | help='decoder number of dilation channels') 141 | cold.add_argument('--dec-n-skp', '-dns', type=int, metavar='INT', default=-1, 142 | help='decoder number of skip channels') 143 | cold.add_argument('--dec-n-post', '-dnp', type=int, metavar='INT', default=-1, 144 | help='decoder number of post-processing channels') 145 | cold.add_argument('--dec-n-quant', '-dnq', type=int, metavar='INT', 146 | help='decoder number of input channels') 147 | cold.add_argument('--dec-n-blocks', '-dnb', type=int, metavar='INT', 148 | help='decoder number of dilation blocks') 149 | cold.add_argument('--dec-n-block-layers', '-dnl', type=int, metavar='INT', 150 | help='decoder number of power-of-two dilated ' 151 | 'convolutions in each layer') 152 | cold.add_argument('--dec-n-global-embed', '-dng', type=int, metavar='INT', 153 | help='decoder number of global embedding channels') 154 | 155 | # MFCC parameters 156 | cold.add_argument('--win-size', '-ws', type=int, metavar='INT', 157 | default=400, 158 | help='Number of timesteps used to calculate MFCC coefficients') 159 | cold.add_argument('--hop-size', '-hs', type=int, metavar='INT', 160 | default=160, 161 | help='Number of timesteps to hop between consecutive MFCC coefficients') 162 | 163 | # positional arguments 164 | cold.add_argument('dat_file', type=str, metavar='DAT_FILE', 165 | help='File created by preprocess.py') 166 | return cold 167 | 168 | # Complete parser for resuming from Checkpoint 169 | def resume_parser(): 170 | tp = train_parser() 171 | resume = argparse.ArgumentParser(parents=[tp], add_help=True) 172 | resume.add_argument('ckpt_file', type=str, metavar='CHECKPOINT_FILE', 173 | help="""Checkpoint file generated from a previous run. Restores model 174 | architecture, model parameters, and data generator state.""") 175 | resume.add_argument('dat_file', type=str, metavar='DAT_FILE', 176 | help='File created by preprocess.py') 177 | resume.prog += ' resume' 178 | return resume 179 | 180 | 181 | 182 | def wav_gen_parser(): 183 | wp = argparse.ArgumentParser(parents=[]) 184 | wp.add_argument('ckpt_file', type=str, metavar='CHECKPOINT_FILE', 185 | help="""Checkpoint file generated from a previous run. Restores model 186 | architecture, model parameters, and data generator state.""") 187 | wp.add_argument('dat_file', type=str, metavar='DAT_FILE', 188 | help='File created by preprocess.py') 189 | wp.add_argument('--dec-n-replicas', '-nsr', type=int, metavar='INT', 190 | default=1, 191 | help='Number of output to generate for each input datum') 192 | wp.add_argument('--output-dir', '-od', type=str, metavar='STR', 193 | default='.', 194 | help="Directory to write output .wav files") 195 | wp.add_argument('--hwtype', '-hw', type=str, metavar='STR', 196 | default='GPU', 197 | help='Hardware type (GPU, TPU-single or TPU)') 198 | wp.add_argument('--jit-script-path', '-js', type=str, metavar='STR', 199 | default=None, 200 | help='If provided, save jit script for the wavenet model here, and exit') 201 | wp.add_argument('--data-write-tmpl', '-dw', type=str, metavar='STR', 202 | default=None, 203 | help='If provided, save data batch tensors here') 204 | wp.add_argument('--n-timesteps', '-nt', type=int, metavar='INT', 205 | default=None, 206 | help='If provided, only infer for this many timesteps') 207 | 208 | return wp 209 | 210 | 211 | def two_stage_parse(cold_parser, args=None): 212 | '''wrapper for parse_args for overriding options from file''' 213 | default_opts = cold_parser.parse_args(args) 214 | 215 | cli_parser = argparse.ArgumentParser(parents=[cold_parser], add_help=False) 216 | dests = {co.dest:argparse.SUPPRESS for co in cli_parser._actions} 217 | cli_parser.set_defaults(**dests) 218 | cli_parser._defaults = {} # hack to overcome bug in set_defaults 219 | cli_opts = cli_parser.parse_args(args) 220 | 221 | # Each option follows the rule: 222 | # Use JSON file setting if present. Otherwise, use command-line argument, 223 | # Otherwise, use command-line default 224 | import json 225 | try: 226 | with open(cli_opts.arch_file) as fp: 227 | arch_opts = json.load(fp) 228 | except AttributeError: 229 | arch_opts = {} 230 | except FileNotFoundError: 231 | print("Error: Couldn't open arch parameters file {}".format(cli_opts.arch_file)) 232 | exit(1) 233 | 234 | try: 235 | with open(cli_opts.train_file) as fp: 236 | train_opts = json.load(fp) 237 | except AttributeError: 238 | train_opts = {} 239 | except FileNotFoundError: 240 | print("Error: Couldn't open train parameters file {}".format(cli_opts.train_file)) 241 | exit(1) 242 | 243 | # Override with command-line settings, then defaults 244 | merged_opts = vars(default_opts) 245 | merged_opts.update(arch_opts) 246 | merged_opts.update(train_opts) 247 | merged_opts.update(vars(cli_opts)) 248 | 249 | # Convert back to a Namespace object 250 | return argparse.Namespace(**merged_opts) 251 | # return cli_opts 252 | 253 | 254 | def get_prefixed_items(d, pfx): 255 | '''select all items whose keys start with pfx, and strip that prefix''' 256 | return { k[len(pfx):]:v for k,v in d.items() if k.startswith(pfx) } 257 | 258 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from sys import stderr 2 | import argparse 3 | import data 4 | 5 | def make_parser(): 6 | p = argparse.ArgumentParser() 7 | p.add_argument('--n-quant', '-nq', type=int, metavar='INT', 8 | default=256, help='Number of quantization levels for Mu-law companding') 9 | p.add_argument('--sample-rate', '-sr', type=int, metavar='INT', 10 | default=16000, help='Number of samples per second for parsing sound files') 11 | 12 | # positional arguments 13 | p.add_argument('sam_file', type=str, metavar='SAMPLES_FILE', 14 | help='File containing lines:\n' 15 | + '\t/path/to/sample1.flac\n' 16 | + '\t/path/to/sample2.flac\n') 17 | p.add_argument('dat_file', type=str, metavar='OUTPUT_DAT_FILENAME', 18 | help='Name for output file to produce') 19 | return p 20 | 21 | def main(): 22 | parser = make_parser() 23 | opts = parser.parse_args() 24 | 25 | print('Starting...', file=stderr) 26 | stderr.flush() 27 | 28 | catalog = data.parse_catalog(opts.sam_file) 29 | data.convert(catalog, opts.dat_file, opts.n_quant, opts.sample_rate) 30 | print('Wrote catalog to {}'.format(opts.dat_file), 31 | file=stderr) 32 | return 0 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | 38 | 39 | -------------------------------------------------------------------------------- /results/a39d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrbigelow/ae-wavenet/80b9c46637151f053f74728fc756f1d01ab0aa69/results/a39d.png -------------------------------------------------------------------------------- /scripts/librispeech_to_rdb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -eq 0 ] 4 | then 5 | echo 'Usage:' 6 | echo "$(basename $0) /path/to/LibriSpeech/{dev-clean,test-clean} > librispeech.rdb" 7 | exit 1 8 | fi 9 | 10 | datadir=$(readlink -m $1) 11 | for dirpath in $(find $datadir -maxdepth 1 -type d -regex '.+[0-9]+') 12 | do 13 | speaker_id=$(basename $dirpath) 14 | for filepath in $(find $dirpath -type f -name '*.flac') 15 | do 16 | echo -en "$speaker_id\t$filepath\n" 17 | done 18 | done 19 | 20 | -------------------------------------------------------------------------------- /scripts/train_plot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib import cm 6 | 7 | 8 | def read_files(*filenames): 9 | lines = {} 10 | for filename in filenames: 11 | with open(filename,'r') as fh: 12 | for line in fh: 13 | fields = line.split('\t') 14 | try: 15 | step = int(fields[0]) 16 | except ValueError: 17 | continue 18 | lines[step] = line 19 | return list(map(lambda i: i[1], sorted(lines.items(), key=lambda i: i[0]))) 20 | 21 | 22 | def main(): 23 | lines = read_files(*sys.argv[1:]) 24 | buf = io.StringIO() 25 | for line in lines: 26 | buf.write(line) 27 | buf.seek(0) 28 | data = np.loadtxt(buf, delimiter='\t') 29 | cms = cm.ScalarMappable(cmap=cm.Reds) 30 | cms.set_clim(10, 18) 31 | for n in range(10, 18): 32 | l = 'layer_{}'.format(n-9) 33 | plt.plot(data[:,0], data[:,n], color=cms.to_rgba(n), label=l) 34 | plt.legend() 35 | plt.show() 36 | plt.plot(data[:,0], data[:,6]) 37 | plt.show() 38 | plt.plot(data[:,0], data[:,3]) 39 | plt.show() 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /scripts/viewlog.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -eq 0 ] 4 | then 5 | echo 'Usage: ' 6 | echo "$(basename $0) file.log [M|S]" 7 | exit 1 8 | fi 9 | 10 | file=$1 11 | mode=$2 12 | tab=$(echo -en '\t') 13 | grep -E "^$mode$tab" $file | column -t -s "$tab" 14 | 15 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Test trained MFCC inverter model 2 | # Initialize the geometry so that the trim_ups_out is zero 3 | 4 | 5 | import sys 6 | from sys import stderr 7 | import torch as t 8 | import fire 9 | import parse_tools 10 | import checkpoint 11 | import chassis 12 | from hparams import setup_hparams, Hyperparams 13 | 14 | 15 | def run(dat_file, hps='mfcc_inverter,mfcc,test', **kwargs): 16 | hps = setup_hparams(hps, kwargs) 17 | assert hps.hw in ('GPU', 'CPU'), 'Currently, Only GPU or CPU supported for sampling' 18 | 19 | if 'random_seed' not in hps: 20 | hps.random_seed = 2507 21 | 22 | if hps.hw == 'GPU': 23 | if not t.cuda.is_available(): 24 | raise RuntimeError('GPU requested but not available') 25 | # elif hps.hw in ('TPU', 'TPU-single'): 26 | # import torch_xla.distributed.xla_multiprocessing as xmp 27 | elif hps.hw == 'CPU': 28 | pass 29 | else: 30 | raise RuntimeError( 31 | ('Invalid device {} requested. ' 32 | + 'Must be GPU or TPU').format(hps.hw)) 33 | 34 | print('Using {}'.format(hps.hw), file=stderr) 35 | stderr.flush() 36 | 37 | # generate requested data 38 | # n_quant = ch.state.model.wavenet.n_quant 39 | 40 | 41 | if hps.hw in ('CPU', 'GPU'): 42 | if hps.hw == 'GPU': 43 | device = t.device('cuda') 44 | hps.n_loader_workers = 0 45 | else: 46 | device = t.device('cpu') 47 | 48 | chs = chassis.InferenceChassis(device, 0, hps, dat_file) 49 | if hps.jit_script_path: 50 | # data_scr = t.jit.script(chs.state.data_loader.dataset) 51 | model_scr = t.jit.script(chs.state.model.wavenet) 52 | model_scr.save(hps.jit_script_path) 53 | model_scr.to(chs.device) 54 | # print(model_scr.code) 55 | print('saved {}'.format(hps.jit_script_path)) 56 | chs.infer(model_scr) 57 | return 58 | 59 | # chs.state.model.print_geometry() 60 | chs.infer() 61 | # elif hps.hw == 'TPU': 62 | # def _mp_fn(index, mode, hps): 63 | # m = chassis.InferenceChassis(mode, hps) 64 | # m.infer(index) 65 | # xmp.spawn(_mp_fn, args=(mode, hps), nprocs=1, start_method='fork') 66 | # elif hps.hw == 'TPU-single': 67 | # chs = chassis.InferenceChassis(mode, hps) 68 | # chs.infer() 69 | 70 | 71 | if __name__ == '__main__': 72 | print(sys.executable, ' '.join(arg for arg in sys.argv), file=stderr, 73 | flush=True) 74 | fire.Fire(run) 75 | 76 | -------------------------------------------------------------------------------- /test_data.py: -------------------------------------------------------------------------------- 1 | import data 2 | import pickle 3 | 4 | sample_rate = 16000 5 | frac_perm_use = 0.1 6 | req_wav_buf_sz = 1e7 7 | sam_file = '/home/henry/ai/data/librispeech.dev-clean.rdb' 8 | n_batch = 4 9 | input_size = 4000 10 | output_size = 2000 11 | 12 | sample_catalog = data.parse_sample_catalog(sam_file) 13 | dwav = data.WavSlices(sample_catalog, sample_rate, frac_perm_use, req_wav_buf_sz) 14 | dwav.set_geometry(n_batch, input_size, output_size) 15 | 16 | batch_gen = dwav.batch_slice_gen_fn() 17 | 18 | for i in range(1000): 19 | __, voice_inds, wav = next(batch_gen) 20 | # print(dwav, end='') 21 | 22 | dwav_state = pickle.dumps(dwav) 23 | print('Yielding: ', dwav, end='') 24 | __, voice_inds, wav = next(batch_gen) 25 | 26 | dwav_r = pickle.loads(dwav_state) 27 | print('Restored: ', dwav_r, end='') 28 | dwav_r.set_geometry(n_batch, input_size, output_size) 29 | 30 | batch_gen_r = dwav_r.batch_slice_gen_fn() 31 | print('Yielding: ', dwav_r, end='') 32 | __, voice_inds_r, wav_r = next(batch_gen_r) 33 | 34 | assert (voice_inds_r == voice_inds).all() 35 | assert (wav_r == wav).all() 36 | 37 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import model 2 | 3 | def main(ind_pfx, batch_size): 4 | data_source = data.Slice(ind_pfx, batch_size) 5 | 6 | -------------------------------------------------------------------------------- /test_vconv.old.py: -------------------------------------------------------------------------------- 1 | from sys import exit 2 | import ast 3 | import argparse 4 | import fractions 5 | import numpy as np 6 | import util 7 | import copy 8 | 9 | def _round_up(val, step): 10 | """Round up to nearest step at phase""" 11 | return val + (-val) % step 12 | 13 | class VirtualConv(object): 14 | def __init__(self, lw, rw, osp, isp, parent=None, name=None): 15 | self.lw = lw 16 | self.rw = rw 17 | self.osp = osp 18 | self.isp = isp 19 | self.parent = parent 20 | if self.parent is not None: 21 | self.parent.child = self 22 | self.child = None 23 | self.name = name 24 | 25 | def __repr__(self): 26 | return '({},{},{},{}) : {}'.format(self.lw, self.rw, self.osp, self.isp, self.name) 27 | 28 | def mul(self, factor): 29 | self.lw *= factor 30 | self.rw *= factor 31 | self.isp *= factor 32 | self.osp *= factor 33 | 34 | def reduce(self): 35 | pass 36 | 37 | 38 | def rfield(self, out_b, out_e, in_l): 39 | """ 40 | Returns receptive field range [in_b, in_e) 41 | from output range [out_b, out_e). in_l is the total length 42 | of input. 43 | """ 44 | if out_b == out_e: 45 | return 0, 0 46 | if_min = 0 47 | if_max = (in_l - 1) * self.isp 48 | b_si_min = max(if_min, out_b * self.osp) 49 | b_si_max = min(if_max, out_b * self.osp + self.lw + self.rw) 50 | if b_si_min > b_si_max: 51 | return 0, 0 52 | b_ii = _round_up(b_si_min, self.isp) // self.isp 53 | 54 | e_si_min = max(if_min, (out_e - 1) * self.osp) 55 | e_si_max = min(if_max, (out_e - 1) * self.osp + self.lw + self.rw) 56 | if e_si_min > e_si_max: 57 | return 0, 0 58 | e_ii = e_si_max // self.isp + 1 59 | return b_ii, e_ii 60 | 61 | def ifield(self, in_b, in_e, in_l): 62 | """ 63 | Returns induced field range [out_b, out_e) 64 | from input range [in_b, in_e) 65 | """ 66 | if in_b == in_e: 67 | return 0, 0 68 | if_min = self.lw 69 | if_max = (in_l - 1) * self.isp - self.rw 70 | b_si_min = max(if_min, in_b * self.isp - self.rw) 71 | b_si_max = min(if_max, in_b * self.isp + self.lw) 72 | if b_si_min > b_si_max: 73 | return 0, 0 74 | b_oi = _round_up(b_si_min - self.lw, self.osp) // self.osp 75 | 76 | e_si_min = max(if_min, (in_e - 1) * self.isp - self.rw) 77 | e_si_max = min(if_max, (in_e - 1) * self.isp + self.lw) 78 | if e_si_min > e_si_max: 79 | return 0, 0 80 | e_oi = (e_si_max - self.lw) // self.osp + 1 81 | return b_oi, e_oi 82 | 83 | def shadow(self, in_b, in_e, in_l): 84 | """ 85 | Return the index range [shadow_in_b, shadow_in_e), which is the largest 86 | range of input that lies underneath the induced output [out_b, out_e). 87 | "underneath" here is based on the physical position induced by the 88 | structure of the filter, as defined by the left and right wing sizes. 89 | """ 90 | out_b, out_e = self.ifield(in_b, in_e, in_l) 91 | if out_b == out_e: 92 | return 0, 0 93 | b_si = self.lw + out_b * self.osp 94 | b_ii = _round_up(b_si, self.isp) // self.isp 95 | e_si = self.lw + (out_e - 1) * self.osp 96 | e_ii = e_si // self.isp + 1 97 | return b_ii, e_ii 98 | 99 | 100 | def merge_child(vc): 101 | if vc.child is None: 102 | raise RuntimeError('Cannot merge vc. No child node') 103 | 104 | lcm = np.lcm.reduce([vc.osp, vc.child.isp]) 105 | m1 = lcm // vc.osp 106 | m2 = lcm // vc.child.isp 107 | n1 = copy.copy(vc) 108 | n2 = copy.copy(vc.child) 109 | n1.mul(m1) 110 | n2.mul(m2) 111 | n1.lw = n1.lw + n2.lw 112 | n1.rw = n1.rw + n2.rw 113 | n1.osp = n2.osp 114 | n1.parent = None 115 | n1.child = vc.child.child 116 | return n1 117 | 118 | def merge_range(source, dest): 119 | if source is dest: 120 | return copy.copy(source) 121 | 122 | vc = source 123 | while vc.child is not dest: 124 | vc = merge_child(vc) 125 | 126 | vc = merge_child(vc) 127 | return vc 128 | 129 | 130 | class VConvNode(object): 131 | def __init__(self, index, position=None): 132 | self.index = index 133 | self.position = position 134 | self.parents = [] 135 | self.children = [] 136 | self.left = None 137 | self.right = None 138 | self.is_output = False 139 | 140 | def __repr__(self): 141 | return '{}:{}'.format(self.index, self.position) 142 | 143 | def pad(vc, input, spacing): 144 | """Add padding to either side of input, initialized with space between 145 | each element""" 146 | pad_input = [] 147 | for i in range(vc.l_pad): 148 | n = VConvNode(-1) 149 | n.position = input[0].position - spacing 150 | pad_input.insert(0, n) 151 | pad_input.extend(input) 152 | for i in range(vc.r_pad): 153 | n = VConvNode(-1) 154 | n.position = input[-1].position + spacing 155 | pad_input.append(n) 156 | return pad_input 157 | 158 | def space(input, n_nodes): 159 | """Add n_nodes spacing elements between each element of input. 160 | Preserve the original input spacing""" 161 | if len(input) < 2: 162 | return 163 | spaced_input = [] 164 | old_space = input[1].position - input[0].position 165 | new_space = old_space / (n_nodes + 1) 166 | for i, n in enumerate(input): 167 | spaced_input.append(n) 168 | if i < len(input) - 1: 169 | for p in range(n_nodes): 170 | ne = VConvNode(-1) 171 | ne.position = spaced_input[-1].position + new_space 172 | spaced_input.append(ne) 173 | return spaced_input 174 | 175 | def init_neighbors(nodes): 176 | pn = None 177 | for n in nodes: 178 | n.left = pn 179 | pn = n 180 | pn = None 181 | for n in reversed(nodes): 182 | n.right = pn 183 | pn = n 184 | 185 | def build_graph(n_input, source, dest): 186 | unit = fractions.Fraction(1, 1) 187 | input = list(map(lambda i: VConvNode(i, unit * i), range(n_input))) 188 | in_layer = input 189 | vc = source 190 | 191 | while True: 192 | #input = pad(vc, input, spacing) 193 | if vc.isp > 1: 194 | input = space(input, vc.isp - 1) 195 | if in_layer is None: 196 | in_layer = input 197 | 198 | init_neighbors(input) 199 | 200 | step = vc.osp 201 | w = vc.lw + vc.rw 202 | output = [] 203 | # locate indices of first and last value elements 204 | n_input = len(input) 205 | for i, n in enumerate(input): 206 | if n.index != -1: 207 | first_val_index = i 208 | break 209 | 210 | for i in reversed(range(n_input)): 211 | n = input[i] 212 | if n.index != -1: 213 | last_val_index = i 214 | break 215 | 216 | for oi, ii in enumerate(range(0, n_input - w, step)): 217 | result = VConvNode(oi) 218 | result.parents = input[ii:ii + w + 1] 219 | result.position = result.parents[vc.lw].position 220 | # add in pseudo-parents if all parents are fill-values 221 | #if i < first_val_index: 222 | # if all(map(lambda p: p.index == -1, result.parents)): 223 | # result.parents.append(input[first_val_index]) 224 | #if i > last_val_index: 225 | # if all(map(lambda p: p.index == -1, result.parents)): 226 | # result.parents.append(input[last_val_index]) 227 | for p in result.parents: 228 | p.children.append(result) 229 | output.append(result) 230 | 231 | input = output 232 | if vc is dest: 233 | out_layer = output 234 | init_neighbors(out_layer) 235 | for n in out_layer: 236 | n.is_output = True 237 | break 238 | vc = vc.child 239 | return in_layer, out_layer 240 | 241 | 242 | def graph_rfield(out_layer, out_b, out_e): 243 | # Successively search lower bound receptive field for out_b 244 | if out_b == out_e: 245 | return 0, 0 246 | 247 | n = out_layer[out_b] 248 | assert n.index == out_b 249 | 250 | while len(n.parents) > 0: 251 | # find first non-filler parent 252 | for c in n.parents: 253 | if c.index != -1: 254 | n = c 255 | break 256 | b = n.index 257 | n = out_layer[out_e-1] 258 | while len(n.parents) > 0: 259 | # find last non-filler parent 260 | for c in reversed(n.parents): 261 | if c.index != -1: 262 | n = c 263 | break 264 | e = n.index 265 | return b, e + 1 266 | 267 | 268 | def graph_ifield(in_layer, in_b, in_e): 269 | """Search up through the graph for field of influence of the input range 270 | [in_b, in_e) 271 | """ 272 | if in_b == in_e: 273 | return 0, 0 274 | 275 | n = in_layer[in_b] 276 | while not n.is_output: 277 | # Traverse the layers upwards through first-child links 278 | # and to the right through layer indices 279 | while n is not None and len(n.children) == 0: 280 | n = n.right 281 | if n is None: 282 | return 0, 0 283 | else: 284 | n = n.children[0] 285 | b = n.index 286 | 287 | n = in_layer[in_e - 1] 288 | while not n.is_output: 289 | # Traverse the layers upwards through first-child links 290 | # and to the right through layer indices 291 | while len(n.children) == 0 and n.left is not None: 292 | n = n.left 293 | if n is None: 294 | return 0, 0 295 | else: 296 | n = n.children[-1] 297 | e = n.index 298 | return b, e + 1 299 | 300 | 301 | def graph_shadow(in_layer, out_layer, in_b, in_e): 302 | out_b, out_e = graph_ifield(in_layer, in_b, in_e) 303 | if out_b == out_e: 304 | return 0, 0 305 | # search through the in_layer until the matching position is found 306 | out_b_pos = out_layer[out_b].position 307 | out_e_pos = out_layer[out_e-1].position 308 | positions = list(map(lambda n: n.position, in_layer)) 309 | lb_b = util.greatest_lower_bound(positions, out_b_pos) 310 | for i in range(lb_b, len(in_layer)): 311 | n = in_layer[i] 312 | if n.position >= out_b_pos: 313 | shadow_b = i 314 | break 315 | lb_e = util.greatest_lower_bound(positions, out_e_pos) 316 | shadow_e = lb_e 317 | #for i in range(lb_e, len(in_layer)): 318 | # n = in_layer[i] 319 | # if n.position <= out_e_pos: 320 | # shadow_e = i 321 | # break 322 | return shadow_b, shadow_e + 1 323 | 324 | def get_parser(): 325 | p = argparse.ArgumentParser() 326 | p.add_argument('--n-input', '-n', type=int, metavar='INT', 327 | help='Number of input elements to the transformations') 328 | p.add_argument('--model-file', '-f', type=str, metavar='STR', 329 | help='File with the structure of each transformation') 330 | p.add_argument('--print-override', '-p', action='store_true', default=False, 331 | help='If given, print comparisons even if they do not differ') 332 | return p 333 | 334 | 335 | def main(): 336 | 337 | parser = get_parser() 338 | opts = parser.parse_args() 339 | 340 | n_input = opts.n_input 341 | 342 | source = None 343 | cur_vc = None 344 | with open(opts.model_file) as fh: 345 | for line in fh.readlines(): 346 | vals = ast.literal_eval(line) 347 | vals.insert(4, cur_vc) 348 | cur_vc = VirtualConv(*tuple(vals)) 349 | if source is None: 350 | source = cur_vc 351 | dest = cur_vc 352 | 353 | vc = merge_range(source, dest) 354 | __, min_input = vc.rfield(0, 1, 100000000) 355 | if n_input < min_input: 356 | print('Given n_input {} less than minimum input {} required for any output'.format( 357 | n_input, min_input)) 358 | exit(1) 359 | 360 | print('Original range: ') 361 | cur_vc = source 362 | while True: 363 | print(cur_vc) 364 | if cur_vc is dest: 365 | break 366 | cur_vc = cur_vc.child 367 | print('') 368 | print('Merged range:\n{}'.format(vc)) 369 | print('') 370 | 371 | in_layer, out_layer = build_graph(n_input, source, dest) 372 | 373 | 374 | # Test combinations of intervals 375 | for in_b in range(0, n_input): 376 | if in_b % 100 == 0: 377 | print('ifield start range {}'.format(in_b)) 378 | for in_e in range(in_b + 1, n_input + 1): 379 | t_out = vc.ifield(in_b, in_e, n_input) 380 | a_out = graph_ifield(in_layer, in_b, in_e) 381 | #if t_out != a_out and t_out[1] != t_out[0] and a_out[1] != a_out[0]: 382 | if t_out != a_out or opts.print_override: 383 | print('ifield: in: {}, test: {}, act: {}'.format( 384 | (in_b, in_e), t_out, a_out)) 385 | 386 | __, n_output = vc.ifield(0, n_input, n_input) 387 | for out_b in range(0, n_output): 388 | if out_b % 100 == 0: 389 | print('rfield start range {}'.format(out_b)) 390 | for out_e in range(out_b, n_output + 1): 391 | t_in = vc.rfield(out_b, out_e, n_input) 392 | a_in = graph_rfield(out_layer, out_b, out_e) 393 | #if t_in != a_in and t_in[0] != t_in[1] and a_in[0] != a_in[1]: 394 | if t_in != a_in or opts.print_override: 395 | print('rfield: out: {}, test: {}, act: {}'.format( 396 | (out_b, out_e), t_in, a_in)) 397 | 398 | for in_b in range(0, n_input): 399 | if in_b % 10 == 0: 400 | print('shadow start range {}'.format(in_b)) 401 | for in_e in range(in_b, n_input + 1): 402 | t_s = vc.shadow(in_b, in_e, n_input) 403 | a_s = graph_shadow(in_layer, out_layer, in_b, in_e) 404 | #if t_s != a_s and t_s[0] != t_s[1] and a_s[0] != a_s[1]: 405 | if t_s != a_s or opts.print_override: 406 | print('shadow: in: {}, test: {}, act: {}'.format( 407 | (in_b, in_e), t_s, a_s)) 408 | 409 | print('Finished') 410 | 411 | if __name__ == '__main__': 412 | main() 413 | 414 | -------------------------------------------------------------------------------- /test_vconv.py: -------------------------------------------------------------------------------- 1 | import vconv 2 | from enum import Enum 3 | from collections import Counter 4 | from fractions import Fraction 5 | from collections import namedtuple 6 | import itertools 7 | 8 | TestInput = namedtuple('TestInput', 9 | [ 10 | 'name', 'lw', 'rw', 'lp', 'rp', 'start', 'l1', 11 | 'l2', 'l3', 'gs', 'strides', 'inv_strides', 12 | 'report_freq' 13 | ] 14 | ) 15 | 16 | 17 | t1 = TestInput( 18 | name='Many Convolutions', 19 | lw=range(0, 20), 20 | rw=range(0, 20), 21 | lp=range(0, 8), 22 | rp=range(0, 8), 23 | start=range(0, 1), 24 | l1=range(25, 26), 25 | l2=range(25, 26), 26 | l3=range(50, 51), 27 | gs=range(1, 2), 28 | strides=[1,2,3,4,5], 29 | inv_strides=[2,3,4,5], 30 | report_freq=10000 31 | ) 32 | 33 | skip = 5 34 | 35 | t2 = TestInput( 36 | name='Many inputs', 37 | lw=range(3, 4), 38 | rw=range(3, 4), 39 | lp=range(0, 1), 40 | rp=range(0, 1), 41 | start=range(0, 200, skip), 42 | l1=range(0, 200, skip), 43 | l2=range(0, 200, skip), 44 | l3=range(0, 200, skip), 45 | gs=range(1, 10), 46 | strides=[1,2,3,4,5], 47 | inv_strides=[2,3,4,5], 48 | report_freq=10000 49 | ) 50 | 51 | class Result(Enum): 52 | NO_OUTPUT = 1 53 | NO_INPUT = 2 54 | UNEQUAL = 3 55 | SUCCESS = 4 56 | 57 | model = [ 58 | # lw, rw, numer_stride, denom_stride, lp, rp 59 | ((199, 200), (0, 0), 160, True, "MFCC"), 60 | ((1, 1), (0, 0), 1, True, "CRR_0"), 61 | ((1, 1), (0, 0), 1, True, "CRR_1"), 62 | ((1, 2), (0, 0), 2, True, "CRR_2"), 63 | ((1, 1), (0, 0), 1, True, "CRR_3"), 64 | ((1, 1), (0, 0), 1, True, "CRR_4"), 65 | ((0, 0), (0, 0), 1, True, "CRR_5"), 66 | ((0, 0), (0, 0), 1, True, "CRR_6"), 67 | ((0, 0), (0, 0), 1, True, "CRR_7"), 68 | ((0, 0), (0, 0), 1, True, "CRR_7"), 69 | ((1, 1), (0, 0), 1, True, "LC_Conv"), 70 | ((12, 12), (4, 4), 5, False, "Upsampling_0"), 71 | ((7, 8), (3, 3), 4, False, "Upsampling_1"), 72 | ((7, 8), (3, 3), 4, False, "Upsampling_2"), 73 | ((7, 8), (3, 3), 4, False, "Upsampling_3"), 74 | ((1, 0), (0, 0), 1, True, "GRCC_0,0"), 75 | ((2, 0), (0, 0), 1, True, "GRCC_0,1"), 76 | ((4, 0), (0, 0), 1, True, "GRCC_0,2"), 77 | ((8, 0), (0, 0), 1, True, "GRCC_0,3"), 78 | ((16, 0), (0, 0), 1, True, "GRCC_0,4"), 79 | ((32, 0), (0, 0), 1, True, "GRCC_0,5"), 80 | ((64, 0), (0, 0), 1, True, "GRCC_0,6"), 81 | ((128, 0), (0, 0), 1, True, "GRCC_0,7"), 82 | ((256, 0), (0, 0), 1, True, "GRCC_0,8"), 83 | ((512, 0), (0, 0), 1, True, "GRCC_0,9"), 84 | ((1, 0), (0, 0), 1, True, "GRCC_1,0"), 85 | ((2, 0), (0, 0), 1, True, "GRCC_1,1"), 86 | ((4, 0), (0, 0), 1, True, "GRCC_1,2"), 87 | ((8, 0), (0, 0), 1, True, "GRCC_1,3"), 88 | ((16, 0), (0, 0), 1, True, "GRCC_1,4"), 89 | ((32, 0), (0, 0), 1, True, "GRCC_1,5"), 90 | ((64, 0), (0, 0), 1, True, "GRCC_1,6"), 91 | ((128, 0), (0, 0), 1, True, "GRCC_1,7"), 92 | ((256, 0), (0, 0), 1, True, "GRCC_1,8"), 93 | ((512, 0), (0, 0), 1, True, "GRCC_1,9") 94 | ] 95 | 96 | 97 | def make_vcs(): 98 | vc = None 99 | vcs = {} 100 | for m in model: 101 | vc = vconv.VirtualConv(*m, parent=vc) 102 | vcs[vc.name] = vc 103 | return vcs 104 | 105 | vcs = make_vcs() 106 | 107 | 108 | def same_or_upsample_test(vc, x): 109 | try: 110 | y = vconv.output_range(vc, vc, x) 111 | except RuntimeError: 112 | return Result.NO_OUTPUT 113 | try: 114 | xn = vconv.input_range(vc, vc, y) 115 | except RuntimeError: 116 | return Result.NO_INPUT 117 | 118 | if xn != x: 119 | return Result.UNEQUAL 120 | else: 121 | return Result.SUCCESS 122 | 123 | 124 | def downsample_test(vc, x): 125 | try: 126 | y = vconv.output_range(vc, vc, x) 127 | except RuntimeError: 128 | return Result.NO_OUTPUT 129 | try: 130 | xn = vconv.input_range(vc, vc, y) 131 | except RuntimeError: 132 | return Result.NO_INPUT 133 | 134 | try: 135 | yt = vconv.output_range(vc, vc, xn) 136 | except RuntimeError: 137 | return Result.NO_OUTPUT 138 | try: 139 | xt = vconv.input_range(vc, vc, yt) 140 | except RuntimeError: 141 | return Result.NO_INPUT 142 | 143 | if xn != xt: 144 | return Result.UNEQUAL 145 | else: 146 | return Result.SUCCESS 147 | 148 | 149 | def grid_range(f_b, l1, l2, l3, gs, inv_stride): 150 | gs *= inv_stride 151 | s_b = f_b + l1 * gs 152 | s_e = s_b + l2 * gs + 1 153 | f_e = s_e + l3 * gs 154 | return vconv.GridRange((f_b, f_e), (s_b, s_e), gs) 155 | 156 | 157 | def input_gen(t): 158 | for lw, rw, lp, rp in itertools.product(t.lw, t.rw, t.lp, t.rp): 159 | for st in t.strides: 160 | try: 161 | vc = vconv.VirtualConv((lw, rw), (lp, rp), st, True, 'Conv', None) 162 | except RuntimeError: 163 | continue 164 | print('lw: {}, rw: {}, lp: {}, rp: {}, st: {}'.format(lw, rw, lp, 165 | rp, st)) 166 | for spec in itertools.product(t.start, t.l1, t.l2, t.l3, t.gs): 167 | yield vc, grid_range(*spec, 1) 168 | for ist in t.inv_strides: 169 | try: 170 | vc = vconv.VirtualConv((lw, rw), (lp, rp), ist, False, 'Conv', None) 171 | except RuntimeError: 172 | continue 173 | print('lw: {}, rw: {}, lp: {}, rp: {}, ist: {}'.format(lw, rw, lp, 174 | rp, ist)) 175 | for spec in itertools.product(t.start, t.l1, t.l2, t.l3, t.gs): 176 | yield vc, grid_range(*spec, vc.stride_ratio.denominator) 177 | 178 | 179 | def main_test(inputs): 180 | t = inputs 181 | c = 0 182 | results = Counter() 183 | print('Test: {}'.format(t.name)) 184 | for vc, x in input_gen(t): 185 | if vc.stride_ratio.numerator > 1: 186 | res = downsample_test(vc, x) 187 | else: 188 | res = same_or_upsample_test(vc, x) 189 | results[res] += 1 190 | if c > 0 and c % t.report_freq == 0: 191 | print(results) 192 | c += 1 193 | 194 | print('Finished') 195 | print('Results: {}'.format(results)) 196 | 197 | 198 | x = vconv.GridRange((0, 250000), (0, 250000), 1) 199 | y = vconv.output_range(vcs['MFCC'], vcs['GRCC_1,9'], x) 200 | xi = vconv.input_range(vcs['MFCC'], vcs['GRCC_1,9'], y) 201 | 202 | #print('x0: {}'.format(x)) 203 | #print('y0: {}'.format(y)) 204 | #print('xi: {}'.format(xi)) 205 | 206 | 207 | def autoenc_test(vcs, in_len, slice_beg): 208 | enc = vcs['MFCC'], vcs['Upsampling_3'] 209 | dec = vcs['GRCC_0,0'], vcs['GRCC_1,9'] 210 | mfcc = vcs['MFCC'], vcs['MFCC'] 211 | autoenc = vcs['MFCC'], vcs['GRCC_1,9'] 212 | 213 | full_in = vconv.GridRange((0, in_len), (0, in_len), 1) 214 | full_mfcc = vconv.output_range(*mfcc, full_in) 215 | full_out = vconv.output_range(*autoenc, full_in) 216 | 217 | out_req = vconv.GridRange(full_out.full, (slice_beg, slice_beg + 100), 1) 218 | mid_req = vconv.input_range(*dec, out_req) 219 | in_req = vconv.input_range(*enc, mid_req) 220 | in_act = in_req 221 | mfcc_act = vconv.output_range(*mfcc, in_act) 222 | mid_act = vconv.output_range(*enc, in_act) 223 | 224 | # wav -> wav_mid 225 | wav_mid_sl = vconv.tensor_slice(in_act, mid_req.sub) 226 | # wav_mid_ten = wav_ten[wav_mid_sl] 227 | 228 | # lcond -> lcond_sl 229 | lcond_sl = vconv.tensor_slice(mid_act, mid_req.sub) 230 | # lcond_sl_ten = lcond_ten[lcond_sl] 231 | 232 | # wav -> wav_out 233 | # +1 since it is predicting the next step 234 | wav_out_sl = vconv.tensor_slice(in_act, out_req.sub) 235 | # wav_out_ten = wav_ten[sl_b+1:sl_e+1] 236 | 237 | mfcc_in_sl = vconv.tensor_slice(full_mfcc, mfcc_act.sub) 238 | 239 | print('{:10}: {}'.format('full_in', full_in)) 240 | print('{:10}: {}'.format('full_mfcc', full_mfcc)) 241 | print('{:10}: {}'.format('in_req', in_req)) 242 | print('{:10}: {}'.format('mfcc_req', mfcc_act)) 243 | print('{:10}: {}'.format('mid_req', mid_req)) 244 | print('{:10}: {}'.format('mid_act', mid_act)) 245 | print('{:10}: {}'.format('out_req', out_req)) 246 | print('{:10}: {}'.format('full_out', full_out)) 247 | 248 | print('wav_mid_sl: {} len: {}'.format(wav_mid_sl, wav_mid_sl[1] - 249 | wav_mid_sl[0])) 250 | print('mfcc_in_sl: {} len: {}'.format(mfcc_in_sl, mfcc_in_sl[1] - 251 | mfcc_in_sl[0])) 252 | print('lcond_sl: {} len: {}'.format(lcond_sl, lcond_sl[1] - lcond_sl[0])) 253 | print('wav_out_sl: {} len: {}'.format(wav_out_sl, wav_out_sl[1] - wav_out_sl[0])) 254 | 255 | 256 | encoder = vcs['MFCC'], vcs['LC_Conv'] 257 | encoder_clip = encoder[0].child, encoder[1] 258 | upsample = vcs['Upsampling_0'], vcs['Upsampling_3'] 259 | half_upsample = vcs['Upsampling_2'], vcs['Upsampling_3'] 260 | decoder = vcs['GRCC_0,0'], vcs['GRCC_1,9'] 261 | autoenc_clip = encoder[0].child, decoder[1] 262 | 263 | def phase_test(vc_range, n_sub_win, winsize): 264 | c = Counter() 265 | for b in range(n_sub_win): 266 | out = vconv.GridRange((0, 90000), (b, b + winsize), 1) 267 | input = vconv.input_range(*vc_range, out) 268 | c[input.sub_length()] += 1 269 | # print(mfcc.sub_length(), end=' ') 270 | print(c) 271 | 272 | 273 | #print('Phase test for autoencoder') 274 | #phase_test(autoenc_clip, 100) 275 | 276 | print('Phase test for upsample') 277 | phase_test(upsample, 20, 2146) 278 | print() 279 | 280 | print('Phase test for half upsample') 281 | phase_test(half_upsample, 20, 2146) 282 | print() 283 | 284 | print('Phase test for encoder_clip + upsample') 285 | phase_test((encoder_clip[0], upsample[1]), 6000, 2146) 286 | print() 287 | 288 | print('Phase test for decoder') 289 | phase_test(decoder, 6000, 100) 290 | print() 291 | 292 | 293 | def usage_test(vc_range, winsize): 294 | c = Counter() 295 | for b in range(winsize): 296 | out = vconv.GridRange((0, 100000), (b, b + 1), 1) 297 | input = vconv.input_range(*vc_range, out) 298 | slice = vconv.tensor_slice(input, input.sub) 299 | c[slice] += 1 300 | print(c) 301 | 302 | winsize = 10000 303 | print('Usage test for window size {}'.format(winsize)) 304 | usage_test((upsample[0], decoder[1]), winsize) 305 | 306 | # for s in range(56730, 57073, 30): 307 | # autoenc_test(vcs, 100000, s) 308 | 309 | 310 | #for t in (t2, t1): 311 | # main_test(t) 312 | 313 | 314 | 315 | #vc = mfcc_vc 316 | #while vc.child is not None: 317 | # f, s = (0, 1000), (150, 850) 318 | # forward = vconv.output_range(vc, vc, f, s, gs) 319 | # f, s, gs = forward[-1] 320 | # backward = vconv.input_range(vc, vc, f, s, gs) 321 | # print('f_in: {}, f_out: {}, {}'.format(forward[0][0], forward[1][0], vc)) 322 | # print('b_in: {}, b_out: {}, {}'.format(backward[1][0], backward[0][0], vc)) 323 | # vc = vc.child 324 | # print("") 325 | 326 | 327 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | try: 3 | import torch_xla.distributed.xla_multiprocessing as xmp 4 | import torch_xla.core.xla_model as xm 5 | except ModuleNotFoundError: 6 | pass 7 | 8 | import sys 9 | from sys import stderr 10 | from pprint import pprint 11 | import fire 12 | import torch as t 13 | 14 | import autoencoder_model as ae 15 | import chassis as ch 16 | import parse_tools 17 | import netmisc 18 | from hparams import setup_hparams, Hyperparams 19 | import time 20 | 21 | 22 | def _mp_fn(index, _hps, _dat_file): 23 | t.manual_seed(_hps.random_seed) 24 | 25 | # Acquires the (unique) Cloud TPU core corresponding to this process's index 26 | pre_dev_time = time.time() 27 | device = xm.xla_device() 28 | device_str = xm.xla_real_devices([str(device)])[0] 29 | elapsed = time.time() - pre_dev_time 30 | print(f'process {index} acquired {device_str} in {elapsed} seconds', 31 | file=stderr, flush=True) 32 | 33 | pre_inst_time = time.time() 34 | m = ch.Chassis(device, index, _hps, _dat_file) 35 | print(f'Created Chassis in {time.time() - pre_inst_time:3.5} seconds.', file=stderr, flush=True) 36 | xm.rendezvous('init') 37 | m.train() 38 | 39 | def run(dat_file, hps='mfcc_inverter,mfcc,train', **kwargs): 40 | if 'ckpt_file' in kwargs: 41 | hps = Hyperparams(kwargs) 42 | if 'random_seed' not in hps: 43 | hps.random_seed = 2507 44 | else: 45 | hps = setup_hparams(hps, kwargs) 46 | 47 | netmisc.set_print_iter(0) 48 | 49 | if hps.hw in ('GPU', 'TPU-single'): 50 | if hps.hw == 'GPU': 51 | device = t.device('cuda') 52 | hps.n_loader_workers = 0 53 | else: 54 | device = xm.xla_device() 55 | chs = ch.Chassis(device, 0, hps, dat_file) 56 | # chs.state.model.print_geometry() 57 | chs.train() 58 | elif hps.hw == 'TPU': 59 | print('Spawning new processes.', file=stderr, flush=True) 60 | xmp.spawn(_mp_fn, args=(hps, dat_file), nprocs=8, start_method='fork') 61 | 62 | 63 | if __name__ == '__main__': 64 | print(sys.executable, ' '.join(arg for arg in sys.argv), file=stderr, 65 | flush=True) 66 | fire.Fire(run) 67 | 68 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from pickle import dumps 3 | import numpy as np 4 | import torch 5 | from typing import Tuple 6 | 7 | def digest(obj): 8 | return md5(dumps(obj)).hexdigest() 9 | 10 | def tensor_digest(tensors): 11 | try: 12 | it = iter(tensors) 13 | except TypeError: 14 | tensors = list(tensors) 15 | 16 | vals = list(map(lambda t: t.flatten().detach().cpu().numpy().tolist(), tensors)) 17 | return digest(vals) 18 | 19 | def _validate_checkpoint_info(ckpt_dir, ckpt_file_template): 20 | # Unfortunately, Python doesn't provide a way to hold an open directory 21 | # handle, so we just check whether the directory path exists and is 22 | # writable during this call. 23 | import os 24 | if not os.access(ckpt_dir, os.R_OK|os.W_OK): 25 | raise ValueError('Cannot read and write checkpoint directory {}'.format(ckpt_dir)) 26 | # test if ckpt_file_template is valid 27 | try: 28 | test_file = ckpt_file_template.replace('%', '1000') 29 | except IndexError: 30 | test_file = '' 31 | # '1000' is 3 longer than '%' 32 | if len(test_file) != len(ckpt_file_template) + 3: 33 | raise ValueError('Checkpoint template "{}" ill-formed. ' 34 | '(should have exactly one "%")'.format(ckpt_file_template)) 35 | try: 36 | test_path = '{}/{}'.format(ckpt_dir, test_file) 37 | if not os.access(test_path, os.R_OK): 38 | fp = open(test_path, 'w') 39 | fp.close() 40 | os.remove(fp.name) 41 | except IOError: 42 | raise ValueError('Cannot create a test checkpoint file {}'.format(test_path)) 43 | 44 | 45 | class CheckpointPath(object): 46 | def __init__(self, path_template, validate=True): 47 | import os.path 48 | _dir = os.path.dirname(path_template) 49 | _base = os.path.basename(path_template) 50 | if _dir == '' or _base == '': 51 | raise ValueError('path_template "{}" does not contain both ' 52 | 'directory and file'.format(path_template)) 53 | self.dir = _dir.rstrip('/') 54 | self.file_template = _base 55 | if validate: 56 | _validate_checkpoint_info(self.dir, self.file_template) 57 | 58 | def path(self, step): 59 | return '{}/{}'.format(self.dir, self.file_template.replace('%', str(step))) 60 | 61 | 62 | def mu_encode_np(x, n_quanta): 63 | '''mu-law encode and quantize''' 64 | mu = n_quanta - 1 65 | amp = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 66 | quant = (amp + 1) * 0.5 * mu + 0.5 67 | return quant.astype(np.int32) 68 | 69 | 70 | def mu_decode_np(quant, n_quanta): 71 | '''accept an integer mu-law encoded quant, and convert 72 | it back to the pre-encoded value''' 73 | mu = n_quanta - 1 74 | qf = quant.astype(np.float32) 75 | inv_mu = 1.0 / mu 76 | a = (2 * qf - 1) * inv_mu - 1 77 | x = np.sign(a) * ((1 + mu)**np.fabs(a) - 1) * inv_mu 78 | return x 79 | 80 | 81 | def mu_encode_torch(x, n_quanta): 82 | '''mu-law encode and quantize''' 83 | mu = torch.tensor(float(n_quanta - 1), device=x.device) 84 | amp = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) 85 | quant = (amp + 1) * 0.5 * mu + 0.5 86 | return quant.round_().to(dtype=torch.long) 87 | 88 | def mu_decode_torch(quant, n_quanta): 89 | '''accept an integer mu-law encoded quant, and convert 90 | it back to the pre-encoded value''' 91 | mu = torch.tensor(float(n_quanta - 1), device=quant.device) 92 | qf = quant.to(dtype=torch.float32) 93 | inv_mu = mu.reciprocal() 94 | a = (2 * qf - 1) * inv_mu - 1 95 | x = torch.sign(a) * ((1 + mu)**torch.abs(a) - 1) * inv_mu 96 | return x 97 | 98 | def entropy(ten, do_norm=True): 99 | if do_norm: 100 | s = ten.sum() 101 | n = ten / s 102 | else: 103 | n = ten 104 | lv = torch.where(n == 0, n.new_zeros(n.size()), torch.log2(n)) 105 | return - (n * lv).sum() 106 | 107 | def int_hist(ten, ignore_val=None, accu=None): 108 | """Return a histogram of the integral-valued tensor""" 109 | if ten.is_floating_point(): 110 | raise RuntimeError('int_hist only works for non-floating-point tensors') 111 | 112 | if ignore_val is not None: 113 | mask = ten.ne(ignore_val) 114 | ten = ten.masked_select(mask) 115 | 116 | ne = max(ten.max() + 1, ten.nelement()) 117 | o = ten.new_ones(ne, dtype=torch.float) 118 | if accu is None: 119 | z = o.new_zeros(ne) 120 | else: 121 | z = accu 122 | z.scatter_add_(0, ten.flatten(), o) 123 | return z 124 | 125 | 126 | """ 127 | torch.index_select(input, d, query), expressed as SQL: 128 | 129 | d: integer in (1..k) 130 | input: i_(1..k), ival 131 | query: q_1, qval 132 | 133 | SELECT (i_1..i_k q_1/i_d), ival 134 | from input, query 135 | where i_d = qval 136 | 137 | notation: (1..k q/d) means "values 1 through k, replacing d with q" 138 | """ 139 | 140 | """ 141 | torch.gather(input, d, query), expressed as SQL: 142 | d: integer in (1..k) 143 | input: i_(1..k), ival 144 | query: q_(1..k), qval 145 | NOTE: max(q_j) = max(i_j) for all j != d 146 | 147 | SELECT (i_1 .. i_k qval/i_d), ival 148 | from index, query 149 | where i_d = qval 150 | 151 | The output has the same shape as query. 152 | All values of the output are values from input. 153 | It's like a multi-dimensional version of torch.take. 154 | """ 155 | 156 | # !!! this doesn't generalize to other dimensions anymore 157 | def gather_md_jit(input, dim: int, perm: Tuple[int, int], query): 158 | """ 159 | torchscript jit version 160 | """ 161 | k = input.dim() 162 | if dim < 0 or dim >= k: 163 | raise ValueError('dim {} must be in [0, {})'.format(dim, k)) 164 | 165 | # Q = prod(q_(1..m)) 166 | # x: (i_1..i_k Q/i_d) 167 | x = torch.index_select(input, dim, query.flatten()) 168 | 169 | # print('type of dim is: ', type(dim)) 170 | # x_perm: (i_1..i_k / q) + Q. In other words, move dimension Q to the end 171 | # t = list(range(dim)) + list(range(dim+1, k)) + [dim] 172 | # t = (0,1) 173 | # t = tuple(range(dim)) + tuple(range(dim+1, k)) + (dim,) 174 | # x_perm = x.permute(*t) 175 | # !!! original 176 | # t = tuple(range(dim)) + tuple(range(dim+1, k)) + (dim,) 177 | # print('permutation:', *perm) 178 | x_perm = x.permute(*perm) 179 | 180 | # for example, expand (i_1, i_2, i_3, Q) to (i_1, i_2, i_3, q_1, q_2, q_3) 181 | out_size = input.size()[:dim] + input.size()[dim+1:] + query.size() 182 | return x_perm.reshape(out_size) 183 | 184 | 185 | def gather_md(input, dim, query): 186 | ''' 187 | You can view a K-dimensional tensor entry: input[i1,i2,...,ik] = cell_value 188 | as a SQL table record with fields : i1, i2, ..., ik, cell_value 189 | 190 | Then, this function logically executes the following query: 191 | 192 | d: integer in (1..k) 193 | input: i_1, i_2, ..., i_k, ival 194 | query: q_1, q_2, ..., q_m, qval 195 | 196 | SELECT i_(1..k / d), q_(1..m), ival 197 | from input, query 198 | where i_d = qval 199 | 200 | (1..k / d) means "values 1 through k, excluding d" 201 | 202 | It is the same as torch.index_select, except that 'query' may have more 203 | than one dimension, and its dimension(s) are placed at the end of the 204 | result tensor rather than replacing input dimension 'dim' 205 | ''' 206 | k = input.dim() 207 | tup = tuple(range(dim)) + tuple(range(dim+1, k)) + (dim,) 208 | return gather_md_scriptable(input, dim, tup, query) 209 | 210 | 211 | def greatest_lower_bound(a, q): 212 | '''return largest i such that a[i] <= q. assume a is sorted. 213 | if q < a[0], return -1''' 214 | l, u = 0, len(a) - 1 215 | while (l < u): 216 | m = u - (u - l) // 2 217 | if a[m] <= q: 218 | l = m 219 | else: 220 | u = m - 1 221 | return l or -1 + (a[l] <= q) 222 | 223 | 224 | 225 | def sigfig(f, s, m): 226 | """format a floating point value in fixed point notation but 227 | with a fixed number of significant figures. 228 | Examples with nsigfig=3, maxwidth= 229 | Rule is: 230 | 1. If f < 1.0e-s, render with {:.ne} where n = s-1 231 | 2. If f > 1.0e+l, render with {:.ne} where n = s-1 232 | 3. Otherwise, render with {:0.ge} where g is: 233 | s if f in (0.1, 234 | 235 | f {:2e} final 236 | 1.23456e-04 => 1.23e-04 => unchanged 237 | 1.23456e-03 => 1.23e-03 => unchanged 238 | 1.23456e-02 => 1.23e-02 => unchanged 239 | 1.23456e-01 => 1.23e-01 => 0.123 240 | 1.23456e+00 => 1.23e+00 => 1.230 241 | 1.23456e+01 => 1.23e+01 => 12.30 242 | 1.23456e+02 => 1.23e+02 => 123.0 243 | 1.23456e+03 => 1.23e+03 => 1230. 244 | 1.23456e+04 => 1.23e+04 => 12300 245 | 1.23456e+05 => 1.23e+05 => unchanged 246 | 1.23456e+06 => 1.23e+06 => unchanged 247 | 248 | """ 249 | pass 250 | 251 | -------------------------------------------------------------------------------- /vae_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from sys import stderr 4 | import netmisc 5 | 6 | 7 | class VAE(nn.Module): 8 | def __init__(self, n_in, n_out, n_sam_per_datapoint=1, bias=True): 9 | '''n_sam_per_datapoint is L from equation 7, 10 | https://arxiv.org/pdf/1312.6114.pdf''' 11 | super(VAE, self).__init__() 12 | self.linear = nn.Conv1d(n_in, n_out * 2, 1, bias=False) 13 | self.tanh = nn.Tanh() 14 | # self.linear_mu = nn.Conv1d(n_out, n_out, 1, bias=bias) 15 | # self.linear_sigma = nn.Conv1d(n_out, n_out, 1, bias=bias) 16 | self.n_sam_per_datapoint = n_sam_per_datapoint 17 | self.n_out_chan = n_out 18 | netmisc.xavier_init(self.linear) 19 | # netmisc.xavier_init(self.linear_mu) 20 | # netmisc.xavier_init(self.linear_sigma) 21 | 22 | # Cache these values for later access by the objective function 23 | self.mu = None 24 | self.sigma = None 25 | 26 | def forward(self, z): 27 | # B, T, I, C: n_batch, n_timesteps, n_in_chan, n_out_chan 28 | # L: n_sam_per_datapoint 29 | # Input: (B, I, T) 30 | # Output: (B * L, C, T) 31 | # lin is the output of 'Linear(128)' from Figure 1 of Chorowski Jan 2019. 32 | lin = self.linear(z) 33 | 34 | # Chorowski doesn't specify anything between lin and mu/sigma. But, at 35 | # the very least, sigma must be positive. So, I adopt techniques from 36 | # Appendix C.2, Gaussian MLP as encoder or decoder" from Kingma VAE 37 | # paper. 38 | mu, log_sigma_sq = torch.split(lin, self.n_out_chan, dim=1) 39 | sigma = torch.exp(0.5 * log_sigma_sq) 40 | # sigma_sq = mss[:,n_out_chan:,:] 41 | #sigma = torch.sqrt(sigma_sq) 42 | 43 | L = self.n_sam_per_datapoint 44 | sample_sz = (mu.size()[0] * L,) + mu.size()[1:] 45 | if L > 1: 46 | sigma_sq = sigma_sq.repeat(L, 1, 1) 47 | log_sigma_sq = log_sigma_sq.repeat(L, 1, 1) 48 | mu = mu.repeat(L, 1, 1) 49 | 50 | # epsilon is the randomness injected here 51 | samples = torch.randn_like(mu) 52 | samples.mul_(sigma) 53 | samples.add_(mu) 54 | 55 | # Cache mu and sigma for objective function later 56 | self.mu = mu 57 | self.sigma_sq = torch.pow(sigma, 2.0) 58 | self.log_sigma_sq = log_sigma_sq 59 | #print(('linmu: {:.3}, linsd: {:.3}, zmu: {:.3}, zsd: {:.3}, mmu: {:.3}, msd: {:.3}, smu:' 60 | # '{:.3}, ssd: {:.3}').format(lin.mean(), lin.std(), z.mean(), 61 | # z.std(), mu.mean(), mu.std(), sigma.mean(), sigma.std())) 62 | return samples 63 | 64 | class SGVBLoss(nn.Module): 65 | def __init__(self, bottleneck, free_nats): 66 | super(SGVBLoss, self).__init__() 67 | self.bottleneck = bottleneck 68 | self.register_buffer('free_nats', torch.tensor(free_nats)) 69 | self.register_buffer('anneal_weight', torch.tensor(0.0)) 70 | self.logsoftmax = nn.LogSoftmax(1) # input is (B, Q, N) 71 | 72 | def update_anneal_weight(self, anneal_weight): 73 | self.anneal_weight.fill_(anneal_weight) 74 | 75 | 76 | def forward(self, quant_pred, target_wav): 77 | ''' 78 | Compute SGVB estimator from equation 8 in 79 | https://arxiv.org/pdf/1312.6114.pdf 80 | Uses formulas from "Autoencoding Variational Bayes", 81 | Appendix B, "Solution of -D_KL(q_phi(z) || p_theta(z)), Gaussian Case" 82 | ''' 83 | # B, T, Q, L: n_batch, n_timesteps, n_quant, n_samples_per_datapoint 84 | # K: n_bottleneck_channels 85 | # log_pred: (L * B, T, Q), the companded, quantized waveforms. 86 | # target_wav: (B, T) 87 | # mu, log_sigma_sq: (B, T, K), the vectors output by the bottleneck 88 | # Output: scalar, L(theta, phi, x) 89 | # log_sigma_sq = self.bottleneck.log_sigma_sq 90 | log_pred = self.logsoftmax(quant_pred) 91 | sigma_sq = self.bottleneck.sigma_sq 92 | mu = self.bottleneck.mu 93 | log_sigma_sq = torch.log(sigma_sq) 94 | mu_sq = mu * mu 95 | 96 | # neg_kl_div_gaussian: (B, K) (from Appendix B at end of derivation) 97 | channel_terms = 1.0 + log_sigma_sq - mu_sq - sigma_sq 98 | neg_kl_div_gauss = 0.5 * torch.sum(channel_terms) 99 | 100 | L = self.bottleneck.n_sam_per_datapoint 101 | BL = log_pred.size(0) 102 | assert BL % L == 0 103 | 104 | target_wav_aug = target_wav.repeat(L, 1).unsqueeze(1).long() 105 | log_pred_target = torch.gather(log_pred, 1, target_wav_aug) 106 | log_pred_target_avg = torch.mean(log_pred_target) 107 | 108 | log_pred_loss = - log_pred_target_avg 109 | kl_div_loss = - neg_kl_div_gauss 110 | 111 | # "For the VAE, this collapse can be prevented by annealing the weight 112 | # of the KL term and using the free-information formulation in Eq. (2)" 113 | # (See p 3 Section C second paragraph) 114 | total_loss = ( 115 | log_pred_loss + self.anneal_weight 116 | * torch.clamp(kl_div_loss, min=self.free_nats)) 117 | 118 | self.metrics = { 119 | 'kl_div_loss': kl_div_loss, 120 | 'log_pred_loss': log_pred_loss, 121 | # 'mu_abs_max': mu.abs().max(), 122 | # 's_sq_abs_max': sigma_sq.abs().max() 123 | } 124 | 125 | return total_loss 126 | 127 | -------------------------------------------------------------------------------- /vq_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import netmisc 4 | import util 5 | 6 | 7 | class VQ(nn.Module): 8 | def __init__(self, n_in, n_out, vq_gamma, vq_n_embed): 9 | super(VQ, self).__init__() 10 | self.d = n_out 11 | self.gamma = vq_gamma 12 | self.k = vq_n_embed 13 | self.linear = nn.Conv1d(n_in, self.d, 1, bias=False) 14 | self.sg = StopGrad() 15 | self.rg = ReplaceGrad() 16 | self.ze = None 17 | self.min_dist = None 18 | self.register_buffer('ind_hist', torch.zeros(self.k)) 19 | self.circ_inds = None 20 | self.emb = nn.Parameter(data=torch.empty(self.k, self.d)) 21 | nn.init.xavier_uniform_(self.emb, gain=1) 22 | 23 | netmisc.xavier_init(self.linear) 24 | 25 | # Shows how many of the embedding vectors have non-zero gradients 26 | #self.emb.register_hook(lambda k: print(k.sum(dim=1).unique(sorted=True))) 27 | 28 | def forward(self, z): 29 | """ 30 | B, Q, K, N: n_batch, n_quant_dims, n_quant_vecs, n_timesteps 31 | ze: (B, Q, N) 32 | emb: (K, Q) 33 | """ 34 | ze = self.linear(z) 35 | 36 | self.ze = ze 37 | 38 | sg_emb = self.sg(self.emb) 39 | l2norm_sq = ((ze.unsqueeze(1) - sg_emb.unsqueeze(2)) ** 2).sum(dim=2) # B, K, N 40 | self.min_dist, min_ind = l2norm_sq.min(dim=1) # B, N 41 | zq = util.gather_md(sg_emb, 0, min_ind).permute(1, 0, 2) 42 | zq_rg, __ = self.rg(zq, self.ze) 43 | 44 | # Diagnostics 45 | ni = min_ind.nelement() 46 | if self.circ_inds is None: 47 | self.write_pos = 0 48 | self.circ_inds = ze.new_full((100, ni), -1, dtype=torch.long) 49 | 50 | self.circ_inds[self.write_pos,0:ni] = min_ind.flatten(0) 51 | self.circ_inds[self.write_pos,ni:] = -1 52 | self.write_pos += 1 53 | self.write_pos = self.write_pos % 100 54 | 55 | ones = self.emb.new_ones(ni) 56 | util.int_hist(min_ind, accu=self.ind_hist) 57 | self.uniq = min_ind.unique(sorted=False) 58 | self.ze_norm = (self.ze ** 2).sum(dim=1).sqrt() 59 | self.emb_norm = (self.emb ** 2).sum(dim=1).sqrt() 60 | 61 | return zq_rg 62 | 63 | class VQLoss(nn.Module): 64 | def __init__(self, bottleneck): 65 | super(VQLoss, self).__init__() 66 | self.bn = bottleneck 67 | self.logsoftmax = nn.LogSoftmax(1) # input is (B, Q, N) 68 | # self.combine = netmisc.LCCombine('LCCombine') 69 | # self.usage_adjust = netmisc.EmbedLossAdjust('EmbedLossAdjust') 70 | self.l2 = L2Error() 71 | 72 | def forward(self, quant_pred, target_wav): 73 | """ 74 | quant_pred: 75 | target_wav: B, 76 | """ 77 | # Loss per embedding vector 78 | l2_loss_embeds = self.l2(self.bn.sg(self.bn.ze), self.bn.emb) 79 | # l2_loss_embeds = scaled_l2_norm(self.bn.sg(self.bn.ze), self.bn.emb) 80 | com_loss_embeds = self.bn.min_dist * self.bn.gamma 81 | 82 | log_pred = self.logsoftmax(quant_pred) 83 | log_pred_target = torch.gather(log_pred, 1, 84 | target_wav.long().unsqueeze(1)) 85 | 86 | # Loss per timestep 87 | # !!! We don't need a 'loss per timestep'. We only need 88 | # to adjust the l2 and com losses by usage weight of each 89 | # code. (The codes at the two ends of the window will be 90 | # used less) 91 | rec_loss_ts = - log_pred_target 92 | 93 | # Use only a subset of the overlapping windows 94 | #sl = slice(0, 1) 95 | #rec_loss_sel = rec_loss_ts[...,sl] 96 | #l2_loss_sel = l2_loss_ts[...,sl] 97 | #com_loss_sel = com_loss_ts[...,sl] 98 | 99 | # total_loss_sel = rec_loss_sel + l2_loss_sel + com_loss_sel 100 | # total_loss_ts = l2_loss_ts 101 | # total_loss_ts = com_loss_ts 102 | # total_loss_ts = com_loss_ts + l2_loss_ts 103 | # total_loss_ts = log_pred_loss_ts + l2_loss_ts 104 | # total_loss_ts = log_pred_loss_ts 105 | # total_loss_ts = com_loss_ts - com_loss_ts 106 | 107 | # total_loss = total_loss_sel.mean() 108 | 109 | # We use sum here for each of the three loss terms because each element 110 | # should affect the total loss equally. For a typical WaveNet 111 | # architecture, there will be only one l2 loss term (or com_loss term) 112 | # per 320 rec_loss terms, due to upsampling. We could adjust for that. 113 | # Implicitly, com_loss is already adjusted by gamma. Perhaps l2_loss 114 | # should also be adjusted, but at the moment it is not. 115 | total_loss = rec_loss_ts.sum() + l2_loss_embeds.sum() + com_loss_embeds.sum() 116 | 117 | nh = self.bn.ind_hist / self.bn.ind_hist.sum() 118 | 119 | self.metrics = { 120 | 'rec': rec_loss_ts.mean(), 121 | 'l2': l2_loss_embeds.mean(), 122 | 'com': com_loss_embeds.mean(), 123 | #'ze_rng': self.bn.ze.max() - self.bn.ze.min(), 124 | #'emb_rng': self.bn.emb.max() - self.bn.emb.min(), 125 | 'min_ze': self.bn.ze_norm.min(), 126 | 'max_ze': self.bn.ze_norm.max(), 127 | 'min_emb': self.bn.emb_norm.min(), 128 | 'max_emb': self.bn.emb_norm.max(), 129 | 'hst_ent': util.entropy(self.bn.ind_hist, True), 130 | 'hst_100': util.entropy(util.int_hist(self.bn.circ_inds, -1), True), 131 | #'p_m': log_pred.max(dim=1)[0].to(torch.float).mean(), 132 | #'p_sd': log_pred.max(dim=1)[0].to(torch.float).std(), 133 | 'nunq': self.bn.uniq.nelement(), 134 | 'pk_m': log_pred.max(dim=1)[0].to(torch.float).mean(), 135 | 'pk_nuq': log_pred.max(dim=1)[1].unique().nelement(), 136 | # 'peak_unq': log_pred.max(dim=1)[1].unique(), 137 | 'pk_sd': log_pred.max(dim=1)[0].to(torch.float).std(), 138 | # 'unq': self.bn.uniq, 139 | #'m_ze': self.bn.ze_norm.max(), 140 | #'m_emb': self.bn.emb_norm.max() 141 | #emb0 = emb - emb.mean(dim=0) 142 | #chan_var = (emb0 ** 2).sum(dim=0) 143 | #chan_covar = torch.matmul(emb0.transpose(1, 0), emb0) - torch.diag(chan_var) 144 | } 145 | # netmisc.print_metrics(losses, 10000000) 146 | 147 | return total_loss 148 | 149 | -------------------------------------------------------------------------------- /vqema_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import netmisc 4 | import util 5 | 6 | 7 | class StopGradFn(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, src): 10 | return src 11 | 12 | @staticmethod 13 | def backward(ctx, src): 14 | return src.new_zeros(src.size()) 15 | 16 | class StopGrad(nn.Module): 17 | """Implements the StopGradient operation. 18 | Usage: 19 | sg = StopGrad() 20 | a = Tensor(..., requires_grad=True) 21 | a.grad.zero_() 22 | b = sg(a).sum() 23 | b.backward() 24 | assert (a.grad == 0).all().item() 25 | """ 26 | def __init__(self): 27 | super(StopGrad, self).__init__() 28 | 29 | def forward(self, src): 30 | return StopGradFn.apply(src) 31 | 32 | 33 | class ReplaceGradFn(torch.autograd.Function): 34 | """ 35 | This is like a StopGradient operation, except that instead 36 | of assigning zero gradient to src, assigns gradient of trg to src 37 | """ 38 | @staticmethod 39 | def forward(ctx, src, trg): 40 | assert src.size() == trg.size() 41 | return src, trg 42 | 43 | @staticmethod 44 | def backward(ctx, src_grad, trg_grad): 45 | return src_grad.new_zeros(src_grad.size()), src_grad + trg_grad 46 | 47 | 48 | class ReplaceGrad(nn.Module): 49 | """ 50 | Usage: 51 | rg = ReplaceGrad() 52 | s1 = Tensor(..., requires_grad=True) 53 | t1 = Tensor(..., requires_grad=True) 54 | s2, t2 = rg(s1, t1) 55 | 56 | s1 receives the zero gradient 57 | t1 receives the sum of s2's and t2's gradient 58 | 59 | """ 60 | def __init__(self): 61 | super(ReplaceGrad, self).__init__() 62 | 63 | def forward(self, src, trg): 64 | return ReplaceGradFn.apply(src, trg) 65 | 66 | 67 | def scaled_l2_norm(z, q): 68 | """ 69 | Computes a distance D(z, q) with properties: 70 | D(lambda * z, lambda * q) = D(z, q) 71 | D(z, 0) = D(0, z) = 1 72 | D(z, lambda*z) = |1-lambda| / (1 + |lambda|) 73 | """ 74 | num = ((z - q) ** 2).sum(dim=2).sqrt() 75 | den = (z ** 2).sum(dim=2).sqrt() + (q ** 2).sum(dim=2).sqrt() 76 | return num / den 77 | 78 | 79 | class VQEMA(nn.Module): 80 | """ 81 | Vector Quantization bottleneck using Exponential Moving Average 82 | updates of the Codebook vectors. 83 | """ 84 | def __init__(self, n_in, n_out, vq_gamma, vq_ema_gamma, vq_n_embed, training): 85 | super(VQEMA, self).__init__() 86 | self.training = training 87 | self.d = n_out 88 | self.gamma = vq_gamma 89 | self.ema_gamma = vq_ema_gamma 90 | self.ema_gamma_comp = 1.0 - self.ema_gamma 91 | self.k = vq_n_embed 92 | self.linear = nn.Conv1d(n_in, self.d, 1, bias=False) 93 | self.sg = StopGrad() 94 | self.rg = ReplaceGrad() 95 | self.ze = None 96 | self.register_buffer('emb', torch.empty(self.k, self.d)) 97 | nn.init.xavier_uniform_(self.emb, gain=10) 98 | 99 | if self.ema_gamma >= 1.0 or self.ema_gamma <= 0: 100 | raise RuntimeError('VQEMA must use an EMA-gamma value in (0, 1)') 101 | 102 | if self.training: 103 | self.min_dist = None 104 | self.circ_inds = None 105 | self.register_buffer('ind_hist', torch.zeros(self.k)) 106 | self.register_buffer('ema_numer', torch.empty(self.k, self.d)) 107 | self.register_buffer('ema_denom', torch.empty(self.k)) 108 | self.register_buffer('z_sum', torch.empty(self.k, self.d)) 109 | self.register_buffer('n_sum', torch.empty(self.k)) 110 | self.register_buffer('n_sum_ones', torch.ones(self.k)) 111 | #self.ema_numer.detach_() 112 | #self.ema_denom.detach_() 113 | #self.z_sum.detach_() 114 | #self.n_sum.detach_() 115 | #self.emb.detach_() 116 | #nn.init.ones_(self.ema_denom) 117 | self.ema_numer = self.emb * self.ema_gamma_comp 118 | self.ema_denom = self.n_sum_ones * self.ema_gamma_comp 119 | 120 | netmisc.xavier_init(self.linear) 121 | 122 | # Shows how many of the embedding vectors have non-zero gradients 123 | #self.emb.register_hook(lambda k: print(k.sum(dim=1).unique(sorted=True))) 124 | 125 | def forward(self, z): 126 | """ 127 | B, Q, K, N: n_batch, n_quant_dims, n_quant_vecs, n_timesteps 128 | ze: (B, Q, N) 129 | emb: (K, Q) 130 | """ 131 | ze = self.linear(z) 132 | self.ze = ze 133 | sg_emb = self.sg(self.emb) 134 | 135 | l2norm_sq = ((ze.unsqueeze(1) - sg_emb.unsqueeze(2)) ** 2).sum(dim=2) # B, K, N 136 | # self.min_dist, min_ind = l2norm_sq.min(dim=1) # B, N 137 | 138 | snorm = scaled_l2_norm(ze.unsqueeze(1), 139 | sg_emb.unsqueeze(2).unsqueeze(0)) 140 | #print('snorm: ', snorm) 141 | self.min_dist, min_ind = snorm.min(dim=1) # B, N 142 | zq = util.gather_md(sg_emb, 0, min_ind).permute(1, 0, 2) 143 | 144 | if self.training: 145 | # Diagnostics 146 | ni = min_ind.nelement() 147 | #if self.circ_inds is None: 148 | # self.write_pos = 0 149 | # self.circ_inds = ze.new_full((100, ni), -1, dtype=torch.long) 150 | 151 | #self.circ_inds[self.write_pos,0:ni] = min_ind.flatten(0) 152 | #self.circ_inds[self.write_pos,ni:] = -1 153 | #self.write_pos += 1 154 | #self.write_pos = self.write_pos % 100 155 | ones = self.emb.new_ones(ni) 156 | util.int_hist(min_ind, accu=self.ind_hist) 157 | self.uniq = min_ind.unique(sorted=False) 158 | self.ze_norm = (self.ze ** 2).sum(dim=1).sqrt() 159 | self.emb_norm = (self.emb ** 2).sum(dim=1).sqrt() 160 | self.min_ind = min_ind 161 | 162 | # EMA statistics 163 | # min_ind: B, W 164 | # ze: B, D, W 165 | # z_sum: K, D 166 | # n_sum: K 167 | # scatter_add has the limitation that the size of the indexing 168 | # vector cannot exceed that of the destination (even in the target 169 | # indexing dimension, which doesn't make much sense) 170 | # In this case, K is the indexing dimension 171 | # batch_size * window_batch_size 172 | flat_ind = min_ind.flatten(0, 1) 173 | idim = max(flat_ind.shape[0], self.k) 174 | 175 | z_tmp_shape = [idim, self.d] 176 | n_sum_tmp = self.n_sum.new_zeros(idim) 177 | 178 | z_sum_tmp = self.z_sum.new_zeros(z_tmp_shape) 179 | z_sum_tmp.scatter_add_(0, 180 | flat_ind.unsqueeze(1).repeat(1, self.d), 181 | self.ze.permute(0,2,1).flatten(0, 1) 182 | ) 183 | self.z_sum[...] = z_sum_tmp[0:self.k,:] 184 | 185 | self.n_sum.zero_() 186 | n_sum_ones = n_sum_tmp.new_ones((idim)) 187 | n_sum_tmp.scatter_add_(0, flat_ind, n_sum_ones) 188 | self.n_sum[...] = n_sum_tmp[0:self.k] 189 | 190 | self.ema_numer = ( 191 | self.ema_gamma * self.ema_numer + 192 | self.ema_gamma_comp * self.z_sum) 193 | self.ema_denom = ( 194 | self.ema_gamma * self.ema_denom + 195 | self.ema_gamma_comp * self.n_sum) 196 | 197 | # construct the straight-through estimator ('ReplaceGrad') 198 | # What I need is 199 | # cb_update = self.ema_numer / self.ema_denom.unsqueeze(1).repeat(1, 200 | # self.d) 201 | 202 | # print('z_sum_norm:', (self.z_sum ** 2).sum(dim=1).sqrt()) 203 | # print('n_sum_norm:', self.n_sum) 204 | print('ze_norm:', self.ze_norm) 205 | print('emb_norm:', (self.emb ** 2).sum(dim=1).sqrt()) 206 | print('min_ind:', self.min_ind) 207 | # print('cb_update_norm:', (cb_update ** 2).sum(dim=1).sqrt()) 208 | # print('ema_numer_norm:', 209 | # (self.ema_numer ** 2).sum(dim=1).sqrt().mean()) 210 | # print('ema_denom_norm:', 211 | # (self.ema_denom ** 2).sqrt().mean()) 212 | zq_rg, __ = self.rg(zq, self.ze) 213 | 214 | return zq_rg 215 | 216 | def update_codebook(self): 217 | """ 218 | Updates the codebook based on the EMA statistics 219 | """ 220 | self.emb = self.ema_numer / self.ema_denom.unsqueeze(1).repeat(1, 221 | self.d) 222 | self.emb.detach_() 223 | 224 | 225 | class VQEMALoss(nn.Module): 226 | def __init__(self, bottleneck): 227 | super(VQEMALoss, self).__init__() 228 | self.bn = bottleneck 229 | self.logsoftmax = nn.LogSoftmax(1) # input is (B, Q, N) 230 | 231 | def forward(self, quant_pred, target_wav): 232 | """ 233 | quant_pred: 234 | target_wav: B, 235 | """ 236 | # Loss per embedding vector 237 | com_loss_embeds = self.bn.min_dist * self.bn.gamma 238 | 239 | log_pred = self.logsoftmax(quant_pred) 240 | log_pred_target = torch.gather(log_pred, 1, 241 | target_wav.long().unsqueeze(1)) 242 | 243 | rec_loss_ts = - log_pred_target 244 | # total_loss = rec_loss_ts.sum() + com_loss_embeds.sum() 245 | # total_loss = rec_loss_ts.sum() 246 | total_loss = com_loss_embeds.sum() 247 | # total_loss = com_loss_embeds.sum() * 0.0 248 | 249 | nh = self.bn.ind_hist / self.bn.ind_hist.sum() 250 | 251 | self.metrics = { 252 | 'rec': rec_loss_ts.mean(), 253 | 'com': com_loss_embeds.mean(), 254 | 'min_ze': self.bn.ze_norm.min(), 255 | 'max_ze': self.bn.ze_norm.max(), 256 | 'min_emb': self.bn.emb_norm.min(), 257 | 'max_emb': self.bn.emb_norm.max(), 258 | 'hst_ent': util.entropy(self.bn.ind_hist, True), 259 | # 'hst_100': util.entropy(util.int_hist(self.bn.circ_inds, -1), True), 260 | 'nunq': self.bn.uniq.nelement(), 261 | 'pk_m': log_pred.max(dim=1)[0].to(torch.float).mean(), 262 | 'pk_nuq': log_pred.max(dim=1)[1].unique().nelement(), 263 | 'pk_sd': log_pred.max(dim=1)[0].to(torch.float).std() 264 | } 265 | 266 | return total_loss 267 | 268 | -------------------------------------------------------------------------------- /wave_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import vconv 4 | import numpy as np 5 | import netmisc 6 | from sys import stderr 7 | 8 | class ConvReLURes(nn.Module): 9 | def __init__(self, n_in_chan, n_out_chan, filter_sz, stride=1, do_res=True, 10 | parent_vc=None, name=None): 11 | super(ConvReLURes, self).__init__() 12 | self.n_in = n_in_chan 13 | self.n_out = n_out_chan 14 | self.conv = nn.Conv1d(n_in_chan, n_out_chan, filter_sz, stride, 15 | padding=0, bias=True) 16 | self.relu = nn.ReLU() 17 | self.name = name 18 | # self.bn = nn.BatchNorm1d(n_out_chan) 19 | 20 | self.vc = vconv.VirtualConv(filter_info=filter_sz, stride=stride, 21 | parent=parent_vc, name=name) 22 | 23 | self.do_res = do_res 24 | if self.do_res: 25 | if stride != 1: 26 | print('Stride must be 1 for residually connected convolution', 27 | file=stderr) 28 | raise ValueError 29 | l_off, r_off = vconv.output_offsets(self.vc, self.vc) 30 | self.register_buffer('residual_offsets', 31 | torch.tensor([l_off, r_off])) 32 | netmisc.xavier_init(self.conv) 33 | 34 | def forward(self, x): 35 | ''' 36 | B, C, T = n_batch, n_in_chan, n_win 37 | x: (B, C, T) 38 | ''' 39 | pre = self.conv(x) 40 | # out = self.bn(out) 41 | act = self.relu(pre) 42 | if self.do_res: 43 | act[...] += x[:,:,self.residual_offsets[0]:self.residual_offsets[1] or None] 44 | # act += x[:,:,self.residual_offsets[0]:self.residual_offsets[1] or None] 45 | #act_sum = act.sum() 46 | self.frac_zero_act = (act == 0.0).sum().double() / act.nelement() 47 | #if act_sum.eq(0.0): 48 | # print('encoder layer {}: {}'.format(self.name, act_sum)) 49 | #print('bias mean: {}'.format(self.conv.bias.mean())) 50 | return act 51 | 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, n_in, n_out, parent_vc): 55 | super(Encoder, self).__init__() 56 | 57 | # the "stack" 58 | stack_in_chan = [n_in, n_out, n_out, n_out, n_out, n_out, n_out, n_out, n_out] 59 | stack_filter_sz = [3, 3, 4, 3, 3, 1, 1, 1, 1] 60 | stack_strides = [1, 1, 2, 1, 1, 1, 1, 1, 1] 61 | stack_residual = [False, True, False, True, True, True, True, True, True] 62 | # stack_residual = [True] * 9 63 | # stack_residual = [False] * 9 64 | stack_info = zip(stack_in_chan, stack_filter_sz, stack_strides, stack_residual) 65 | self.net = nn.Sequential() 66 | self.vc = dict() 67 | 68 | for i, (in_chan, filt_sz, stride, do_res) in enumerate(stack_info): 69 | name = 'CRR_{}(filter_sz={}, stride={}, do_res={})'.format(i, 70 | filt_sz, stride, do_res) 71 | mod = ConvReLURes(in_chan, n_out, filt_sz, stride, do_res, 72 | parent_vc, name) 73 | self.net.add_module(str(i), mod) 74 | parent_vc = mod.vc 75 | 76 | self.vc['beg'] = self.net[0].vc 77 | self.vc['end'] = self.net[-1].vc 78 | 79 | def set_parent_vc(self, parent_vc): 80 | self.vc['beg'].parent = parent_vc 81 | parent_vc.child = self.vc['beg'] 82 | 83 | def update_metrics(self): 84 | self.metrics = {} 85 | for i, mod in enumerate(self.net): 86 | wkey = 'enc_wz_{}'.format(i) 87 | bkey = 'enc_bz_{}'.format(i) 88 | akey = 'enc_az_{}'.format(i) 89 | self.metrics[akey] = mod.frac_zero_act 90 | # self.metrics[wkey] = (mod.conv.weight == 0).sum() 91 | # self.metrics[bkey] = (mod.conv.bias == 0).sum() 92 | 93 | 94 | def forward(self, mels): 95 | ''' 96 | B, M, C, T = n_batch, n_mels, n_channels, n_timesteps 97 | mels: (B, M, T) (torch.tensor) 98 | outputs: (B, C, T) 99 | ''' 100 | out = self.net(mels) 101 | self.update_metrics() 102 | #out = torch.tanh(out * 10.0) 103 | return out 104 | 105 | --------------------------------------------------------------------------------