├── .gitignore ├── LICENSE ├── README.md ├── ckpts └── .gitkeep ├── data └── .gitkeep ├── environment.yml ├── models ├── __init__.py └── glow │ ├── __init__.py │ ├── act_norm.py │ ├── coupling.py │ ├── glow.py │ └── inv_conv.py ├── samples ├── .gitkeep ├── epoch_0.png ├── epoch_1.png ├── epoch_10.png ├── epoch_11.png ├── epoch_12.png ├── epoch_13.png ├── epoch_14.png ├── epoch_15.png ├── epoch_16.png ├── epoch_17.png ├── epoch_18.png ├── epoch_19.png ├── epoch_2.png ├── epoch_20.png ├── epoch_21.png ├── epoch_22.png ├── epoch_23.png ├── epoch_24.png ├── epoch_25.png ├── epoch_26.png ├── epoch_27.png ├── epoch_28.png ├── epoch_29.png ├── epoch_3.png ├── epoch_30.png ├── epoch_31.png ├── epoch_32.png ├── epoch_33.png ├── epoch_34.png ├── epoch_35.png ├── epoch_36.png ├── epoch_37.png ├── epoch_38.png ├── epoch_39.png ├── epoch_4.png ├── epoch_40.png ├── epoch_41.png ├── epoch_42.png ├── epoch_43.png ├── epoch_44.png ├── epoch_45.png ├── epoch_46.png ├── epoch_47.png ├── epoch_48.png ├── epoch_49.png ├── epoch_5.png ├── epoch_50.png ├── epoch_51.png ├── epoch_52.png ├── epoch_53.png ├── epoch_54.png ├── epoch_55.png ├── epoch_56.png ├── epoch_57.png ├── epoch_58.png ├── epoch_59.png ├── epoch_6.png ├── epoch_60.png ├── epoch_61.png ├── epoch_62.png ├── epoch_63.png ├── epoch_64.png ├── epoch_65.png ├── epoch_66.png ├── epoch_67.png ├── epoch_68.png ├── epoch_69.png ├── epoch_7.png ├── epoch_70.png ├── epoch_71.png ├── epoch_72.png ├── epoch_73.png ├── epoch_74.png ├── epoch_75.png ├── epoch_76.png ├── epoch_77.png ├── epoch_78.png ├── epoch_79.png ├── epoch_8.png ├── epoch_80.png └── epoch_9.png ├── train.py └── util ├── __init__.py ├── array_util.py ├── optim_util.py └── shell_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | __pycache__ 4 | .texpadtmp/ 5 | data/ 6 | logs/ 7 | ckpts/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Christopher Chute http://chrischute.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Glow in PyTorch 2 | 3 | ![CIFAR-10 Samples](/samples/epoch_80.png?raw=true "CIFAR-10 Samples") 4 | 5 | Implementation of Glow in PyTorch. Based on the paper: 6 | 7 | > [Glow: Generative Flow with Invertible 1x1 Convolutions](https://arxiv.org/abs/1807.03039)\ 8 | > Diederik P. Kingma, Prafulla Dhariwal\ 9 | > _arXiv:1807.03039_ 10 | 11 | Training script and hyperparameters designed to match the 12 | CIFAR-10 experiments described in Table 4 of the paper. 13 | 14 | 15 | ## Usage 16 | 17 | ### Environment Setup 18 | 1. Make sure you have [Anaconda or Miniconda](https://conda.io/docs/download.html) 19 | installed. 20 | 2. Clone repo with `git clone https://github.com/chrischute/glow.git glow`. 21 | 3. Go into the cloned repo: `cd glow`. 22 | 4. Create the environment: `conda env create -f environment.yml`. 23 | 5. Activate the environment: `source activate glow`. 24 | 25 | ### Train 26 | 1. Make sure you've created and activated the conda environment as described above. 27 | 2. Run `python train.py -h` to see options. 28 | 3. Run `python train.py [FLAGS]` to train. *E.g.,* run 29 | `python train.py` for the default configuration, or run 30 | `python train.py --gpu_ids=0,1` to run on 31 | 2 GPUs instead of the default of 1 GPU. This will also double the batch size. 32 | 4. At the end of each epoch, samples from the model will be saved to 33 | `samples/epoch_N.png`, where `N` is the epoch number. 34 | 35 | 36 | A single epoch takes about 30 minutes with the default hyperparameters (K=32, L=3, C=512) on two 1080 Ti's. 37 | 38 | 39 | ## Samples (K=16, L=3, C=512) 40 | 41 | ### Epoch 10 42 | 43 | ![Samples at Epoch 10](/samples/epoch_10.png?raw=true "Samples at Epoch 10") 44 | 45 | 46 | ### Epoch 20 47 | 48 | ![Samples at Epoch 20](/samples/epoch_20.png?raw=true "Samples at Epoch 20") 49 | 50 | 51 | ### Epoch 30 52 | 53 | ![Samples at Epoch 30](/samples/epoch_30.png?raw=true "Samples at Epoch 30") 54 | 55 | 56 | ### Epoch 40 57 | 58 | ![Samples at Epoch 40](/samples/epoch_40.png?raw=true "Samples at Epoch 40") 59 | 60 | 61 | ### Epoch 50 62 | 63 | ![Samples at Epoch 50](/samples/epoch_50.png?raw=true "Samples at Epoch 50") 64 | 65 | 66 | ### Epoch 60 67 | 68 | ![Samples at Epoch 60](/samples/epoch_60.png?raw=true "Samples at Epoch 60") 69 | 70 | 71 | ### Epoch 70 72 | 73 | ![Samples at Epoch 70](/samples/epoch_70.png?raw=true "Samples at Epoch 70") 74 | 75 | 76 | ### Epoch 80 77 | 78 | ![Samples at Epoch 80](/samples/epoch_80.png?raw=true "Samples at Epoch 80") 79 | 80 | 81 | More samples can be found in the `samples` folder. 82 | 83 | 84 | ## Results (K=32, L=3 C=512) 85 | 86 | ### Bits per Dimension 87 | 88 | | Epoch | Train | Valid | 89 | |-------|-------|-------| 90 | | 10 | 3.64 | 3.63 | 91 | | 20 | 3.51 | 3.56 | 92 | | 30 | 3.46 | 3.53 | 93 | | 40 | 3.43 | 3.51 | 94 | | 50 | 3.42 | 3.50 | 95 | | 60 | 3.40 | 3.51 | 96 | | 70 | 3.39 | 3.49 | 97 | | 80 | 3.38 | 3.49 | 98 | 99 | ## Gradient Checkpointing 100 | 101 | As pointed out by [AlexanderMath](https://github.com/AlexanderMath), you can use gradient checkpointing to reduce memory consumption in the coupling layers. If interested, see [this issue](https://github.com/chrischute/glow/issues/8). 102 | -------------------------------------------------------------------------------- /ckpts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/ckpts/.gitkeep -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/data/.gitkeep -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: glow 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2019.1.23=0 8 | - certifi=2019.3.9=py37_0 9 | - cffi=1.12.3=py37h2e261b9_0 10 | - cuda90=1.0=h6433d27_0 11 | - cudatoolkit=9.0=h13b8566_0 12 | - freetype=2.9.1=h8a8886c_1 13 | - intel-openmp=2019.3=199 14 | - jpeg=9b=h024ee3a_2 15 | - libedit=3.1.20181209=hc058e9b_0 16 | - libffi=3.2.1=hd88cf55_4 17 | - libgcc-ng=8.2.0=hdf63c60_1 18 | - libgfortran-ng=7.3.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=8.2.0=hdf63c60_1 21 | - libtiff=4.0.10=h2733197_2 22 | - mkl=2019.3=199 23 | - mkl_fft=1.0.12=py37ha843d7b_0 24 | - mkl_random=1.0.2=py37hd81dba3_0 25 | - ncurses=6.1=he6710b0_1 26 | - ninja=1.9.0=py37hfd86e86_0 27 | - numpy=1.16.3=py37h7e9f1db_0 28 | - numpy-base=1.16.3=py37hde5b4d6_0 29 | - olefile=0.46=py37_0 30 | - openssl=1.1.1b=h7b6447c_1 31 | - pillow=6.0.0=py37h34e0f95_0 32 | - pip=19.1.1=py37_0 33 | - pycparser=2.19=py37_0 34 | - python=3.7.3=h0371630_0 35 | - pytorch=1.0.0=py3.7_cuda9.0.176_cudnn7.4.1_1 36 | - readline=7.0=h7b6447c_5 37 | - setuptools=41.0.1=py37_0 38 | - six=1.12.0=py37_0 39 | - sqlite=3.28.0=h7b6447c_0 40 | - tk=8.6.8=hbc83047_0 41 | - torchvision=0.2.2=py_3 42 | - tqdm=4.31.1=py37_1 43 | - wheel=0.33.4=py37_0 44 | - xz=5.2.4=h14c3975_4 45 | - zlib=1.2.11=h7b6447c_3 46 | - zstd=1.3.7=h0b5b093_0 47 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.glow import Glow 2 | -------------------------------------------------------------------------------- /models/glow/__init__.py: -------------------------------------------------------------------------------- 1 | from models.glow.glow import Glow 2 | -------------------------------------------------------------------------------- /models/glow/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from util import mean_dim 5 | 6 | 7 | class ActNorm(nn.Module): 8 | """Activation normalization for 2D inputs. 9 | 10 | The bias and scale get initialized using the mean and variance of the 11 | first mini-batch. After the init, bias and scale are trainable parameters. 12 | 13 | Adapted from: 14 | > https://github.com/openai/glow 15 | """ 16 | def __init__(self, num_features, scale=1., return_ldj=False): 17 | super(ActNorm, self).__init__() 18 | self.register_buffer('is_initialized', torch.zeros(1)) 19 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 20 | self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 21 | 22 | self.num_features = num_features 23 | self.scale = float(scale) 24 | self.eps = 1e-6 25 | self.return_ldj = return_ldj 26 | 27 | def initialize_parameters(self, x): 28 | if not self.training: 29 | return 30 | 31 | with torch.no_grad(): 32 | bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True) 33 | v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True) 34 | logs = (self.scale / (v.sqrt() + self.eps)).log() 35 | self.bias.data.copy_(bias.data) 36 | self.logs.data.copy_(logs.data) 37 | self.is_initialized += 1. 38 | 39 | def _center(self, x, reverse=False): 40 | if reverse: 41 | return x - self.bias 42 | else: 43 | return x + self.bias 44 | 45 | def _scale(self, x, sldj, reverse=False): 46 | logs = self.logs 47 | if reverse: 48 | x = x * logs.mul(-1).exp() 49 | else: 50 | x = x * logs.exp() 51 | 52 | if sldj is not None: 53 | ldj = logs.sum() * x.size(2) * x.size(3) 54 | if reverse: 55 | sldj = sldj - ldj 56 | else: 57 | sldj = sldj + ldj 58 | 59 | return x, sldj 60 | 61 | def forward(self, x, ldj=None, reverse=False): 62 | if not self.is_initialized: 63 | self.initialize_parameters(x) 64 | 65 | if reverse: 66 | x, ldj = self._scale(x, ldj, reverse) 67 | x = self._center(x, reverse) 68 | else: 69 | x = self._center(x, reverse) 70 | x, ldj = self._scale(x, ldj, reverse) 71 | 72 | if self.return_ldj: 73 | return x, ldj 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /models/glow/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.glow.act_norm import ActNorm 6 | 7 | 8 | class Coupling(nn.Module): 9 | """Affine coupling layer originally used in Real NVP and described by Glow. 10 | 11 | Note: The official Glow implementation (https://github.com/openai/glow) 12 | uses a different affine coupling formulation than described in the paper. 13 | This implementation follows the paper and Real NVP. 14 | 15 | Args: 16 | in_channels (int): Number of channels in the input. 17 | mid_channels (int): Number of channels in the intermediate activation 18 | in NN. 19 | """ 20 | def __init__(self, in_channels, mid_channels): 21 | super(Coupling, self).__init__() 22 | self.nn = NN(in_channels, mid_channels, 2 * in_channels) 23 | self.scale = nn.Parameter(torch.ones(in_channels, 1, 1)) 24 | 25 | def forward(self, x, ldj, reverse=False): 26 | x_change, x_id = x.chunk(2, dim=1) 27 | 28 | st = self.nn(x_id) 29 | s, t = st[:, 0::2, ...], st[:, 1::2, ...] 30 | s = self.scale * torch.tanh(s) 31 | 32 | # Scale and translate 33 | if reverse: 34 | x_change = x_change * s.mul(-1).exp() - t 35 | ldj = ldj - s.flatten(1).sum(-1) 36 | else: 37 | x_change = (x_change + t) * s.exp() 38 | ldj = ldj + s.flatten(1).sum(-1) 39 | 40 | x = torch.cat((x_change, x_id), dim=1) 41 | 42 | return x, ldj 43 | 44 | 45 | class NN(nn.Module): 46 | """Small convolutional network used to compute scale and translate factors. 47 | 48 | Args: 49 | in_channels (int): Number of channels in the input. 50 | mid_channels (int): Number of channels in the hidden activations. 51 | out_channels (int): Number of channels in the output. 52 | use_act_norm (bool): Use activation norm rather than batch norm. 53 | """ 54 | def __init__(self, in_channels, mid_channels, out_channels, 55 | use_act_norm=False): 56 | super(NN, self).__init__() 57 | norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d 58 | 59 | self.in_norm = norm_fn(in_channels) 60 | self.in_conv = nn.Conv2d(in_channels, mid_channels, 61 | kernel_size=3, padding=1, bias=False) 62 | nn.init.normal_(self.in_conv.weight, 0., 0.05) 63 | 64 | self.mid_norm = norm_fn(mid_channels) 65 | self.mid_conv = nn.Conv2d(mid_channels, mid_channels, 66 | kernel_size=1, padding=0, bias=False) 67 | nn.init.normal_(self.mid_conv.weight, 0., 0.05) 68 | 69 | self.out_norm = norm_fn(mid_channels) 70 | self.out_conv = nn.Conv2d(mid_channels, out_channels, 71 | kernel_size=3, padding=1, bias=True) 72 | nn.init.zeros_(self.out_conv.weight) 73 | nn.init.zeros_(self.out_conv.bias) 74 | 75 | def forward(self, x): 76 | x = self.in_norm(x) 77 | x = F.relu(x) 78 | x = self.in_conv(x) 79 | 80 | x = self.mid_norm(x) 81 | x = F.relu(x) 82 | x = self.mid_conv(x) 83 | 84 | x = self.out_norm(x) 85 | x = F.relu(x) 86 | x = self.out_conv(x) 87 | 88 | return x 89 | -------------------------------------------------------------------------------- /models/glow/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.glow.act_norm import ActNorm 6 | from models.glow.coupling import Coupling 7 | from models.glow.inv_conv import InvConv 8 | 9 | 10 | class Glow(nn.Module): 11 | """Glow Model 12 | 13 | Based on the paper: 14 | "Glow: Generative Flow with Invertible 1x1 Convolutions" 15 | by Diederik P. Kingma, Prafulla Dhariwal 16 | (https://arxiv.org/abs/1807.03039). 17 | 18 | Args: 19 | num_channels (int): Number of channels in middle convolution of each 20 | step of flow. 21 | num_levels (int): Number of levels in the entire model. 22 | num_steps (int): Number of steps of flow for each level. 23 | """ 24 | def __init__(self, num_channels, num_levels, num_steps): 25 | super(Glow, self).__init__() 26 | 27 | # Use bounds to rescale images before converting to logits, not learned 28 | self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32)) 29 | self.flows = _Glow(in_channels=4 * 3, # RGB image after squeeze 30 | mid_channels=num_channels, 31 | num_levels=num_levels, 32 | num_steps=num_steps) 33 | 34 | def forward(self, x, reverse=False): 35 | if reverse: 36 | sldj = torch.zeros(x.size(0), device=x.device) 37 | else: 38 | # Expect inputs in [0, 1] 39 | if x.min() < 0 or x.max() > 1: 40 | raise ValueError('Expected x in [0, 1], got min/max {}/{}' 41 | .format(x.min(), x.max())) 42 | 43 | # De-quantize and convert to logits 44 | x, sldj = self._pre_process(x) 45 | 46 | x = squeeze(x) 47 | x, sldj = self.flows(x, sldj, reverse) 48 | x = squeeze(x, reverse=True) 49 | 50 | return x, sldj 51 | 52 | def _pre_process(self, x): 53 | """Dequantize the input image `x` and convert to logits. 54 | 55 | See Also: 56 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1 57 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1 58 | 59 | Args: 60 | x (torch.Tensor): Input image. 61 | 62 | Returns: 63 | y (torch.Tensor): Dequantized logits of `x`. 64 | """ 65 | y = (x * 255. + torch.rand_like(x)) / 256. 66 | y = (2 * y - 1) * self.bounds 67 | y = (y + 1) / 2 68 | y = y.log() - (1. - y).log() 69 | 70 | # Save log-determinant of Jacobian of initial transform 71 | ldj = F.softplus(y) + F.softplus(-y) \ 72 | - F.softplus((1. - self.bounds).log() - self.bounds.log()) 73 | sldj = ldj.flatten(1).sum(-1) 74 | 75 | return y, sldj 76 | 77 | 78 | class _Glow(nn.Module): 79 | """Recursive constructor for a Glow model. Each call creates a single level. 80 | 81 | Args: 82 | in_channels (int): Number of channels in the input. 83 | mid_channels (int): Number of channels in hidden layers of each step. 84 | num_levels (int): Number of levels to construct. Counter for recursion. 85 | num_steps (int): Number of steps of flow for each level. 86 | """ 87 | def __init__(self, in_channels, mid_channels, num_levels, num_steps): 88 | super(_Glow, self).__init__() 89 | self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels, 90 | mid_channels=mid_channels) 91 | for _ in range(num_steps)]) 92 | 93 | if num_levels > 1: 94 | self.next = _Glow(in_channels=2 * in_channels, 95 | mid_channels=mid_channels, 96 | num_levels=num_levels - 1, 97 | num_steps=num_steps) 98 | else: 99 | self.next = None 100 | 101 | def forward(self, x, sldj, reverse=False): 102 | if not reverse: 103 | for step in self.steps: 104 | x, sldj = step(x, sldj, reverse) 105 | 106 | if self.next is not None: 107 | x = squeeze(x) 108 | x, x_split = x.chunk(2, dim=1) 109 | x, sldj = self.next(x, sldj, reverse) 110 | x = torch.cat((x, x_split), dim=1) 111 | x = squeeze(x, reverse=True) 112 | 113 | if reverse: 114 | for step in reversed(self.steps): 115 | x, sldj = step(x, sldj, reverse) 116 | 117 | return x, sldj 118 | 119 | 120 | class _FlowStep(nn.Module): 121 | def __init__(self, in_channels, mid_channels): 122 | super(_FlowStep, self).__init__() 123 | 124 | # Activation normalization, invertible 1x1 convolution, affine coupling 125 | self.norm = ActNorm(in_channels, return_ldj=True) 126 | self.conv = InvConv(in_channels) 127 | self.coup = Coupling(in_channels // 2, mid_channels) 128 | 129 | def forward(self, x, sldj=None, reverse=False): 130 | if reverse: 131 | x, sldj = self.coup(x, sldj, reverse) 132 | x, sldj = self.conv(x, sldj, reverse) 133 | x, sldj = self.norm(x, sldj, reverse) 134 | else: 135 | x, sldj = self.norm(x, sldj, reverse) 136 | x, sldj = self.conv(x, sldj, reverse) 137 | x, sldj = self.coup(x, sldj, reverse) 138 | 139 | return x, sldj 140 | 141 | 142 | def squeeze(x, reverse=False): 143 | """Trade spatial extent for channels. In forward direction, convert each 144 | 1x4x4 volume of input into a 4x1x1 volume of output. 145 | 146 | Args: 147 | x (torch.Tensor): Input to squeeze or unsqueeze. 148 | reverse (bool): Reverse the operation, i.e., unsqueeze. 149 | 150 | Returns: 151 | x (torch.Tensor): Squeezed or unsqueezed tensor. 152 | """ 153 | b, c, h, w = x.size() 154 | if reverse: 155 | # Unsqueeze 156 | x = x.view(b, c // 4, 2, 2, h, w) 157 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 158 | x = x.view(b, c // 4, h * 2, w * 2) 159 | else: 160 | # Squeeze 161 | x = x.view(b, c, h // 2, 2, w // 2, 2) 162 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 163 | x = x.view(b, c * 2 * 2, h // 2, w // 2) 164 | 165 | return x 166 | -------------------------------------------------------------------------------- /models/glow/inv_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class InvConv(nn.Module): 8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow 9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. 10 | 11 | Args: 12 | num_channels (int): Number of channels in the input and output. 13 | """ 14 | def __init__(self, num_channels): 15 | super(InvConv, self).__init__() 16 | self.num_channels = num_channels 17 | 18 | # Initialize with a random orthogonal matrix 19 | w_init = np.random.randn(num_channels, num_channels) 20 | w_init = np.linalg.qr(w_init)[0].astype(np.float32) 21 | self.weight = nn.Parameter(torch.from_numpy(w_init)) 22 | 23 | def forward(self, x, sldj, reverse=False): 24 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) 25 | 26 | if reverse: 27 | weight = torch.inverse(self.weight.double()).float() 28 | sldj = sldj - ldj 29 | else: 30 | weight = self.weight 31 | sldj = sldj + ldj 32 | 33 | weight = weight.view(self.num_channels, self.num_channels, 1, 1) 34 | z = F.conv2d(x, weight) 35 | 36 | return z, sldj 37 | -------------------------------------------------------------------------------- /samples/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/.gitkeep -------------------------------------------------------------------------------- /samples/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_0.png -------------------------------------------------------------------------------- /samples/epoch_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_1.png -------------------------------------------------------------------------------- /samples/epoch_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_10.png -------------------------------------------------------------------------------- /samples/epoch_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_11.png -------------------------------------------------------------------------------- /samples/epoch_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_12.png -------------------------------------------------------------------------------- /samples/epoch_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_13.png -------------------------------------------------------------------------------- /samples/epoch_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_14.png -------------------------------------------------------------------------------- /samples/epoch_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_15.png -------------------------------------------------------------------------------- /samples/epoch_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_16.png -------------------------------------------------------------------------------- /samples/epoch_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_17.png -------------------------------------------------------------------------------- /samples/epoch_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_18.png -------------------------------------------------------------------------------- /samples/epoch_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_19.png -------------------------------------------------------------------------------- /samples/epoch_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_2.png -------------------------------------------------------------------------------- /samples/epoch_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_20.png -------------------------------------------------------------------------------- /samples/epoch_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_21.png -------------------------------------------------------------------------------- /samples/epoch_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_22.png -------------------------------------------------------------------------------- /samples/epoch_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_23.png -------------------------------------------------------------------------------- /samples/epoch_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_24.png -------------------------------------------------------------------------------- /samples/epoch_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_25.png -------------------------------------------------------------------------------- /samples/epoch_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_26.png -------------------------------------------------------------------------------- /samples/epoch_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_27.png -------------------------------------------------------------------------------- /samples/epoch_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_28.png -------------------------------------------------------------------------------- /samples/epoch_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_29.png -------------------------------------------------------------------------------- /samples/epoch_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_3.png -------------------------------------------------------------------------------- /samples/epoch_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_30.png -------------------------------------------------------------------------------- /samples/epoch_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_31.png -------------------------------------------------------------------------------- /samples/epoch_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_32.png -------------------------------------------------------------------------------- /samples/epoch_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_33.png -------------------------------------------------------------------------------- /samples/epoch_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_34.png -------------------------------------------------------------------------------- /samples/epoch_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_35.png -------------------------------------------------------------------------------- /samples/epoch_36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_36.png -------------------------------------------------------------------------------- /samples/epoch_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_37.png -------------------------------------------------------------------------------- /samples/epoch_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_38.png -------------------------------------------------------------------------------- /samples/epoch_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_39.png -------------------------------------------------------------------------------- /samples/epoch_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_4.png -------------------------------------------------------------------------------- /samples/epoch_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_40.png -------------------------------------------------------------------------------- /samples/epoch_41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_41.png -------------------------------------------------------------------------------- /samples/epoch_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_42.png -------------------------------------------------------------------------------- /samples/epoch_43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_43.png -------------------------------------------------------------------------------- /samples/epoch_44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_44.png -------------------------------------------------------------------------------- /samples/epoch_45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_45.png -------------------------------------------------------------------------------- /samples/epoch_46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_46.png -------------------------------------------------------------------------------- /samples/epoch_47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_47.png -------------------------------------------------------------------------------- /samples/epoch_48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_48.png -------------------------------------------------------------------------------- /samples/epoch_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_49.png -------------------------------------------------------------------------------- /samples/epoch_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_5.png -------------------------------------------------------------------------------- /samples/epoch_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_50.png -------------------------------------------------------------------------------- /samples/epoch_51.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_51.png -------------------------------------------------------------------------------- /samples/epoch_52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_52.png -------------------------------------------------------------------------------- /samples/epoch_53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_53.png -------------------------------------------------------------------------------- /samples/epoch_54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_54.png -------------------------------------------------------------------------------- /samples/epoch_55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_55.png -------------------------------------------------------------------------------- /samples/epoch_56.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_56.png -------------------------------------------------------------------------------- /samples/epoch_57.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_57.png -------------------------------------------------------------------------------- /samples/epoch_58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_58.png -------------------------------------------------------------------------------- /samples/epoch_59.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_59.png -------------------------------------------------------------------------------- /samples/epoch_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_6.png -------------------------------------------------------------------------------- /samples/epoch_60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_60.png -------------------------------------------------------------------------------- /samples/epoch_61.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_61.png -------------------------------------------------------------------------------- /samples/epoch_62.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_62.png -------------------------------------------------------------------------------- /samples/epoch_63.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_63.png -------------------------------------------------------------------------------- /samples/epoch_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_64.png -------------------------------------------------------------------------------- /samples/epoch_65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_65.png -------------------------------------------------------------------------------- /samples/epoch_66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_66.png -------------------------------------------------------------------------------- /samples/epoch_67.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_67.png -------------------------------------------------------------------------------- /samples/epoch_68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_68.png -------------------------------------------------------------------------------- /samples/epoch_69.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_69.png -------------------------------------------------------------------------------- /samples/epoch_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_7.png -------------------------------------------------------------------------------- /samples/epoch_70.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_70.png -------------------------------------------------------------------------------- /samples/epoch_71.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_71.png -------------------------------------------------------------------------------- /samples/epoch_72.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_72.png -------------------------------------------------------------------------------- /samples/epoch_73.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_73.png -------------------------------------------------------------------------------- /samples/epoch_74.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_74.png -------------------------------------------------------------------------------- /samples/epoch_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_75.png -------------------------------------------------------------------------------- /samples/epoch_76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_76.png -------------------------------------------------------------------------------- /samples/epoch_77.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_77.png -------------------------------------------------------------------------------- /samples/epoch_78.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_78.png -------------------------------------------------------------------------------- /samples/epoch_79.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_79.png -------------------------------------------------------------------------------- /samples/epoch_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_8.png -------------------------------------------------------------------------------- /samples/epoch_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_80.png -------------------------------------------------------------------------------- /samples/epoch_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischute/glow/faffa5ba02f878902a211db76c0bd4ea074b39f7/samples/epoch_9.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train Glow on CIFAR-10. 2 | 3 | Train script adapted from: https://github.com/kuangliu/pytorch-cifar/ 4 | """ 5 | import argparse 6 | import numpy as np 7 | import os 8 | import random 9 | import torch 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as sched 12 | import torch.backends.cudnn as cudnn 13 | import torch.utils.data as data 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | import util 17 | 18 | from models import Glow 19 | from tqdm import tqdm 20 | 21 | 22 | def main(args): 23 | # Set up main device and scale batch size 24 | device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu' 25 | args.batch_size *= max(1, len(args.gpu_ids)) 26 | 27 | # Set random seeds 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | torch.manual_seed(args.seed) 31 | torch.cuda.manual_seed_all(args.seed) 32 | 33 | # No normalization applied, since Glow expects inputs in (0, 1) 34 | transform_train = transforms.Compose([ 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor() 37 | ]) 38 | 39 | transform_test = transforms.Compose([ 40 | transforms.ToTensor() 41 | ]) 42 | 43 | trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train) 44 | trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 45 | 46 | testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test) 47 | testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 48 | 49 | # Model 50 | print('Building model..') 51 | net = Glow(num_channels=args.num_channels, 52 | num_levels=args.num_levels, 53 | num_steps=args.num_steps) 54 | net = net.to(device) 55 | if device == 'cuda': 56 | net = torch.nn.DataParallel(net, args.gpu_ids) 57 | cudnn.benchmark = args.benchmark 58 | 59 | start_epoch = 0 60 | if args.resume: 61 | # Load checkpoint. 62 | print('Resuming from checkpoint at ckpts/best.pth.tar...') 63 | assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!' 64 | checkpoint = torch.load('ckpts/best.pth.tar') 65 | net.load_state_dict(checkpoint['net']) 66 | global best_loss 67 | global global_step 68 | best_loss = checkpoint['test_loss'] 69 | start_epoch = checkpoint['epoch'] 70 | global_step = start_epoch * len(trainset) 71 | 72 | loss_fn = util.NLLLoss().to(device) 73 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 74 | scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / args.warm_up)) 75 | 76 | for epoch in range(start_epoch, start_epoch + args.num_epochs): 77 | train(epoch, net, trainloader, device, optimizer, scheduler, 78 | loss_fn, args.max_grad_norm) 79 | test(epoch, net, testloader, device, loss_fn, args.num_samples) 80 | 81 | 82 | @torch.enable_grad() 83 | def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm): 84 | global global_step 85 | print('\nEpoch: %d' % epoch) 86 | net.train() 87 | loss_meter = util.AverageMeter() 88 | with tqdm(total=len(trainloader.dataset)) as progress_bar: 89 | for x, _ in trainloader: 90 | x = x.to(device) 91 | optimizer.zero_grad() 92 | z, sldj = net(x, reverse=False) 93 | loss = loss_fn(z, sldj) 94 | loss_meter.update(loss.item(), x.size(0)) 95 | loss.backward() 96 | if max_grad_norm > 0: 97 | util.clip_grad_norm(optimizer, max_grad_norm) 98 | optimizer.step() 99 | scheduler.step(global_step) 100 | 101 | progress_bar.set_postfix(nll=loss_meter.avg, 102 | bpd=util.bits_per_dim(x, loss_meter.avg), 103 | lr=optimizer.param_groups[0]['lr']) 104 | progress_bar.update(x.size(0)) 105 | global_step += x.size(0) 106 | 107 | 108 | @torch.no_grad() 109 | def sample(net, batch_size, device): 110 | """Sample from RealNVP model. 111 | 112 | Args: 113 | net (torch.nn.DataParallel): The RealNVP model wrapped in DataParallel. 114 | batch_size (int): Number of samples to generate. 115 | device (torch.device): Device to use. 116 | """ 117 | z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device) 118 | x, _ = net(z, reverse=True) 119 | x = torch.sigmoid(x) 120 | 121 | return x 122 | 123 | 124 | @torch.no_grad() 125 | def test(epoch, net, testloader, device, loss_fn, num_samples): 126 | global best_loss 127 | net.eval() 128 | loss_meter = util.AverageMeter() 129 | with tqdm(total=len(testloader.dataset)) as progress_bar: 130 | for x, _ in testloader: 131 | x = x.to(device) 132 | z, sldj = net(x, reverse=False) 133 | loss = loss_fn(z, sldj) 134 | loss_meter.update(loss.item(), x.size(0)) 135 | progress_bar.set_postfix(nll=loss_meter.avg, 136 | bpd=util.bits_per_dim(x, loss_meter.avg)) 137 | progress_bar.update(x.size(0)) 138 | 139 | # Save checkpoint 140 | if loss_meter.avg < best_loss: 141 | print('Saving...') 142 | state = { 143 | 'net': net.state_dict(), 144 | 'test_loss': loss_meter.avg, 145 | 'epoch': epoch, 146 | } 147 | os.makedirs('ckpts', exist_ok=True) 148 | torch.save(state, 'ckpts/best.pth.tar') 149 | best_loss = loss_meter.avg 150 | 151 | # Save samples and data 152 | images = sample(net, num_samples, device) 153 | os.makedirs('samples', exist_ok=True) 154 | images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255) 155 | torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch)) 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = argparse.ArgumentParser(description='Glow on CIFAR-10') 160 | 161 | def str2bool(s): 162 | return s.lower().startswith('t') 163 | 164 | parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU') 165 | parser.add_argument('--benchmark', type=str2bool, default=True, help='Turn on CUDNN benchmarking') 166 | parser.add_argument('--gpu_ids', default=[0], type=eval, help='IDs of GPUs to use') 167 | parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') 168 | parser.add_argument('--max_grad_norm', type=float, default=-1., help='Max gradient norm for clipping') 169 | parser.add_argument('--num_channels', '-C', default=512, type=int, help='Number of channels in hidden layers') 170 | parser.add_argument('--num_levels', '-L', default=3, type=int, help='Number of levels in the Glow model') 171 | parser.add_argument('--num_steps', '-K', default=32, type=int, help='Number of steps of flow in each level') 172 | parser.add_argument('--num_epochs', default=100, type=int, help='Number of epochs to train') 173 | parser.add_argument('--num_samples', default=64, type=int, help='Number of samples at test time') 174 | parser.add_argument('--num_workers', default=8, type=int, help='Number of data loader threads') 175 | parser.add_argument('--resume', type=str2bool, default=False, help='Resume from checkpoint') 176 | parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility') 177 | parser.add_argument('--warm_up', default=500000, type=int, help='Number of steps for lr warm-up') 178 | 179 | best_loss = 0 180 | global_step = 0 181 | 182 | main(parser.parse_args()) 183 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from util.array_util import mean_dim 2 | from util.optim_util import bits_per_dim, clip_grad_norm, NLLLoss 3 | from util.shell_util import AverageMeter 4 | -------------------------------------------------------------------------------- /util/array_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_dim(tensor, dim=None, keepdims=False): 5 | """Take the mean along multiple dimensions. 6 | 7 | Args: 8 | tensor (torch.Tensor): Tensor of values to average. 9 | dim (list): List of dimensions along which to take the mean. 10 | keepdims (bool): Keep dimensions rather than squeezing. 11 | 12 | Returns: 13 | mean (torch.Tensor): New tensor of mean value(s). 14 | """ 15 | if dim is None: 16 | return tensor.mean() 17 | else: 18 | if isinstance(dim, int): 19 | dim = [dim] 20 | dim = sorted(dim) 21 | for d in dim: 22 | tensor = tensor.mean(dim=d, keepdim=True) 23 | if not keepdims: 24 | for i, d in enumerate(dim): 25 | tensor.squeeze_(d-i) 26 | return tensor 27 | -------------------------------------------------------------------------------- /util/optim_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.utils as utils 4 | 5 | 6 | def bits_per_dim(x, nll): 7 | """Get the bits per dimension implied by using model with `loss` 8 | for compressing `x`, assuming each entry can take on `k` discrete values. 9 | 10 | Args: 11 | x (torch.Tensor): Input to the model. Just used for dimensions. 12 | nll (torch.Tensor): Scalar negative log-likelihood loss tensor. 13 | 14 | Returns: 15 | bpd (torch.Tensor): Bits per dimension implied if compressing `x`. 16 | """ 17 | dim = np.prod(x.size()[1:]) 18 | bpd = nll / (np.log(2) * dim) 19 | 20 | return bpd 21 | 22 | 23 | def clip_grad_norm(optimizer, max_norm, norm_type=2): 24 | """Clip the norm of the gradients for all parameters under `optimizer`. 25 | 26 | Args: 27 | optimizer (torch.optim.Optimizer): 28 | max_norm (float): The maximum allowable norm of gradients. 29 | norm_type (int): The type of norm to use in computing gradient norms. 30 | """ 31 | for group in optimizer.param_groups: 32 | utils.clip_grad_norm_(group['params'], max_norm, norm_type) 33 | 34 | 35 | class NLLLoss(nn.Module): 36 | """Negative log-likelihood loss assuming isotropic gaussian with unit norm. 37 | 38 | Args: 39 | k (int or float): Number of discrete values in each input dimension. 40 | E.g., `k` is 256 for natural images. 41 | 42 | See Also: 43 | Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 44 | """ 45 | def __init__(self, k=256): 46 | super(NLLLoss, self).__init__() 47 | self.k = k 48 | 49 | def forward(self, z, sldj): 50 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 51 | prior_ll = prior_ll.flatten(1).sum(-1) \ 52 | - np.log(self.k) * np.prod(z.size()[1:]) 53 | ll = prior_ll + sldj 54 | nll = -ll.mean() 55 | 56 | return nll 57 | -------------------------------------------------------------------------------- /util/shell_util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value. 3 | 4 | Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py 5 | """ 6 | def __init__(self): 7 | self.val = 0. 8 | self.avg = 0. 9 | self.sum = 0. 10 | self.count = 0. 11 | 12 | def reset(self): 13 | self.val = 0. 14 | self.avg = 0. 15 | self.sum = 0. 16 | self.count = 0. 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | --------------------------------------------------------------------------------