├── nice ├── __init__.py ├── utils.py ├── loss.py ├── models.py └── layers.py ├── .gitignore ├── LICENSE ├── make_datasets.py ├── README.md ├── sample.py └── train.py /nice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/** 2 | __pycache__/ 3 | datasets/* 4 | saved_models/ 5 | samples.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Paul Tang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of nice_pytorch nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /nice/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for loading, rescaling, image processing. 3 | """ 4 | import torch 5 | 6 | 7 | def unflatten_images(input_batch, depth, height, width): 8 | """ 9 | Take a batch of images and unflatten into a DxHxW grid. 10 | Nearly an inverse of `flatten_images`. (`flatten_images` assumes a list of tensors, not a tensor.) 11 | 12 | Args: 13 | * input_batch: a tensor of dtype=float and shape (bsz, d*h*w). 14 | * depth: int 15 | * height: int 16 | * width: int 17 | """ 18 | return input_batch.view(input_batch.shape[0], depth, height, width) 19 | 20 | 21 | def rescale(x, lo, hi): 22 | """Rescale a tensor to [lo,hi].""" 23 | assert(lo < hi), "[rescale] lo={0} must be smaller than hi={1}".format(lo,hi) 24 | old_width = torch.max(x)-torch.min(x) 25 | old_center = torch.min(x) + (old_width / 2.) 26 | new_width = float(hi-lo) 27 | new_center = lo + (new_width / 2.) 28 | # shift everything back to zero: 29 | x = x - old_center 30 | # rescale to correct width: 31 | x = x * (new_width / old_width) 32 | # shift everything to the new center: 33 | x = x + new_center 34 | # return: 35 | return x 36 | 37 | 38 | def l1_norm(mdl, include_bias=True, device=(torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))): 39 | """Compute L1 norm on all the weights of mdl.""" 40 | if include_bias: 41 | _norm = torch.tensor(0.0, device=device) 42 | for w in mdl.parameters(): 43 | _norm = _norm + w.norm(p=1) 44 | return _norm 45 | else: 46 | _norm = torch.tensor(0.0) 47 | for w in mdl.parameters(): 48 | if len(w.shape) > 1: 49 | _norm = _norm + w.norm(p=1) 50 | return _norm 51 | -------------------------------------------------------------------------------- /make_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download all datasets and compute preprocessing structures (whitening matrices, etc). 3 | """ 4 | import torch 5 | import torchvision 6 | from nice.utils import rescale 7 | 8 | 9 | def zca_matrix(data_tensor): 10 | """ 11 | Helper function: compute ZCA whitening matrix across a dataset ~ (N, C, H, W). 12 | """ 13 | # 1. flatten dataset: 14 | X = data_tensor.view(data_tensor.shape[0], -1) 15 | 16 | # 2. zero-center the matrix: 17 | X = rescale(X, -1., 1.) 18 | 19 | # 3. compute covariances: 20 | cov = torch.t(X) @ X 21 | 22 | # 4. compute ZCA(X) == U @ (diag(1/S)) @ torch.t(V) where U, S, V = SVD(cov): 23 | U, S, V = torch.svd(cov) 24 | return (U @ torch.diag(torch.reciprocal(S)) @ torch.t(V)) 25 | 26 | 27 | def main(): 28 | ### download training datasets: 29 | print("Downloading CIFAR10...") 30 | cifar10 = torchvision.datasets.CIFAR10(root="./datasets/cifar", train=True, 31 | transform=torchvision.transforms.ToTensor(), download=True) 32 | print("Downloading SVHN...") 33 | svhn = torchvision.datasets.SVHN(root="./datasets/svhn", split='train', 34 | transform=torchvision.transforms.ToTensor(), download=True) 35 | print("Downloading MNIST...") 36 | mnist = torchvision.datasets.MNIST(root="./datasets/mnist", train=True, 37 | transform=torchvision.transforms.ToTensor(), download=True) 38 | 39 | ### save ZCA whitening matrices: 40 | print("Computing CIFAR10 ZCA matrix...") 41 | torch.save(zca_matrix(torch.cat([x for (x,_) in cifar10], dim=0)), "./datasets/cifar/zca_matrix.pt") 42 | print("Computing SVHN ZCA matrix...") 43 | torch.save(zca_matrix(torch.cat([x for (x,_) in svhn], dim=0)), "./datasets/svhn/zca_matrix.pt") 44 | 45 | print("...All done.") 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Nonlinear Independent Components Estimation 2 | =========================================== 3 | 4 | An implementation of the NICE model from Dinh et al (2014) in PyTorch. 5 | 6 | I was only able to find [the original theano-based repo](https://github.com/laurent-dinh/nice) from the first author, 7 | and I figured it would be good practice to re-implement the architecture in PyTorch. 8 | 9 | Please cite the paper by the original authors and credit them (not me or this repo) if any of the code in this repo 10 | ends up being useful to you in a publication: 11 | 12 | ["NICE: Non-linear independent components estimation"](http://arxiv.org/abs/1410.8516), Laurent Dinh, David Krueger, Yoshua Bengio. ArXiv 2014. 13 | 14 | 15 | Requirements 16 | ------------ 17 | * PyTorch 0.4.1+ 18 | * NumPy 1.14.5+ 19 | * tqdm 4.15.0+ (though any version should work --- we primarily just use the main tqdm and trange wrappers.) 20 | 21 | 22 | Benchmarks 23 | ---------- 24 | We plan to use the same four datasets as in the original paper (MNIST, TFD, SVHN, and CIFAR-10) and attempt to reproduce the results in the paper. At present, MNIST, SVHN, and CIFAR10 are supported; TFD is a bit harder to get access to (due to privacy issues regarding the faces, etc.) 25 | 26 | Running `python make_datasets.py` will download the relevant dataset and store it in the appropriate directory the first time 27 | you run it; subsequent runs will not re-download the datasets if they already exist. Additionally, the ZCA matrices will be 28 | computed for the relevant datasets that require them (CIFAR10, SVHN). 29 | 30 | `(TBD: comparisons to original repo & paper results here --- once I find the time to run on 1500 epochs.)` 31 | 32 | 33 | License 34 | ------- 35 | The license for this repository is the 3-clause BSD, as in the theano-based implementation. 36 | 37 | 38 | Status 39 | ------ 40 | * Training on MNIST, CIFAR10, SVHN currently work; trained models can be sampled via `python sample.py [--args]`. 41 | * Training on GPU currently works. (Sampling is still CPU-only, but this is by design.) 42 | * Benchmarks are still forthcoming. 43 | * Toronto Face Dataset support is still something I'm considering if I can find a place to download it. 44 | 45 | Future To-Do List 46 | ----------------- 47 | + Implement inpainting from trained model. 48 | + Toronto Face Dataset? (See remark about privacy issues above) 49 | + Implement affine coupling law 50 | + Allow arbitrary partitions of the input in coupling layers? 51 | -------------------------------------------------------------------------------- /nice/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of NICE log-likelihood loss. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | # ===== ===== Loss Function Implementations ===== ===== 9 | """ 10 | We assume that we final output of the network are components of a multivariate distribution that 11 | factorizes, i.e. the output is (y1,y2,...,yK) ~ p(Y) s.t. p(Y) = p_1(Y1) * p_2(Y2) * ... * p_K(YK), 12 | with each individual component's prior distribution coming from a standardized family of 13 | distributions, i.e. p_i == Gaussian(mu,sigma) for all i in 1..K, or p_i == Logistic(mu,scale). 14 | """ 15 | def gaussian_nice_loglkhd(h, diag): 16 | """ 17 | Definition of log-likelihood function with a Gaussian prior, as in the paper. 18 | 19 | Args: 20 | * h: float tensor of shape (N,D). First dimension is batch dim, second dim consists of components 21 | of a factorized probability distribution. 22 | * diag: scaling diagonal of shape (D,). 23 | 24 | Returns: 25 | * loss: torch float tensor of shape (N,). 26 | """ 27 | # \sum^D_i s_{ii} - { (1/2) * \sum^D_i h_i**2) + (D/2) * log(2\pi) } 28 | return torch.sum(diag) - (0.5*torch.sum(torch.pow(h,2),dim=1) + h.size(1)*0.5*torch.log(torch.tensor(2*np.pi))) 29 | 30 | def logistic_nice_loglkhd(h, diag): 31 | """ 32 | Definition of log-likelihood function with a Logistic prior. 33 | 34 | Same arguments/returns as gaussian_nice_loglkhd. 35 | """ 36 | # \sum^D_i s_{ii} - { \sum^D_i log(exp(h)+1) + torch.log(exp(-h)+1) } 37 | return (torch.sum(diag) - (torch.sum(torch.log1p(torch.exp(h)) + torch.log1p(torch.exp(-h)), dim=1))) 38 | 39 | # wrap above loss functions in Modules: 40 | class GaussianPriorNICELoss(nn.Module): 41 | def __init__(self, size_average=True): 42 | super(GaussianPriorNICELoss, self).__init__() 43 | self.size_average = size_average 44 | 45 | def forward(self, fx, diag): 46 | if self.size_average: 47 | return torch.mean(-gaussian_nice_loglkhd(fx, diag)) 48 | else: 49 | return torch.sum(-gaussian_nice_loglkhd(fx, diag)) 50 | 51 | class LogisticPriorNICELoss(nn.Module): 52 | def __init__(self, size_average=True): 53 | super(LogisticPriorNICELoss, self).__init__() 54 | self.size_average = size_average 55 | 56 | def forward(self, fx, diag): 57 | if self.size_average: 58 | return torch.mean(-logistic_nice_loglkhd(fx, diag)) 59 | else: 60 | return torch.sum(-logistic_nice_loglkhd(fx, diag)) 61 | -------------------------------------------------------------------------------- /nice/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of models from paper. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from .layers import AdditiveCouplingLayer 8 | 9 | def _build_relu_network(latent_dim, hidden_dim, num_layers): 10 | """Helper function to construct a ReLU network of varying number of layers.""" 11 | _modules = [ nn.Linear(latent_dim, hidden_dim) ] 12 | for _ in range(num_layers): 13 | _modules.append( nn.Linear(hidden_dim, hidden_dim) ) 14 | _modules.append( nn.ReLU() ) 15 | _modules.append( nn.BatchNorm1d(hidden_dim) ) 16 | _modules.append( nn.Linear(hidden_dim, latent_dim) ) 17 | return nn.Sequential( *_modules ) 18 | 19 | 20 | class NICEModel(nn.Module): 21 | """ 22 | Replication of model from the paper: 23 | "Nonlinear Independent Components Estimation", 24 | Laurent Dinh, David Krueger, Yoshua Bengio (2014) 25 | https://arxiv.org/abs/1410.8516 26 | 27 | Contains the following components: 28 | * four additive coupling layers with nonlinearity functions consisting of 29 | five-layer RELUs 30 | * a diagonal scaling matrix output layer 31 | """ 32 | def __init__(self, input_dim, hidden_dim, num_layers): 33 | super(NICEModel, self).__init__() 34 | assert (input_dim % 2 == 0), "[NICEModel] only even input dimensions supported for now" 35 | assert (num_layers > 2), "[NICEModel] num_layers must be at least 3" 36 | self.input_dim = input_dim 37 | half_dim = int(input_dim / 2) 38 | self.layer1 = AdditiveCouplingLayer(input_dim, 'odd', _build_relu_network(half_dim, hidden_dim, num_layers)) 39 | self.layer2 = AdditiveCouplingLayer(input_dim, 'even', _build_relu_network(half_dim, hidden_dim, num_layers)) 40 | self.layer3 = AdditiveCouplingLayer(input_dim, 'odd', _build_relu_network(half_dim, hidden_dim, num_layers)) 41 | self.layer4 = AdditiveCouplingLayer(input_dim, 'even', _build_relu_network(half_dim, hidden_dim, num_layers)) 42 | self.scaling_diag = nn.Parameter(torch.ones(input_dim)) 43 | 44 | # randomly initialize weights: 45 | for p in self.layer1.parameters(): 46 | if len(p.shape) > 1: 47 | init.kaiming_uniform_(p, nonlinearity='relu') 48 | else: 49 | init.normal_(p, mean=0., std=0.001) 50 | for p in self.layer2.parameters(): 51 | if len(p.shape) > 1: 52 | init.kaiming_uniform_(p, nonlinearity='relu') 53 | else: 54 | init.normal_(p, mean=0., std=0.001) 55 | for p in self.layer3.parameters(): 56 | if len(p.shape) > 1: 57 | init.kaiming_uniform_(p, nonlinearity='relu') 58 | else: 59 | init.normal_(p, mean=0., std=0.001) 60 | for p in self.layer4.parameters(): 61 | if len(p.shape) > 1: 62 | init.kaiming_uniform_(p, nonlinearity='relu') 63 | else: 64 | init.normal_(p, mean=0., std=0.001) 65 | 66 | 67 | def forward(self, xs): 68 | """ 69 | Forward pass through all invertible coupling layers. 70 | 71 | Args: 72 | * xs: float tensor of shape (B,dim). 73 | 74 | Returns: 75 | * ys: float tensor of shape (B,dim). 76 | """ 77 | ys = self.layer1(xs) 78 | ys = self.layer2(ys) 79 | ys = self.layer3(ys) 80 | ys = self.layer4(ys) 81 | ys = torch.matmul(ys, torch.diag(torch.exp(self.scaling_diag))) 82 | return ys 83 | 84 | 85 | def inverse(self, ys): 86 | """Invert a set of draws from gaussians""" 87 | with torch.no_grad(): 88 | xs = torch.matmul(ys, torch.diag(torch.reciprocal(torch.exp(self.scaling_diag)))) 89 | xs = self.layer4.inverse(xs) 90 | xs = self.layer3.inverse(xs) 91 | xs = self.layer2.inverse(xs) 92 | xs = self.layer1.inverse(xs) 93 | return xs 94 | -------------------------------------------------------------------------------- /nice/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of NICE bijective triangular-jacobian layers. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | # ===== ===== Coupling Layer Implementations ===== ===== 10 | 11 | _get_even = lambda xs: xs[:,0::2] 12 | _get_odd = lambda xs: xs[:,1::2] 13 | 14 | def _interleave(first, second, order): 15 | """ 16 | Given 2 rank-2 tensors with same batch dimension, interleave their columns. 17 | 18 | The tensors "first" and "second" are assumed to be of shape (B,M) and (B,N) 19 | where M = N or N+1, repsectively. 20 | """ 21 | cols = [] 22 | if order == 'even': 23 | for k in range(second.shape[1]): 24 | cols.append(first[:,k]) 25 | cols.append(second[:,k]) 26 | if first.shape[1] > second.shape[1]: 27 | cols.append(first[:,-1]) 28 | else: 29 | for k in range(first.shape[1]): 30 | cols.append(second[:,k]) 31 | cols.append(first[:,k]) 32 | if second.shape[1] > first.shape[1]: 33 | cols.append(second[:,-1]) 34 | return torch.stack(cols, dim=1) 35 | 36 | 37 | class _BaseCouplingLayer(nn.Module): 38 | def __init__(self, dim, partition, nonlinearity): 39 | """ 40 | Base coupling layer that handles the permutation of the inputs and wraps 41 | an instance of torch.nn.Module. 42 | 43 | Usage: 44 | >> layer = AdditiveCouplingLayer(1000, 'even', nn.Sequential(...)) 45 | 46 | Args: 47 | * dim: dimension of the inputs. 48 | * partition: str, 'even' or 'odd'. If 'even', the even-valued columns are sent to 49 | pass through the activation module. 50 | * nonlinearity: an instance of torch.nn.Module. 51 | """ 52 | super(_BaseCouplingLayer, self).__init__() 53 | # store input dimension of incoming values: 54 | self.dim = dim 55 | # store partition choice and make shorthands for 1st and second partitions: 56 | assert (partition in ['even', 'odd']), "[_BaseCouplingLayer] Partition type must be `even` or `odd`!" 57 | self.partition = partition 58 | if (partition == 'even'): 59 | self._first = _get_even 60 | self._second = _get_odd 61 | else: 62 | self._first = _get_odd 63 | self._second = _get_even 64 | # store nonlinear function module: 65 | # (n.b. this can be a complex instance of torch.nn.Module, for ex. a deep ReLU network) 66 | self.add_module('nonlinearity', nonlinearity) 67 | 68 | def forward(self, x): 69 | """Map an input through the partition and nonlinearity.""" 70 | return _interleave( 71 | self._first(x), 72 | self.coupling_law(self._second(x), self.nonlinearity(self._first(x))), 73 | self.partition 74 | ) 75 | 76 | def inverse(self, y): 77 | """Inverse mapping through the layer. Gradients should be turned off for this pass.""" 78 | return _interleave( 79 | self._first(y), 80 | self.anticoupling_law(self._second(y), self.nonlinearity(self._first(y))), 81 | self.partition 82 | ) 83 | 84 | def coupling_law(self, a, b): 85 | # (a,b) --> g(a,b) 86 | raise NotImplementedError("[_BaseCouplingLayer] Don't call abstract base layer!") 87 | 88 | def anticoupling_law(self, a, b): 89 | # (a,b) --> g^{-1}(a,b) 90 | raise NotImplementedError("[_BaseCouplingLayer] Don't call abstract base layer!") 91 | 92 | 93 | class AdditiveCouplingLayer(_BaseCouplingLayer): 94 | """Layer with coupling law g(a;b) := a + b.""" 95 | def coupling_law(self, a, b): 96 | return (a + b) 97 | def anticoupling_law(self, a, b): 98 | return (a - b) 99 | 100 | 101 | class MultiplicativeCouplingLayer(_BaseCouplingLayer): 102 | """Layer with coupling law g(a;b) := a .* b.""" 103 | def coupling_law(self, a, b): 104 | return torch.mul(a,b) 105 | def anticoupling_law(self, a, b): 106 | return torch.mul(a, torch.reciprocal(b)) 107 | 108 | 109 | class AffineCouplingLayer(_BaseCouplingLayer): 110 | """Layer with coupling law g(a;b) := a .* b1 + b2, where (b1,b2) is a partition of b.""" 111 | def coupling_law(self, a, b): 112 | return torch.mul(a, self._first(b)) + self._second(b) 113 | def anticoupling_law(self, a, b): 114 | # TODO 115 | raise NotImplementedError("TODO: AffineCouplingLayer (sorry!)") 116 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from latent space of independent gaussians. 3 | """ 4 | import torch 5 | import torchvision 6 | import torch.distributions as dist 7 | import numpy as np 8 | from nice.models import NICEModel 9 | from nice.utils import unflatten_images 10 | import argparse 11 | 12 | 13 | def sample(args): 14 | """ 15 | Performs the following: 16 | 1. construct model object & load state dict from saved model; 17 | 2. make H x W samples from a set of gaussian or logistic prior on the latent space; 18 | 3. save to disk as a grid of images. 19 | """ 20 | # parse settings: 21 | if args.dataset == 'mnist': 22 | input_dim = 28*28 23 | img_height = 28 24 | img_width = 28 25 | img_depth = 1 26 | if args.dataset == 'svhn': 27 | input_dim = 32*32*3 28 | img_height = 32 29 | img_width = 32 30 | img_depth = 3 31 | if args.dataset == 'cifar10': 32 | input_dim = 32*32*3 33 | img_height = 32 34 | img_width = 32 35 | img_depth = 3 36 | if args.dataset == 'tfd': 37 | raise NotImplementedError("[sample] Toronto Faces Dataset unsupported right now. Sorry!") 38 | input_dim = None 39 | img_height = None 40 | img_width = None 41 | img_depth = None 42 | 43 | # shut off gradients for sampling: 44 | torch.set_grad_enabled(False) 45 | 46 | # build model & load state dict: 47 | nice = NICEModel(input_dim, args.nhidden, args.nlayers) 48 | if args.model_path is not None: 49 | nice.load_state_dict(torch.load(args.model_path, map_location='cpu')) 50 | print("[sample] Loaded model from file.") 51 | nice.eval() 52 | 53 | # sample a batch: 54 | if args.prior == 'logistic': 55 | LOGISTIC_LOC = 0.0 56 | LOGISTIC_SCALE = (3. / (np.pi**2)) # (sets variance to 1) 57 | logistic = dist.TransformedDistribution( 58 | dist.Uniform(0.0, 1.0), 59 | [dist.SigmoidTransform().inv, dist.AffineTransform(loc=LOGISTIC_LOC, scale=LOGISTIC_SCALE)] 60 | ) 61 | print("[sample] sampling from logistic prior with loc={0:.4f}, scale={1:.4f}.".format(LOGISTIC_LOC,LOGISTIC_SCALE)) 62 | ys = logistic.sample(torch.Size([args.nrows*args.ncols, input_dim])) 63 | xs = nice.inverse(ys) 64 | if args.prior == 'gaussian': 65 | print("[sample] sampling from gaussian prior with loc=0.0, scale=1.0.") 66 | ys = torch.randn(args.nrows*args.ncols, input_dim) 67 | xs = nice.inverse(ys) 68 | 69 | # format sample into images of correct shape: 70 | image_batch = unflatten_images(xs, img_depth, img_height, img_width) 71 | 72 | # arrange into a grid and save to file: 73 | torchvision.utils.save_image(image_batch, args.save_image_path, nrow=args.nrows) 74 | print("[sample] Saved {0}-by-{1} sampled images to {2}.".format(args.nrows, args.ncols, args.save_image_path)) 75 | 76 | 77 | # ===== ===== ===== ===== ===== ===== ===== ===== ===== ===== 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser(description="Sample from a trained NICE model.") 80 | # --- sampling options: 81 | parser.add_argument("--dataset", dest='dataset', choices=('tfd', 'cifar10', 'svhn', 'mnist'), required=True, 82 | help="Which dataset to use; determines height and width") 83 | parser.add_argument("--prior", choices=('logistic', 'gaussian'), default="logistic", 84 | help="Prior distribution of latent space components. [logistic]") 85 | parser.add_argument("--nrows", dest='nrows', default=1, type=int, 86 | help="Number of rows in grid of output images. [1]") 87 | parser.add_argument("--ncols", dest='ncols', default=1, type=int, 88 | help="Number of columns in grid of output images. [1]") 89 | parser.add_argument("--save_image", default="./samples.png", dest='save_image_path', 90 | help="Where to save the grid of samples. [./samples.png]") 91 | # --- model options: 92 | parser.add_argument("--model", dest='model_path', default=None, 93 | help="Path to trained model. [None/untrained model]") 94 | parser.add_argument("--nonlinearity_layers", dest='nlayers', default=5, type=int, 95 | help="Number of layers in the nonlinearity. [5]") 96 | parser.add_argument("--nonlinearity_hiddens", dest='nhidden', default=1000, type=int, 97 | help="Hidden size of inner layers of nonlinearity. [1000]") 98 | args = parser.parse_args() 99 | sample(args) 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training loop for NICEModel. Attempts to replicate the conditions in the NICE paper. 3 | 4 | Supports the following datasets: 5 | * MNIST (LeCun & Cortes, 1998); 6 | * Toronto Face Dataset (Susskind et al, 2010); 7 | * CIFAR-10 (Krizhevsky, 2010); 8 | * Street View House Numbers (Netzer et al, 2011). 9 | 10 | We apply a dequantization for MNIST, TFD, SVHN as follows (following the NICE authors): 11 | 1. Add uniform noise ~ Unif([0, 1/256]); 12 | 2. Rescale data to be in [0,1] in each dimension. 13 | 14 | For CIFAR10, we instead do: 15 | 1. Add uniform noise ~ Unif([-1/256, 1/256]); 16 | 2. Rescale data to be in [-1,1] in each dimensions. 17 | 18 | Additionally, we perform: 19 | * approximate whitening for TFD; 20 | * exact ZCA on SVHN, CIFAR10; 21 | * no additional preprocessing for MNIST. 22 | 23 | Finally, images are flattened from (H,W) to (H*W,). 24 | """ 25 | # numeric/nn libraries: 26 | import torch 27 | import torchvision 28 | import torch.optim as optim 29 | import torch 30 | import torch.utils.data as data 31 | import numpy as np 32 | # models/losses/image utils: 33 | from nice.models import NICEModel 34 | from nice.loss import LogisticPriorNICELoss, GaussianPriorNICELoss 35 | from nice.utils import rescale, l1_norm 36 | # python/os utils: 37 | import argparse 38 | import os 39 | from tqdm import tqdm, trange 40 | 41 | # set CUDA training on if detected: 42 | if torch.cuda.is_available(): 43 | DEVICE = torch.device('cuda:0') 44 | CUDA = True 45 | else: 46 | DEVICE = torch.device('cpu') 47 | CUDA = False 48 | 49 | # ===== ===== ===== ===== ===== ===== ===== ===== ===== ===== 50 | # Dataset loaders: each of these helper functions does the following: 51 | # 1) downloads the corresponding dataset into a folder (if not already downloaded); 52 | # 2) adds the corresponding whitening & rescaling transforms; 53 | # 3) returns a dataloader for that dataset. 54 | 55 | def load_mnist(train=True, batch_size=1, num_workers=0): 56 | """Rescale and preprocess MNIST dataset.""" 57 | mnist_transform = torchvision.transforms.Compose([ 58 | # convert PIL image to tensor: 59 | torchvision.transforms.ToTensor(), 60 | # flatten: 61 | torchvision.transforms.Lambda(lambda x: x.view(-1)), 62 | # add uniform noise: 63 | torchvision.transforms.Lambda(lambda x: (x + torch.rand_like(x).div_(256.))), 64 | # rescale to [0,1]: 65 | torchvision.transforms.Lambda(lambda x: rescale(x, 0., 1.)) 66 | ]) 67 | return data.DataLoader( 68 | torchvision.datasets.MNIST(root="./datasets/mnist", train=train, transform=mnist_transform, download=False), 69 | batch_size=batch_size, 70 | pin_memory=CUDA, 71 | drop_last=train 72 | ) 73 | 74 | def load_svhn(train=True, batch_size=1, num_workers=0): 75 | """Rescale and preprocess SVHN dataset.""" 76 | # check if ZCA matrix exists on dataset yet: 77 | assert os.path.exists("./datasets/svhn/zca_matrix.pt"), \ 78 | "[load_svhn] ZCA whitening matrix not built! Run `python make_dataset.py` first." 79 | zca_matrix = torch.load("./datasets/svhn/zca_matrix.pt") 80 | 81 | svhn_transform = torchvision.transforms.Compose([ 82 | # convert PIL image to tensor: 83 | torchvision.transforms.ToTensor(), 84 | # flatten: 85 | torchvision.transforms.Lambda(lambda x: x.view(-1)), 86 | # add uniform noise: 87 | torchvision.transforms.Lambda(lambda x: (x + torch.rand_like(x).div_(256.))), 88 | # rescale to [0,1]: 89 | torchvision.transforms.Lambda(lambda x: rescale(x, 0., 1.)), 90 | # exact ZCA: 91 | torchvision.transforms.LinearTransformation(zca_matrix) 92 | ]) 93 | _mode = 'train' if train else 'test' 94 | return data.DataLoader( 95 | torchvision.datasets.SVHN(root="./datasets/svhn", split=_mode, transform=svhn_transform, download=False), 96 | batch_size=batch_size, 97 | pin_memory=CUDA, 98 | drop_last=train 99 | ) 100 | 101 | def load_cifar10(train=True, batch_size=1, num_workers=0): 102 | """Rescale and preprocess CIFAR10 dataset.""" 103 | # check if ZCA matrix exists on dataset yet: 104 | assert os.path.exists("./datasets/cifar/zca_matrix.pt"), \ 105 | "[load_cifar10] ZCA whitening matrix not built! Run `python make_datasets.py` first." 106 | zca_matrix = torch.load("./datasets/cifar/zca_matrix.pt") 107 | 108 | cifar10_transform = torchvision.transforms.Compose([ 109 | # convert PIL image to tensor: 110 | torchvision.transforms.ToTensor(), 111 | # flatten: 112 | torchvision.transforms.Lambda(lambda x: x.view(-1)), 113 | # add uniform noise ~ [-1/256, +1/256]: 114 | torchvision.transforms.Lambda(lambda x: (x + torch.rand_like(x).div_(128.).add_(-1./256.))), 115 | # rescale to [-1,1]: 116 | torchvision.transforms.Lambda(lambda x: rescale(x,-1.,1.)), 117 | # exact ZCA: 118 | torchvision.transforms.LinearTransformation(zca_matrix) 119 | ]) 120 | return data.DataLoader( 121 | torchvision.datasets.CIFAR10(root="./datasets/cifar", train=train, transform=cifar10_transform, download=False), 122 | batch_size=batch_size, 123 | pin_memory=CUDA, 124 | drop_last=train 125 | ) 126 | 127 | def load_tfd(train=True, batch_size=1, num_workers=0): 128 | """Rescale and preprocess TFD dataset.""" 129 | raise NotImplementedError("[load_tfd] Toronto Faces Dataset unsupported right now. Sorry!") 130 | 131 | # ===== ===== ===== ===== ===== ===== ===== ===== ===== ===== 132 | # Training loop: return a NICE model trained over a number of epochs. 133 | def train(args): 134 | """Construct a NICE model and train over a number of epochs.""" 135 | # === choose which dataset to build: 136 | if args.dataset == 'mnist': 137 | dataloader_fn = load_mnist 138 | input_dim = 28*28 139 | if args.dataset == 'svhn': 140 | dataloader_fn = load_svhn 141 | input_dim = 32*32*3 142 | if args.dataset == 'cifar10': 143 | dataloader_fn = load_cifar10 144 | input_dim = 32*32*3 145 | if args.dataset == 'tfd': 146 | raise NotImplementedError("[train] Toronto Faces Dataset unsupported right now. Sorry!") 147 | dataloader_fn = load_tfd 148 | input_dim = None 149 | 150 | # === build model & optimizer: 151 | model = NICEModel(input_dim, args.nhidden, args.nlayers) 152 | if (args.model_path is not None): 153 | assert(os.path.exists(args.model_path)), "[train] model does not exist at specified location" 154 | model.load_state_dict(torch.load(args.model_path, map_location='cpu')) 155 | model.to(DEVICE) 156 | opt = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1,args.beta2), eps=args.eps) 157 | 158 | # === choose which loss function to build: 159 | if args.prior == 'logistic': 160 | nice_loss_fn = LogisticPriorNICELoss(size_average=True) 161 | else: 162 | nice_loss_fn = GaussianPriorNICELoss(size_average=True) 163 | def loss_fn(fx): 164 | """Compute NICE loss w/r/t a prior and optional L1 regularization.""" 165 | if args.lmbda == 0.0: 166 | return nice_loss_fn(fx, model.scaling_diag) 167 | else: 168 | return nice_loss_fn(fx, model.scaling_diag) + args.lmbda*l1_norm(model, include_bias=True) 169 | 170 | # === train over a number of epochs; perform validation after each: 171 | for t in range(args.num_epochs): 172 | print("* Epoch {0}:".format(t)) 173 | dataloader = dataloader_fn(train=True, batch_size=args.batch_size) 174 | for inputs, _ in tqdm(dataloader): 175 | opt.zero_grad() 176 | loss_fn(model(inputs.to(DEVICE))).backward() 177 | opt.step() 178 | 179 | # save model to disk and delete dataloader to save memory: 180 | if t % args.save_epoch == 0: 181 | _dev = 'cuda' if CUDA else 'cpu' 182 | _fn = "nice.{0}.l_{1}.h_{2}.p_{3}.e_{4}.{5}.pt".format(args.dataset, args.nlayers, args.nhidden, args.prior, t, _dev) 183 | torch.save(model.state_dict(), os.path.join(args.savedir, _fn)) 184 | print(">>> Saved file: {0}".format(_fn)) 185 | del dataloader 186 | 187 | # perform validation loop: 188 | vmin, vmed, vmean, vmax = validate(model, dataloader_fn, nice_loss_fn) 189 | print(">>> Validation Loss Statistics: min={0}, med={1}, mean={2}, max={3}".format(vmin,vmed,vmean,vmax)) 190 | 191 | # ===== ===== ===== ===== ===== ===== ===== ===== ===== ===== 192 | # Validation loop: set gradient-tracking off with model in eval mode: 193 | def validate(model, dataloader_fn, loss_fn): 194 | """Perform validation on a dataset.""" 195 | # set model to eval mode (turns batch norm training off) 196 | model.eval() 197 | 198 | # build dataloader in eval mode: 199 | dataloader = dataloader_fn(train=False, batch_size=args.batch_size) 200 | 201 | # turn gradient-tracking off (for speed) during validation: 202 | validation_losses = [] 203 | with torch.no_grad(): 204 | for inputs,_ in tqdm(dataloader): 205 | validation_losses.append(loss_fn(model(inputs.to(DEVICE)), model.scaling_diag).item()) 206 | 207 | # delete dataloader to save memory: 208 | del dataloader 209 | 210 | # set model back in train mode: 211 | model.train() 212 | 213 | # return validation loss summary statistics: 214 | return (np.amin(validation_losses), 215 | np.median(validation_losses), 216 | np.mean(validation_losses), 217 | np.amax(validation_losses)) 218 | 219 | # ===== ===== ===== ===== ===== ===== ===== ===== ===== ===== 220 | if __name__ == '__main__': 221 | # ----- parse training settings: 222 | parser = argparse.ArgumentParser(description="Train a fresh NICE model and save.") 223 | # configuration settings: 224 | parser.add_argument("--dataset", required=True, dest='dataset', choices=('tfd', 'cifar10', 'svhn', 'mnist'), 225 | help="Dataset to train the NICE model on.") 226 | parser.add_argument("--epochs", dest='num_epochs', default=1500, type=int, 227 | help="Number of epochs to train on. [1500]") 228 | parser.add_argument("--batch_size", dest="batch_size", default=16, type=int, 229 | help="Number of examples per batch. [16]") 230 | parser.add_argument("--save_epoch", dest="save_epoch", default=10, type=int, 231 | help="Number of epochs between saves. [10]") 232 | parser.add_argument("--savedir", dest='savedir', default="./saved_models", 233 | help="Where to save the trained model. [./saved_models]") 234 | # model settings: 235 | parser.add_argument("--nonlinearity_layers", dest='nlayers', default=5, type=int, 236 | help="Number of layers in the nonlinearity. [5]") 237 | parser.add_argument("--nonlinearity_hiddens", dest='nhidden', default=1000, type=int, 238 | help="Hidden size of inner layers of nonlinearity. [1000]") 239 | parser.add_argument("--prior", choices=('logistic', 'prior'), default="logistic", 240 | help="Prior distribution of latent space components. [logistic]") 241 | parser.add_argument("--model_path", dest='model_path', default=None, type=str, 242 | help="Continue from pretrained model. [None]") 243 | # optimization settings: 244 | parser.add_argument("--lr", default=0.001, dest='lr', type=float, 245 | help="Learning rate for ADAM optimizer. [0.001]") 246 | parser.add_argument("--beta1", default=0.9, dest='beta1', type=float, 247 | help="Momentum for ADAM optimizer. [0.9]") 248 | parser.add_argument("--beta2", default=0.01, dest='beta2', type=float, 249 | help="Beta2 for ADAM optimizer. [0.01]") 250 | parser.add_argument("--eps", default=0.0001, dest='eps', type=float, 251 | help="Epsilon for ADAM optimizer. [0.0001]") 252 | parser.add_argument("--lambda", default=0.0, dest='lmbda', type=float, 253 | help="L1 weight decay coefficient. [0.0]") 254 | args = parser.parse_args() 255 | # ----- run training loop over several epochs & save models for each epoch: 256 | model = train(args) 257 | --------------------------------------------------------------------------------