├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── datasets └── download-from-youtube.sh ├── model.py ├── nn.py ├── optim.py ├── requirements.txt ├── train.py ├── trainer ├── __init__.py └── plugins.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # vim temporary files 92 | *~ 93 | *.swp 94 | *.swo 95 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Piotr Kozakowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # samplernn-pytorch 2 | 3 | A PyTorch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio Generation Model](https://arxiv.org/abs/1612.07837). 4 | 5 | ![A visual representation of the SampleRNN architecture](http://deepsound.io/images/samplernn.png) 6 | 7 | It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html. 8 | 9 | ## Dependencies 10 | 11 | This code requires Python 3.5+ and PyTorch 0.1.12+. Installation instructions for PyTorch are available on their website: http://pytorch.org/. You can install the rest of the dependencies by running `pip install -r requirements.txt`. 12 | 13 | ## Datasets 14 | 15 | We provide a script for creating datasets from YouTube single-video mixes. It downloads a mix, converts it to wav and splits it into equal-length chunks. To run it you need youtube-dl (a recent version; the latest version from pip should be okay) and ffmpeg. To create an example dataset - 4 hours of piano music split into 8 second chunks, run: 16 | 17 | ``` 18 | cd datasets 19 | ./download-from-youtube.sh "https://www.youtube.com/watch?v=EhO_MrRfftU" 8 piano 20 | ``` 21 | 22 | You can also prepare a dataset yourself. It should be a directory in `datasets/` filled with equal-length wav files. Or you can create your own dataset format by subclassing `torch.utils.data.Dataset`. It's easy, take a look at `dataset.FolderDataset` in this repo for an example. 23 | 24 | ## Training 25 | 26 | To train the model you need to run `train.py`. All model hyperparameters are settable in the command line. Most hyperparameters have sensible default values, so you don't need to provide all of them. Run `python train.py -h` for details. To train on the `piano` dataset using the best hyperparameters we've found, run: 27 | 28 | ``` 29 | python train.py --exp TEST --frame_sizes 16 4 --n_rnn 2 --dataset piano 30 | ``` 31 | 32 | The results - training log, loss plots, model checkpoints and generated samples will be saved in `results/`. 33 | 34 | We also have an option to monitor the metrics using [CometML](https://www.comet.ml/). To use it, just pass your API key as `--comet_key` parameter to `train.py`. 35 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | import torch 4 | from torch.utils.data import ( 5 | Dataset, DataLoader as DataLoaderBase 6 | ) 7 | 8 | from librosa.core import load 9 | from natsort import natsorted 10 | 11 | from os import listdir 12 | from os.path import join 13 | 14 | 15 | class FolderDataset(Dataset): 16 | 17 | def __init__(self, path, overlap_len, q_levels, ratio_min=0, ratio_max=1): 18 | super().__init__() 19 | self.overlap_len = overlap_len 20 | self.q_levels = q_levels 21 | file_names = natsorted( 22 | [join(path, file_name) for file_name in listdir(path)] 23 | ) 24 | self.file_names = file_names[ 25 | int(ratio_min * len(file_names)) : int(ratio_max * len(file_names)) 26 | ] 27 | 28 | def __getitem__(self, index): 29 | (seq, _) = load(self.file_names[index], sr=None, mono=True) 30 | return torch.cat([ 31 | torch.LongTensor(self.overlap_len) \ 32 | .fill_(utils.q_zero(self.q_levels)), 33 | utils.linear_quantize( 34 | torch.from_numpy(seq), self.q_levels 35 | ) 36 | ]) 37 | 38 | def __len__(self): 39 | return len(self.file_names) 40 | 41 | 42 | class DataLoader(DataLoaderBase): 43 | 44 | def __init__(self, dataset, batch_size, seq_len, overlap_len, 45 | *args, **kwargs): 46 | super().__init__(dataset, batch_size, *args, **kwargs) 47 | self.seq_len = seq_len 48 | self.overlap_len = overlap_len 49 | 50 | def __iter__(self): 51 | for batch in super().__iter__(): 52 | (batch_size, n_samples) = batch.size() 53 | 54 | reset = True 55 | 56 | for seq_begin in range(self.overlap_len, n_samples, self.seq_len): 57 | from_index = seq_begin - self.overlap_len 58 | to_index = seq_begin + self.seq_len 59 | sequences = batch[:, from_index : to_index] 60 | input_sequences = sequences[:, : -1] 61 | target_sequences = sequences[:, self.overlap_len :].contiguous() 62 | 63 | yield (input_sequences, reset, target_sequences) 64 | 65 | reset = False 66 | 67 | def __len__(self): 68 | raise NotImplementedError() 69 | -------------------------------------------------------------------------------- /datasets/download-from-youtube.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 3 ]; then 4 | echo "Usage: $0 " 5 | exit 6 | fi 7 | 8 | url=$1 9 | chunk_size=$2 10 | dataset_path=$3 11 | 12 | downloaded=".temp" 13 | rm -f $downloaded 14 | format=$(youtube-dl -F $url | grep audio | sed -r 's|([0-9]+).*|\1|g' | tail -n 1) 15 | youtube-dl $url -f $format -o $downloaded 16 | 17 | converted=".temp2.wav" 18 | rm -f $converted 19 | ffmpeg -i $downloaded -ac 1 -ab 16k -ar 16000 $converted 20 | rm -f $downloaded 21 | 22 | mkdir $dataset_path 23 | length=$(ffprobe -i $converted -show_entries format=duration -v quiet -of csv="p=0") 24 | end=$(echo "$length / $chunk_size - 1" | bc) 25 | echo "splitting..." 26 | for i in $(seq 0 $end); do 27 | ffmpeg -hide_banner -loglevel error -ss $(($i * $chunk_size)) -t $chunk_size -i $converted "$dataset_path/$i.wav" 28 | done 29 | echo "done" 30 | rm -f $converted 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import nn 2 | import utils 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | 8 | import numpy as np 9 | 10 | 11 | class SampleRNN(torch.nn.Module): 12 | 13 | def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels, 14 | weight_norm): 15 | super().__init__() 16 | 17 | self.dim = dim 18 | self.q_levels = q_levels 19 | 20 | ns_frame_samples = map(int, np.cumprod(frame_sizes)) 21 | self.frame_level_rnns = torch.nn.ModuleList([ 22 | FrameLevelRNN( 23 | frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm 24 | ) 25 | for (frame_size, n_frame_samples) in zip( 26 | frame_sizes, ns_frame_samples 27 | ) 28 | ]) 29 | 30 | self.sample_level_mlp = SampleLevelMLP( 31 | frame_sizes[0], dim, q_levels, weight_norm 32 | ) 33 | 34 | @property 35 | def lookback(self): 36 | return self.frame_level_rnns[-1].n_frame_samples 37 | 38 | 39 | class FrameLevelRNN(torch.nn.Module): 40 | 41 | def __init__(self, frame_size, n_frame_samples, n_rnn, dim, 42 | learn_h0, weight_norm): 43 | super().__init__() 44 | 45 | self.frame_size = frame_size 46 | self.n_frame_samples = n_frame_samples 47 | self.dim = dim 48 | 49 | h0 = torch.zeros(n_rnn, dim) 50 | if learn_h0: 51 | self.h0 = torch.nn.Parameter(h0) 52 | else: 53 | self.register_buffer('h0', torch.autograd.Variable(h0)) 54 | 55 | self.input_expand = torch.nn.Conv1d( 56 | in_channels=n_frame_samples, 57 | out_channels=dim, 58 | kernel_size=1 59 | ) 60 | init.kaiming_uniform(self.input_expand.weight) 61 | init.constant(self.input_expand.bias, 0) 62 | if weight_norm: 63 | self.input_expand = torch.nn.utils.weight_norm(self.input_expand) 64 | 65 | self.rnn = torch.nn.GRU( 66 | input_size=dim, 67 | hidden_size=dim, 68 | num_layers=n_rnn, 69 | batch_first=True 70 | ) 71 | for i in range(n_rnn): 72 | nn.concat_init( 73 | getattr(self.rnn, 'weight_ih_l{}'.format(i)), 74 | [nn.lecun_uniform, nn.lecun_uniform, nn.lecun_uniform] 75 | ) 76 | init.constant(getattr(self.rnn, 'bias_ih_l{}'.format(i)), 0) 77 | 78 | nn.concat_init( 79 | getattr(self.rnn, 'weight_hh_l{}'.format(i)), 80 | [nn.lecun_uniform, nn.lecun_uniform, init.orthogonal] 81 | ) 82 | init.constant(getattr(self.rnn, 'bias_hh_l{}'.format(i)), 0) 83 | 84 | self.upsampling = nn.LearnedUpsampling1d( 85 | in_channels=dim, 86 | out_channels=dim, 87 | kernel_size=frame_size 88 | ) 89 | init.uniform( 90 | self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim) 91 | ) 92 | init.constant(self.upsampling.bias, 0) 93 | if weight_norm: 94 | self.upsampling.conv_t = torch.nn.utils.weight_norm( 95 | self.upsampling.conv_t 96 | ) 97 | 98 | def forward(self, prev_samples, upper_tier_conditioning, hidden): 99 | (batch_size, _, _) = prev_samples.size() 100 | 101 | input = self.input_expand( 102 | prev_samples.permute(0, 2, 1) 103 | ).permute(0, 2, 1) 104 | if upper_tier_conditioning is not None: 105 | input += upper_tier_conditioning 106 | 107 | reset = hidden is None 108 | 109 | if hidden is None: 110 | (n_rnn, _) = self.h0.size() 111 | hidden = self.h0.unsqueeze(1) \ 112 | .expand(n_rnn, batch_size, self.dim) \ 113 | .contiguous() 114 | 115 | (output, hidden) = self.rnn(input, hidden) 116 | 117 | output = self.upsampling( 118 | output.permute(0, 2, 1) 119 | ).permute(0, 2, 1) 120 | return (output, hidden) 121 | 122 | 123 | class SampleLevelMLP(torch.nn.Module): 124 | 125 | def __init__(self, frame_size, dim, q_levels, weight_norm): 126 | super().__init__() 127 | 128 | self.q_levels = q_levels 129 | 130 | self.embedding = torch.nn.Embedding( 131 | self.q_levels, 132 | self.q_levels 133 | ) 134 | 135 | self.input = torch.nn.Conv1d( 136 | in_channels=q_levels, 137 | out_channels=dim, 138 | kernel_size=frame_size, 139 | bias=False 140 | ) 141 | init.kaiming_uniform(self.input.weight) 142 | if weight_norm: 143 | self.input = torch.nn.utils.weight_norm(self.input) 144 | 145 | self.hidden = torch.nn.Conv1d( 146 | in_channels=dim, 147 | out_channels=dim, 148 | kernel_size=1 149 | ) 150 | init.kaiming_uniform(self.hidden.weight) 151 | init.constant(self.hidden.bias, 0) 152 | if weight_norm: 153 | self.hidden = torch.nn.utils.weight_norm(self.hidden) 154 | 155 | self.output = torch.nn.Conv1d( 156 | in_channels=dim, 157 | out_channels=q_levels, 158 | kernel_size=1 159 | ) 160 | nn.lecun_uniform(self.output.weight) 161 | init.constant(self.output.bias, 0) 162 | if weight_norm: 163 | self.output = torch.nn.utils.weight_norm(self.output) 164 | 165 | def forward(self, prev_samples, upper_tier_conditioning): 166 | (batch_size, _, _) = upper_tier_conditioning.size() 167 | 168 | prev_samples = self.embedding( 169 | prev_samples.contiguous().view(-1) 170 | ).view( 171 | batch_size, -1, self.q_levels 172 | ) 173 | 174 | prev_samples = prev_samples.permute(0, 2, 1) 175 | upper_tier_conditioning = upper_tier_conditioning.permute(0, 2, 1) 176 | 177 | x = F.relu(self.input(prev_samples) + upper_tier_conditioning) 178 | x = F.relu(self.hidden(x)) 179 | x = self.output(x).permute(0, 2, 1).contiguous() 180 | 181 | return F.log_softmax(x.view(-1, self.q_levels)) \ 182 | .view(batch_size, -1, self.q_levels) 183 | 184 | 185 | class Runner: 186 | 187 | def __init__(self, model): 188 | super().__init__() 189 | self.model = model 190 | self.reset_hidden_states() 191 | 192 | def reset_hidden_states(self): 193 | self.hidden_states = {rnn: None for rnn in self.model.frame_level_rnns} 194 | 195 | def run_rnn(self, rnn, prev_samples, upper_tier_conditioning): 196 | (output, new_hidden) = rnn( 197 | prev_samples, upper_tier_conditioning, self.hidden_states[rnn] 198 | ) 199 | self.hidden_states[rnn] = new_hidden.detach() 200 | return output 201 | 202 | 203 | class Predictor(Runner, torch.nn.Module): 204 | 205 | def __init__(self, model): 206 | super().__init__(model) 207 | 208 | def forward(self, input_sequences, reset): 209 | if reset: 210 | self.reset_hidden_states() 211 | 212 | (batch_size, _) = input_sequences.size() 213 | 214 | upper_tier_conditioning = None 215 | for rnn in reversed(self.model.frame_level_rnns): 216 | from_index = self.model.lookback - rnn.n_frame_samples 217 | to_index = -rnn.n_frame_samples + 1 218 | prev_samples = 2 * utils.linear_dequantize( 219 | input_sequences[:, from_index : to_index], 220 | self.model.q_levels 221 | ) 222 | prev_samples = prev_samples.contiguous().view( 223 | batch_size, -1, rnn.n_frame_samples 224 | ) 225 | 226 | upper_tier_conditioning = self.run_rnn( 227 | rnn, prev_samples, upper_tier_conditioning 228 | ) 229 | 230 | bottom_frame_size = self.model.frame_level_rnns[0].frame_size 231 | mlp_input_sequences = input_sequences \ 232 | [:, self.model.lookback - bottom_frame_size :] 233 | 234 | return self.model.sample_level_mlp( 235 | mlp_input_sequences, upper_tier_conditioning 236 | ) 237 | 238 | 239 | class Generator(Runner): 240 | 241 | def __init__(self, model, cuda=False): 242 | super().__init__(model) 243 | self.cuda = cuda 244 | 245 | def __call__(self, n_seqs, seq_len): 246 | # generation doesn't work with CUDNN for some reason 247 | torch.backends.cudnn.enabled = False 248 | 249 | self.reset_hidden_states() 250 | 251 | bottom_frame_size = self.model.frame_level_rnns[0].n_frame_samples 252 | sequences = torch.LongTensor(n_seqs, self.model.lookback + seq_len) \ 253 | .fill_(utils.q_zero(self.model.q_levels)) 254 | frame_level_outputs = [None for _ in self.model.frame_level_rnns] 255 | 256 | for i in range(self.model.lookback, self.model.lookback + seq_len): 257 | for (tier_index, rnn) in \ 258 | reversed(list(enumerate(self.model.frame_level_rnns))): 259 | if i % rnn.n_frame_samples != 0: 260 | continue 261 | 262 | prev_samples = torch.autograd.Variable( 263 | 2 * utils.linear_dequantize( 264 | sequences[:, i - rnn.n_frame_samples : i], 265 | self.model.q_levels 266 | ).unsqueeze(1), 267 | volatile=True 268 | ) 269 | if self.cuda: 270 | prev_samples = prev_samples.cuda() 271 | 272 | if tier_index == len(self.model.frame_level_rnns) - 1: 273 | upper_tier_conditioning = None 274 | else: 275 | frame_index = (i // rnn.n_frame_samples) % \ 276 | self.model.frame_level_rnns[tier_index + 1].frame_size 277 | upper_tier_conditioning = \ 278 | frame_level_outputs[tier_index + 1][:, frame_index, :] \ 279 | .unsqueeze(1) 280 | 281 | frame_level_outputs[tier_index] = self.run_rnn( 282 | rnn, prev_samples, upper_tier_conditioning 283 | ) 284 | 285 | prev_samples = torch.autograd.Variable( 286 | sequences[:, i - bottom_frame_size : i], 287 | volatile=True 288 | ) 289 | if self.cuda: 290 | prev_samples = prev_samples.cuda() 291 | upper_tier_conditioning = \ 292 | frame_level_outputs[0][:, i % bottom_frame_size, :] \ 293 | .unsqueeze(1) 294 | sample_dist = self.model.sample_level_mlp( 295 | prev_samples, upper_tier_conditioning 296 | ).squeeze(1).exp_().data 297 | sequences[:, i] = sample_dist.multinomial(1).squeeze(1) 298 | 299 | torch.backends.cudnn.enabled = True 300 | 301 | return sequences[:, self.model.lookback :] 302 | -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | 6 | 7 | class LearnedUpsampling1d(nn.Module): 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size, bias=True): 10 | super().__init__() 11 | 12 | self.conv_t = nn.ConvTranspose1d( 13 | in_channels=in_channels, 14 | out_channels=out_channels, 15 | kernel_size=kernel_size, 16 | stride=kernel_size, 17 | bias=False 18 | ) 19 | 20 | if bias: 21 | self.bias = nn.Parameter( 22 | torch.FloatTensor(out_channels, kernel_size) 23 | ) 24 | else: 25 | self.register_parameter('bias', None) 26 | 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | self.conv_t.reset_parameters() 31 | nn.init.constant(self.bias, 0) 32 | 33 | def forward(self, input): 34 | (batch_size, _, length) = input.size() 35 | (kernel_size,) = self.conv_t.kernel_size 36 | bias = self.bias.unsqueeze(0).unsqueeze(2).expand( 37 | batch_size, self.conv_t.out_channels, 38 | length, kernel_size 39 | ).contiguous().view( 40 | batch_size, self.conv_t.out_channels, 41 | length * kernel_size 42 | ) 43 | return self.conv_t(input) + bias 44 | 45 | 46 | def lecun_uniform(tensor): 47 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 48 | nn.init.uniform(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 49 | 50 | 51 | def concat_init(tensor, inits): 52 | try: 53 | tensor = tensor.data 54 | except AttributeError: 55 | pass 56 | 57 | (length, fan_out) = tensor.size() 58 | fan_in = length // len(inits) 59 | 60 | chunk = tensor.new(fan_in, fan_out) 61 | for (i, init) in enumerate(inits): 62 | init(chunk) 63 | tensor[i * fan_in : (i + 1) * fan_in, :] = chunk 64 | 65 | 66 | def sequence_nll_loss_bits(input, target, *args, **kwargs): 67 | (_, _, n_classes) = input.size() 68 | return nn.functional.nll_loss( 69 | input.view(-1, n_classes), target.view(-1), *args, **kwargs 70 | ) * math.log(math.e, 2) 71 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import hardtanh 2 | 3 | 4 | def gradient_clipping(optimizer, min=-1, max=1): 5 | 6 | class OptimizerWrapper(object): 7 | 8 | def step(self, closure): 9 | def closure_wrapper(): 10 | loss = closure() 11 | for group in optimizer.param_groups: 12 | for p in group['params']: 13 | hardtanh(p.grad, min, max, inplace=True) 14 | return loss 15 | 16 | return optimizer.step(closure_wrapper) 17 | 18 | def __getattr__(self, attr): 19 | return getattr(optimizer, attr) 20 | 21 | return OptimizerWrapper() 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.5.1 2 | matplotlib==2.1.0 3 | natsort==5.1.0 4 | torch==0.2.0.post3 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # CometML needs to be imported first. 2 | try: 3 | import comet_ml 4 | except ImportError: 5 | pass 6 | 7 | from model import SampleRNN, Predictor 8 | from optim import gradient_clipping 9 | from nn import sequence_nll_loss_bits 10 | from trainer import Trainer 11 | from trainer.plugins import ( 12 | TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin, 13 | GeneratorPlugin, StatsPlugin 14 | ) 15 | from dataset import FolderDataset, DataLoader 16 | 17 | import torch 18 | from torch.utils.trainer.plugins import Logger 19 | 20 | from natsort import natsorted 21 | 22 | from functools import reduce 23 | import os 24 | import shutil 25 | import sys 26 | from glob import glob 27 | import re 28 | import argparse 29 | 30 | 31 | default_params = { 32 | # model parameters 33 | 'n_rnn': 1, 34 | 'dim': 1024, 35 | 'learn_h0': True, 36 | 'q_levels': 256, 37 | 'seq_len': 1024, 38 | 'weight_norm': True, 39 | 'batch_size': 128, 40 | 'val_frac': 0.1, 41 | 'test_frac': 0.1, 42 | 43 | # training parameters 44 | 'keep_old_checkpoints': False, 45 | 'datasets_path': 'datasets', 46 | 'results_path': 'results', 47 | 'epoch_limit': 1000, 48 | 'resume': True, 49 | 'sample_rate': 16000, 50 | 'n_samples': 1, 51 | 'sample_length': 80000, 52 | 'loss_smoothing': 0.99, 53 | 'cuda': True, 54 | 'comet_key': None 55 | } 56 | 57 | tag_params = [ 58 | 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len', 59 | 'batch_size', 'dataset', 'val_frac', 'test_frac' 60 | ] 61 | 62 | def param_to_string(value): 63 | if isinstance(value, bool): 64 | return 'T' if value else 'F' 65 | elif isinstance(value, list): 66 | return ','.join(map(param_to_string, value)) 67 | else: 68 | return str(value) 69 | 70 | def make_tag(params): 71 | return '-'.join( 72 | key + ':' + param_to_string(params[key]) 73 | for key in tag_params 74 | if key not in default_params or params[key] != default_params[key] 75 | ) 76 | 77 | def setup_results_dir(params): 78 | def ensure_dir_exists(path): 79 | if not os.path.exists(path): 80 | os.makedirs(path) 81 | 82 | tag = make_tag(params) 83 | results_path = os.path.abspath(params['results_path']) 84 | ensure_dir_exists(results_path) 85 | results_path = os.path.join(results_path, tag) 86 | if not os.path.exists(results_path): 87 | os.makedirs(results_path) 88 | elif not params['resume']: 89 | shutil.rmtree(results_path) 90 | os.makedirs(results_path) 91 | 92 | for subdir in ['checkpoints', 'samples']: 93 | ensure_dir_exists(os.path.join(results_path, subdir)) 94 | 95 | return results_path 96 | 97 | def load_last_checkpoint(checkpoints_path): 98 | checkpoints_pattern = os.path.join( 99 | checkpoints_path, SaverPlugin.last_pattern.format('*', '*') 100 | ) 101 | checkpoint_paths = natsorted(glob(checkpoints_pattern)) 102 | if len(checkpoint_paths) > 0: 103 | checkpoint_path = checkpoint_paths[-1] 104 | checkpoint_name = os.path.basename(checkpoint_path) 105 | match = re.match( 106 | SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'), 107 | checkpoint_name 108 | ) 109 | epoch = int(match.group(1)) 110 | iteration = int(match.group(2)) 111 | return (torch.load(checkpoint_path), epoch, iteration) 112 | else: 113 | return None 114 | 115 | def tee_stdout(log_path): 116 | log_file = open(log_path, 'a', 1) 117 | stdout = sys.stdout 118 | 119 | class Tee: 120 | 121 | def write(self, string): 122 | log_file.write(string) 123 | stdout.write(string) 124 | 125 | def flush(self): 126 | log_file.flush() 127 | stdout.flush() 128 | 129 | sys.stdout = Tee() 130 | 131 | def make_data_loader(overlap_len, params): 132 | path = os.path.join(params['datasets_path'], params['dataset']) 133 | def data_loader(split_from, split_to, eval): 134 | dataset = FolderDataset( 135 | path, overlap_len, params['q_levels'], split_from, split_to 136 | ) 137 | return DataLoader( 138 | dataset, 139 | batch_size=params['batch_size'], 140 | seq_len=params['seq_len'], 141 | overlap_len=overlap_len, 142 | shuffle=(not eval), 143 | drop_last=(not eval) 144 | ) 145 | return data_loader 146 | 147 | def init_comet(params, trainer): 148 | if params['comet_key'] is not None: 149 | from comet_ml import Experiment 150 | from trainer.plugins import CometPlugin 151 | experiment = Experiment(api_key=params['comet_key'], log_code=False) 152 | hyperparams = { 153 | name: param_to_string(params[name]) for name in tag_params 154 | } 155 | experiment.log_multiple_params(hyperparams) 156 | trainer.register_plugin(CometPlugin( 157 | experiment, [ 158 | ('training_loss', 'epoch_mean'), 159 | 'validation_loss', 160 | 'test_loss' 161 | ] 162 | )) 163 | 164 | def main(exp, frame_sizes, dataset, **params): 165 | params = dict( 166 | default_params, 167 | exp=exp, frame_sizes=frame_sizes, dataset=dataset, 168 | **params 169 | ) 170 | 171 | results_path = setup_results_dir(params) 172 | tee_stdout(os.path.join(results_path, 'log')) 173 | 174 | model = SampleRNN( 175 | frame_sizes=params['frame_sizes'], 176 | n_rnn=params['n_rnn'], 177 | dim=params['dim'], 178 | learn_h0=params['learn_h0'], 179 | q_levels=params['q_levels'], 180 | weight_norm=params['weight_norm'] 181 | ) 182 | predictor = Predictor(model) 183 | if params['cuda']: 184 | model = model.cuda() 185 | predictor = predictor.cuda() 186 | 187 | optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) 188 | 189 | data_loader = make_data_loader(model.lookback, params) 190 | test_split = 1 - params['test_frac'] 191 | val_split = test_split - params['val_frac'] 192 | 193 | trainer = Trainer( 194 | predictor, sequence_nll_loss_bits, optimizer, 195 | data_loader(0, val_split, eval=False), 196 | cuda=params['cuda'] 197 | ) 198 | 199 | checkpoints_path = os.path.join(results_path, 'checkpoints') 200 | checkpoint_data = load_last_checkpoint(checkpoints_path) 201 | if checkpoint_data is not None: 202 | (state_dict, epoch, iteration) = checkpoint_data 203 | trainer.epochs = epoch 204 | trainer.iterations = iteration 205 | predictor.load_state_dict(state_dict) 206 | 207 | trainer.register_plugin(TrainingLossMonitor( 208 | smoothing=params['loss_smoothing'] 209 | )) 210 | trainer.register_plugin(ValidationPlugin( 211 | data_loader(val_split, test_split, eval=True), 212 | data_loader(test_split, 1, eval=True) 213 | )) 214 | trainer.register_plugin(AbsoluteTimeMonitor()) 215 | trainer.register_plugin(SaverPlugin( 216 | checkpoints_path, params['keep_old_checkpoints'] 217 | )) 218 | trainer.register_plugin(GeneratorPlugin( 219 | os.path.join(results_path, 'samples'), params['n_samples'], 220 | params['sample_length'], params['sample_rate'] 221 | )) 222 | trainer.register_plugin( 223 | Logger([ 224 | 'training_loss', 225 | 'validation_loss', 226 | 'test_loss', 227 | 'time' 228 | ]) 229 | ) 230 | trainer.register_plugin(StatsPlugin( 231 | results_path, 232 | iteration_fields=[ 233 | 'training_loss', 234 | ('training_loss', 'running_avg'), 235 | 'time' 236 | ], 237 | epoch_fields=[ 238 | 'validation_loss', 239 | 'test_loss', 240 | 'time' 241 | ], 242 | plots={ 243 | 'loss': { 244 | 'x': 'iteration', 245 | 'ys': [ 246 | 'training_loss', 247 | ('training_loss', 'running_avg'), 248 | 'validation_loss', 249 | 'test_loss', 250 | ], 251 | 'log_y': True 252 | } 253 | } 254 | )) 255 | 256 | init_comet(params, trainer) 257 | 258 | trainer.run(params['epoch_limit']) 259 | 260 | 261 | if __name__ == '__main__': 262 | parser = argparse.ArgumentParser( 263 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 264 | argument_default=argparse.SUPPRESS 265 | ) 266 | 267 | def parse_bool(arg): 268 | arg = arg.lower() 269 | if 'true'.startswith(arg): 270 | return True 271 | elif 'false'.startswith(arg): 272 | return False 273 | else: 274 | raise ValueError() 275 | 276 | parser.add_argument('--exp', required=True, help='experiment name') 277 | parser.add_argument( 278 | '--frame_sizes', nargs='+', type=int, required=True, 279 | help='frame sizes in terms of the number of lower tier frames, \ 280 | starting from the lowest RNN tier' 281 | ) 282 | parser.add_argument( 283 | '--dataset', required=True, 284 | help='dataset name - name of a directory in the datasets path \ 285 | (settable by --datasets_path)' 286 | ) 287 | parser.add_argument( 288 | '--n_rnn', type=int, help='number of RNN layers in each tier' 289 | ) 290 | parser.add_argument( 291 | '--dim', type=int, help='number of neurons in every RNN and MLP layer' 292 | ) 293 | parser.add_argument( 294 | '--learn_h0', type=parse_bool, 295 | help='whether to learn the initial states of RNNs' 296 | ) 297 | parser.add_argument( 298 | '--q_levels', type=int, 299 | help='number of bins in quantization of audio samples' 300 | ) 301 | parser.add_argument( 302 | '--seq_len', type=int, 303 | help='how many samples to include in each truncated BPTT pass' 304 | ) 305 | parser.add_argument( 306 | '--weight_norm', type=parse_bool, 307 | help='whether to use weight normalization' 308 | ) 309 | parser.add_argument('--batch_size', type=int, help='batch size') 310 | parser.add_argument( 311 | '--val_frac', type=float, 312 | help='fraction of data to go into the validation set' 313 | ) 314 | parser.add_argument( 315 | '--test_frac', type=float, 316 | help='fraction of data to go into the test set' 317 | ) 318 | parser.add_argument( 319 | '--keep_old_checkpoints', type=parse_bool, 320 | help='whether to keep checkpoints from past epochs' 321 | ) 322 | parser.add_argument( 323 | '--datasets_path', help='path to the directory containing datasets' 324 | ) 325 | parser.add_argument( 326 | '--results_path', help='path to the directory to save the results to' 327 | ) 328 | parser.add_argument('--epoch_limit', help='how many epochs to run') 329 | parser.add_argument( 330 | '--resume', type=parse_bool, default=True, 331 | help='whether to resume training from the last checkpoint' 332 | ) 333 | parser.add_argument( 334 | '--sample_rate', type=int, 335 | help='sample rate of the training data and generated sound' 336 | ) 337 | parser.add_argument( 338 | '--n_samples', type=int, 339 | help='number of samples to generate in each epoch' 340 | ) 341 | parser.add_argument( 342 | '--sample_length', type=int, 343 | help='length of each generated sample (in samples)' 344 | ) 345 | parser.add_argument( 346 | '--loss_smoothing', type=float, 347 | help='smoothing parameter of the exponential moving average over \ 348 | training loss, used in the log and in the loss plot' 349 | ) 350 | parser.add_argument( 351 | '--cuda', type=parse_bool, 352 | help='whether to use CUDA' 353 | ) 354 | parser.add_argument( 355 | '--comet_key', help='comet.ml API key' 356 | ) 357 | 358 | parser.set_defaults(**default_params) 359 | 360 | main(**vars(parser.parse_args())) 361 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | import heapq 5 | 6 | 7 | # Based on torch.utils.trainer.Trainer code. 8 | # Allows multiple inputs to the model, not all need to be Tensors. 9 | class Trainer(object): 10 | 11 | def __init__(self, model, criterion, optimizer, dataset, cuda=False): 12 | self.model = model 13 | self.criterion = criterion 14 | self.optimizer = optimizer 15 | self.dataset = dataset 16 | self.cuda = cuda 17 | self.iterations = 0 18 | self.epochs = 0 19 | self.stats = {} 20 | self.plugin_queues = { 21 | 'iteration': [], 22 | 'epoch': [], 23 | 'batch': [], 24 | 'update': [], 25 | } 26 | 27 | def register_plugin(self, plugin): 28 | plugin.register(self) 29 | 30 | intervals = plugin.trigger_interval 31 | if not isinstance(intervals, list): 32 | intervals = [intervals] 33 | for (duration, unit) in intervals: 34 | queue = self.plugin_queues[unit] 35 | queue.append((duration, len(queue), plugin)) 36 | 37 | def call_plugins(self, queue_name, time, *args): 38 | args = (time,) + args 39 | queue = self.plugin_queues[queue_name] 40 | if len(queue) == 0: 41 | return 42 | while queue[0][0] <= time: 43 | plugin = queue[0][2] 44 | getattr(plugin, queue_name)(*args) 45 | for trigger in plugin.trigger_interval: 46 | if trigger[1] == queue_name: 47 | interval = trigger[0] 48 | new_item = (time + interval, queue[0][1], plugin) 49 | heapq.heappushpop(queue, new_item) 50 | 51 | def run(self, epochs=1): 52 | for q in self.plugin_queues.values(): 53 | heapq.heapify(q) 54 | 55 | for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1): 56 | self.train() 57 | self.call_plugins('epoch', self.epochs) 58 | 59 | def train(self): 60 | for (self.iterations, data) in \ 61 | enumerate(self.dataset, self.iterations + 1): 62 | batch_inputs = data[: -1] 63 | batch_target = data[-1] 64 | self.call_plugins( 65 | 'batch', self.iterations, batch_inputs, batch_target 66 | ) 67 | 68 | def wrap(input): 69 | if torch.is_tensor(input): 70 | input = Variable(input) 71 | if self.cuda: 72 | input = input.cuda() 73 | return input 74 | batch_inputs = list(map(wrap, batch_inputs)) 75 | 76 | batch_target = Variable(batch_target) 77 | if self.cuda: 78 | batch_target = batch_target.cuda() 79 | 80 | plugin_data = [None, None] 81 | 82 | def closure(): 83 | batch_output = self.model(*batch_inputs) 84 | 85 | loss = self.criterion(batch_output, batch_target) 86 | loss.backward() 87 | 88 | if plugin_data[0] is None: 89 | plugin_data[0] = batch_output.data 90 | plugin_data[1] = loss.data 91 | 92 | return loss 93 | 94 | self.optimizer.zero_grad() 95 | self.optimizer.step(closure) 96 | self.call_plugins( 97 | 'iteration', self.iterations, batch_inputs, batch_target, 98 | *plugin_data 99 | ) 100 | self.call_plugins('update', self.iterations, self.model) 101 | -------------------------------------------------------------------------------- /trainer/plugins.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | from model import Generator 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch.utils.trainer.plugins.plugin import Plugin 9 | from torch.utils.trainer.plugins.monitor import Monitor 10 | from torch.utils.trainer.plugins import LossMonitor 11 | 12 | from librosa.output import write_wav 13 | from matplotlib import pyplot 14 | 15 | from glob import glob 16 | import os 17 | import pickle 18 | import time 19 | 20 | 21 | class TrainingLossMonitor(LossMonitor): 22 | 23 | stat_name = 'training_loss' 24 | 25 | 26 | class ValidationPlugin(Plugin): 27 | 28 | def __init__(self, val_dataset, test_dataset): 29 | super().__init__([(1, 'epoch')]) 30 | self.val_dataset = val_dataset 31 | self.test_dataset = test_dataset 32 | 33 | def register(self, trainer): 34 | self.trainer = trainer 35 | val_stats = self.trainer.stats.setdefault('validation_loss', {}) 36 | val_stats['log_epoch_fields'] = ['{last:.4f}'] 37 | test_stats = self.trainer.stats.setdefault('test_loss', {}) 38 | test_stats['log_epoch_fields'] = ['{last:.4f}'] 39 | 40 | def epoch(self, idx): 41 | self.trainer.model.eval() 42 | 43 | val_stats = self.trainer.stats.setdefault('validation_loss', {}) 44 | val_stats['last'] = self._evaluate(self.val_dataset) 45 | test_stats = self.trainer.stats.setdefault('test_loss', {}) 46 | test_stats['last'] = self._evaluate(self.test_dataset) 47 | 48 | self.trainer.model.train() 49 | 50 | def _evaluate(self, dataset): 51 | loss_sum = 0 52 | n_examples = 0 53 | for data in dataset: 54 | batch_inputs = data[: -1] 55 | batch_target = data[-1] 56 | batch_size = batch_target.size()[0] 57 | 58 | def wrap(input): 59 | if torch.is_tensor(input): 60 | input = Variable(input, volatile=True) 61 | if self.trainer.cuda: 62 | input = input.cuda() 63 | return input 64 | batch_inputs = list(map(wrap, batch_inputs)) 65 | 66 | batch_target = Variable(batch_target, volatile=True) 67 | if self.trainer.cuda: 68 | batch_target = batch_target.cuda() 69 | 70 | batch_output = self.trainer.model(*batch_inputs) 71 | loss_sum += self.trainer.criterion(batch_output, batch_target) \ 72 | .data[0] * batch_size 73 | 74 | n_examples += batch_size 75 | 76 | return loss_sum / n_examples 77 | 78 | 79 | class AbsoluteTimeMonitor(Monitor): 80 | 81 | stat_name = 'time' 82 | 83 | def __init__(self, *args, **kwargs): 84 | kwargs.setdefault('unit', 's') 85 | kwargs.setdefault('precision', 0) 86 | kwargs.setdefault('running_average', False) 87 | kwargs.setdefault('epoch_average', False) 88 | super(AbsoluteTimeMonitor, self).__init__(*args, **kwargs) 89 | self.start_time = None 90 | 91 | def _get_value(self, *args): 92 | if self.start_time is None: 93 | self.start_time = time.time() 94 | return time.time() - self.start_time 95 | 96 | 97 | class SaverPlugin(Plugin): 98 | 99 | last_pattern = 'ep{}-it{}' 100 | best_pattern = 'best-ep{}-it{}' 101 | 102 | def __init__(self, checkpoints_path, keep_old_checkpoints): 103 | super().__init__([(1, 'epoch')]) 104 | self.checkpoints_path = checkpoints_path 105 | self.keep_old_checkpoints = keep_old_checkpoints 106 | self._best_val_loss = float('+inf') 107 | 108 | def register(self, trainer): 109 | self.trainer = trainer 110 | 111 | def epoch(self, epoch_index): 112 | if not self.keep_old_checkpoints: 113 | self._clear(self.last_pattern.format('*', '*')) 114 | torch.save( 115 | self.trainer.model.state_dict(), 116 | os.path.join( 117 | self.checkpoints_path, 118 | self.last_pattern.format(epoch_index, self.trainer.iterations) 119 | ) 120 | ) 121 | 122 | cur_val_loss = self.trainer.stats['validation_loss']['last'] 123 | if cur_val_loss < self._best_val_loss: 124 | self._clear(self.best_pattern.format('*', '*')) 125 | torch.save( 126 | self.trainer.model.state_dict(), 127 | os.path.join( 128 | self.checkpoints_path, 129 | self.best_pattern.format( 130 | epoch_index, self.trainer.iterations 131 | ) 132 | ) 133 | ) 134 | self._best_val_loss = cur_val_loss 135 | 136 | def _clear(self, pattern): 137 | pattern = os.path.join(self.checkpoints_path, pattern) 138 | for file_name in glob(pattern): 139 | os.remove(file_name) 140 | 141 | 142 | class GeneratorPlugin(Plugin): 143 | 144 | pattern = 'ep{}-s{}.wav' 145 | 146 | def __init__(self, samples_path, n_samples, sample_length, sample_rate): 147 | super().__init__([(1, 'epoch')]) 148 | self.samples_path = samples_path 149 | self.n_samples = n_samples 150 | self.sample_length = sample_length 151 | self.sample_rate = sample_rate 152 | 153 | def register(self, trainer): 154 | self.generate = Generator(trainer.model.model, trainer.cuda) 155 | 156 | def epoch(self, epoch_index): 157 | samples = self.generate(self.n_samples, self.sample_length) \ 158 | .cpu().float().numpy() 159 | for i in range(self.n_samples): 160 | write_wav( 161 | os.path.join( 162 | self.samples_path, self.pattern.format(epoch_index, i + 1) 163 | ), 164 | samples[i, :], sr=self.sample_rate, norm=True 165 | ) 166 | 167 | 168 | class StatsPlugin(Plugin): 169 | 170 | data_file_name = 'stats.pkl' 171 | plot_pattern = '{}.svg' 172 | 173 | def __init__(self, results_path, iteration_fields, epoch_fields, plots): 174 | super().__init__([(1, 'iteration'), (1, 'epoch')]) 175 | self.results_path = results_path 176 | 177 | self.iteration_fields = self._fields_to_pairs(iteration_fields) 178 | self.epoch_fields = self._fields_to_pairs(epoch_fields) 179 | self.plots = plots 180 | self.data = { 181 | 'iterations': { 182 | field: [] 183 | for field in self.iteration_fields + [('iteration', 'last')] 184 | }, 185 | 'epochs': { 186 | field: [] 187 | for field in self.epoch_fields + [('iteration', 'last')] 188 | } 189 | } 190 | 191 | def register(self, trainer): 192 | self.trainer = trainer 193 | 194 | def iteration(self, *args): 195 | for (field, stat) in self.iteration_fields: 196 | self.data['iterations'][field, stat].append( 197 | self.trainer.stats[field][stat] 198 | ) 199 | 200 | self.data['iterations']['iteration', 'last'].append( 201 | self.trainer.iterations 202 | ) 203 | 204 | def epoch(self, epoch_index): 205 | for (field, stat) in self.epoch_fields: 206 | self.data['epochs'][field, stat].append( 207 | self.trainer.stats[field][stat] 208 | ) 209 | 210 | self.data['epochs']['iteration', 'last'].append( 211 | self.trainer.iterations 212 | ) 213 | 214 | data_file_path = os.path.join(self.results_path, self.data_file_name) 215 | with open(data_file_path, 'wb') as f: 216 | pickle.dump(self.data, f) 217 | 218 | for (name, info) in self.plots.items(): 219 | x_field = self._field_to_pair(info['x']) 220 | 221 | try: 222 | y_fields = info['ys'] 223 | except KeyError: 224 | y_fields = [info['y']] 225 | 226 | labels = list(map( 227 | lambda x: ' '.join(x) if type(x) is tuple else x, 228 | y_fields 229 | )) 230 | y_fields = self._fields_to_pairs(y_fields) 231 | 232 | try: 233 | formats = info['formats'] 234 | except KeyError: 235 | formats = [''] * len(y_fields) 236 | 237 | pyplot.gcf().clear() 238 | 239 | for (y_field, format, label) in zip(y_fields, formats, labels): 240 | if y_field in self.iteration_fields: 241 | part_name = 'iterations' 242 | else: 243 | part_name = 'epochs' 244 | 245 | xs = self.data[part_name][x_field] 246 | ys = self.data[part_name][y_field] 247 | 248 | pyplot.plot(xs, ys, format, label=label) 249 | 250 | if 'log_y' in info and info['log_y']: 251 | pyplot.yscale('log') 252 | 253 | pyplot.legend() 254 | pyplot.savefig( 255 | os.path.join(self.results_path, self.plot_pattern.format(name)) 256 | ) 257 | 258 | @staticmethod 259 | def _field_to_pair(field): 260 | if type(field) is tuple: 261 | return field 262 | else: 263 | return (field, 'last') 264 | 265 | @classmethod 266 | def _fields_to_pairs(cls, fields): 267 | return list(map(cls._field_to_pair, fields)) 268 | 269 | 270 | class CometPlugin(Plugin): 271 | 272 | def __init__(self, experiment, fields): 273 | super().__init__([(1, 'epoch')]) 274 | 275 | self.experiment = experiment 276 | self.fields = [ 277 | field if type(field) is tuple else (field, 'last') 278 | for field in fields 279 | ] 280 | 281 | def register(self, trainer): 282 | self.trainer = trainer 283 | 284 | def epoch(self, epoch_index): 285 | for (field, stat) in self.fields: 286 | self.experiment.log_metric(field, self.trainer.stats[field][stat]) 287 | self.experiment.log_epoch_end(epoch_index) 288 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | EPSILON = 1e-2 7 | 8 | def linear_quantize(samples, q_levels): 9 | samples = samples.clone() 10 | samples -= samples.min(dim=-1)[0].expand_as(samples) 11 | samples /= samples.max(dim=-1)[0].expand_as(samples) 12 | samples *= q_levels - EPSILON 13 | samples += EPSILON / 2 14 | return samples.long() 15 | 16 | def linear_dequantize(samples, q_levels): 17 | return samples.float() / (q_levels / 2) - 1 18 | 19 | def q_zero(q_levels): 20 | return q_levels // 2 21 | --------------------------------------------------------------------------------