├── README.md ├── invertible_layers.py ├── layers.py ├── test_flows.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Glow 2 | Pytorch implementation of OpenAI's generative model [GLOW](https://github.com/openai/glow). This repo provides a modular approach for stacking invertible transformations. 3 | 4 | ## Running Code 5 | ``` 6 | python train.py 7 | ``` 8 | e.g. 9 | ``` 10 | CUDA_VISIBLE_DEVICES=0 python train.py --depth 10 --coupling affine --batch_size 64 --print_every 100 --permutation conv 11 | ``` 12 | ## TODOs 13 | - [ ] Multi-GPU support. If performance is an issue for you, I encourage you to checkout [this](https://github.com/chaiyujin/glow-pytorch) pytorch implementation. 14 | - [ ] Support for more datasets 15 | - [ ] LU-decomposed invertible convolution. 16 | 17 | ### Contact 18 | This repository is no longer maintained. Feel free to file an issue if need be, however response may be slow. 19 | -------------------------------------------------------------------------------- /invertible_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as wn 5 | 6 | import numpy as np 7 | import pdb 8 | 9 | from layers import * 10 | from utils import * 11 | 12 | # ------------------------------------------------------------------------------ 13 | # Abstract Classes to define common interface for invertible functions 14 | # ------------------------------------------------------------------------------ 15 | 16 | # Abstract Class for bijective functions 17 | class Layer(nn.Module): 18 | def __init__(self): 19 | super(Layer, self).__init__() 20 | 21 | def forward_(self, x, objective): 22 | raise NotImplementedError 23 | 24 | def reverse_(self, y, objective): 25 | raise NotImplementedError 26 | 27 | # Wrapper for stacking multiple layers 28 | class LayerList(Layer): 29 | def __init__(self, list_of_layers=None): 30 | super(LayerList, self).__init__() 31 | self.layers = nn.ModuleList(list_of_layers) 32 | 33 | def __getitem__(self, i): 34 | return self.layers[i] 35 | 36 | def forward_(self, x, objective): 37 | for layer in self.layers: 38 | x, objective = layer.forward_(x, objective) 39 | return x, objective 40 | 41 | def reverse_(self, x, objective): 42 | for layer in reversed(self.layers): 43 | x, objective = layer.reverse_(x, objective) 44 | return x, objective 45 | 46 | 47 | # ------------------------------------------------------------------------------ 48 | # Permutation Layers 49 | # ------------------------------------------------------------------------------ 50 | 51 | # Shuffling on the channel axis 52 | class Shuffle(Layer): 53 | def __init__(self, num_channels): 54 | super(Shuffle, self).__init__() 55 | indices = np.arange(num_channels) 56 | np.random.shuffle(indices) 57 | rev_indices = np.zeros_like(indices) 58 | for i in range(num_channels): 59 | rev_indices[indices[i]] = i 60 | 61 | indices = torch.from_numpy(indices).long() 62 | rev_indices = torch.from_numpy(rev_indices).long() 63 | self.register_buffer('indices', indices) 64 | self.register_buffer('rev_indices', rev_indices) 65 | # self.indices, self.rev_indices = indices.cuda(), rev_indices.cuda() 66 | 67 | def forward_(self, x, objective): 68 | return x[:, self.indices], objective 69 | 70 | def reverse_(self, x, objective): 71 | return x[:, self.rev_indices], objective 72 | 73 | # Reversing on the channel axis 74 | class Reverse(Shuffle): 75 | def __init__(self, num_channels): 76 | super(Reverse, self).__init__(num_channels) 77 | indices = np.copy(np.arange(num_channels)[::-1]) 78 | indices = torch.from_numpy(indices).long() 79 | self.indices.copy_(indices) 80 | self.rev_indices.copy_(indices) 81 | 82 | # Invertible 1x1 convolution 83 | class Invertible1x1Conv(Layer, nn.Conv2d): 84 | def __init__(self, num_channels): 85 | self.num_channels = num_channels 86 | nn.Conv2d.__init__(self, num_channels, num_channels, 1, bias=False) 87 | 88 | def reset_parameters(self): 89 | # initialization done with rotation matrix 90 | w_init = np.linalg.qr(np.random.randn(self.num_channels, self.num_channels))[0] 91 | w_init = torch.from_numpy(w_init.astype('float32')) 92 | w_init = w_init.unsqueeze(-1).unsqueeze(-1) 93 | self.weight.data.copy_(w_init) 94 | 95 | def forward_(self, x, objective): 96 | dlogdet = torch.det(self.weight.squeeze()).abs().log() * x.size(-2) * x.size(-1) 97 | objective += dlogdet 98 | output = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, \ 99 | self.dilation, self.groups) 100 | 101 | return output, objective 102 | 103 | def reverse_(self, x, objective): 104 | dlogdet = torch.det(self.weight.squeeze()).abs().log() * x.size(-2) * x.size(-1) 105 | objective -= dlogdet 106 | weight_inv = torch.inverse(self.weight.squeeze()).unsqueeze(-1).unsqueeze(-1) 107 | output = F.conv2d(x, weight_inv, self.bias, self.stride, self.padding, \ 108 | self.dilation, self.groups) 109 | return output, objective 110 | 111 | 112 | # ------------------------------------------------------------------------------ 113 | # Layers involving squeeze operations defined in RealNVP / Glow. 114 | # ------------------------------------------------------------------------------ 115 | 116 | # Trades space for depth and vice versa 117 | class Squeeze(Layer): 118 | def __init__(self, input_shape, factor=2): 119 | super(Squeeze, self).__init__() 120 | assert factor > 1 and isinstance(factor, int), 'no point of using this if factor <= 1' 121 | self.factor = factor 122 | self.input_shape = input_shape 123 | 124 | def squeeze_bchw(self, x): 125 | bs, c, h, w = x.size() 126 | assert h % self.factor == 0 and w % self.factor == 0, pdb.set_trace() 127 | 128 | # taken from https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py 129 | x = x.view(bs, c, h // self.factor, self.factor, w // self.factor, self.factor) 130 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 131 | x = x.view(bs, c * self.factor * self.factor, h // self.factor, w // self.factor) 132 | 133 | return x 134 | 135 | def unsqueeze_bchw(self, x): 136 | bs, c, h, w = x.size() 137 | assert c >= 4 and c % 4 == 0 138 | 139 | # taken from https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py 140 | x = x.view(bs, c // self.factor ** 2, self.factor, self.factor, h, w) 141 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 142 | x = x.view(bs, c // self.factor ** 2, h * self.factor, w * self.factor) 143 | return x 144 | 145 | def forward_(self, x, objective): 146 | if len(x.size()) != 4: 147 | raise NotImplementedError # Maybe ValueError would be more appropriate 148 | 149 | return self.squeeze_bchw(x), objective 150 | 151 | def reverse_(self, x, objective): 152 | if len(x.size()) != 4: 153 | raise NotImplementedError 154 | 155 | return self.unsqueeze_bchw(x), objective 156 | 157 | 158 | # ------------------------------------------------------------------------------ 159 | # Layers involving prior 160 | # ------------------------------------------------------------------------------ 161 | 162 | # Split Layer for multi-scale architecture. Factor of 2 hardcoded. 163 | class Split(Layer): 164 | def __init__(self, input_shape): 165 | super(Split, self).__init__() 166 | bs, c, h, w = input_shape 167 | self.conv_zero = Conv2dZeroInit(c // 2, c, 3, padding=(3 - 1) // 2) 168 | 169 | def split2d_prior(self, x): 170 | h = self.conv_zero(x) 171 | mean, logs = h[:, 0::2], h[:, 1::2] 172 | return gaussian_diag(mean, logs) 173 | 174 | def forward_(self, x, objective): 175 | bs, c, h, w = x.size() 176 | z1, z2 = torch.chunk(x, 2, dim=1) 177 | pz = self.split2d_prior(z1) 178 | self.sample = z2 179 | objective += pz.logp(z2) 180 | return z1, objective 181 | 182 | def reverse_(self, x, objective, use_stored_sample=False): 183 | pz = self.split2d_prior(x) 184 | z2 = self.sample if use_stored_sample else pz.sample() 185 | z = torch.cat([x, z2], dim=1) 186 | objective -= pz.logp(z2) 187 | return z, objective 188 | 189 | # Gaussian Prior that's compatible with the Layer framework 190 | class GaussianPrior(Layer): 191 | def __init__(self, input_shape, args): 192 | super(GaussianPrior, self).__init__() 193 | self.input_shape = input_shape 194 | if args.learntop: 195 | self.conv = Conv2dZeroInit(2 * input_shape[1], 2 * input_shape[1], 3, padding=(3 - 1) // 2) 196 | else: 197 | self.conv = None 198 | 199 | def forward_(self, x, objective): 200 | mean_and_logsd = torch.cat([torch.zeros_like(x) for _ in range(2)], dim=1) 201 | 202 | if self.conv: 203 | mean_and_logsd = self.conv(mean_and_logsd) 204 | 205 | mean, logsd = torch.chunk(mean_and_logsd, 2, dim=1) 206 | 207 | pz = gaussian_diag(mean, logsd) 208 | objective += pz.logp(x) 209 | 210 | # this way, you can encode and decode back the same image. 211 | return x, objective 212 | 213 | def reverse_(self, x, objective): 214 | bs, c, h, w = self.input_shape 215 | mean_and_logsd = torch.cuda.FloatTensor(bs, 2 * c, h, w).fill_(0.) 216 | 217 | if self.conv: 218 | mean_and_logsd = self.conv(mean_and_logsd) 219 | 220 | mean, logsd = torch.chunk(mean_and_logsd, 2, dim=1) 221 | pz = gaussian_diag(mean, logsd) 222 | z = pz.sample() if x is None else x 223 | objective -= pz.logp(z) 224 | 225 | # this way, you can encode and decode back the same image. 226 | return z, objective 227 | 228 | 229 | # ------------------------------------------------------------------------------ 230 | # Coupling Layers 231 | # ------------------------------------------------------------------------------ 232 | 233 | # Additive Coupling Layer 234 | class AdditiveCoupling(Layer): 235 | def __init__(self, num_features): 236 | super(AdditiveCoupling, self).__init__() 237 | assert num_features % 2 == 0 238 | self.NN = NN(num_features // 2) 239 | 240 | def forward_(self, x, objective): 241 | z1, z2 = torch.chunk(x, 2, dim=1) 242 | z2 += self.NN(z1) 243 | return torch.cat([z1, z2], dim=1), objective 244 | 245 | def reverse_(self, x, objective): 246 | z1, z2 = torch.chunk(x, 2, dim=1) 247 | z2 -= self.NN(z1) 248 | return torch.cat([z1, z2], dim=1), objective 249 | 250 | # Additive Coupling Layer 251 | class AffineCoupling(Layer): 252 | def __init__(self, num_features): 253 | super(AffineCoupling, self).__init__() 254 | # assert num_features % 2 == 0 255 | self.NN = NN(num_features // 2, channels_out=num_features) 256 | 257 | def forward_(self, x, objective): 258 | z1, z2 = torch.chunk(x, 2, dim=1) 259 | h = self.NN(z1) 260 | shift = h[:, 0::2] 261 | scale = F.sigmoid(h[:, 1::2] + 2.) 262 | z2 += shift 263 | z2 *= scale 264 | objective += flatten_sum(torch.log(scale)) 265 | 266 | return torch.cat([z1, z2], dim=1), objective 267 | 268 | def reverse_(self, x, objective): 269 | z1, z2 = torch.chunk(x, 2, dim=1) 270 | h = self.NN(z1) 271 | shift = h[:, 0::2] 272 | scale = F.sigmoid(h[:, 1::2] + 2.) 273 | z2 /= scale 274 | z2 -= shift 275 | objective -= flatten_sum(torch.log(scale)) 276 | return torch.cat([z1, z2], dim=1), objective 277 | 278 | 279 | # ------------------------------------------------------------------------------ 280 | # Normalizing Layers 281 | # ------------------------------------------------------------------------------ 282 | 283 | # ActNorm Layer with data-dependant init 284 | class ActNorm(Layer): 285 | def __init__(self, num_features, logscale_factor=1., scale=1.): 286 | super(Layer, self).__init__() 287 | self.initialized = False 288 | self.logscale_factor = logscale_factor 289 | self.scale = scale 290 | self.register_parameter('b', nn.Parameter(torch.zeros(1, num_features, 1))) 291 | self.register_parameter('logs', nn.Parameter(torch.zeros(1, num_features, 1))) 292 | 293 | def forward_(self, input, objective): 294 | input_shape = input.size() 295 | input = input.view(input_shape[0], input_shape[1], -1) 296 | 297 | if not self.initialized: 298 | self.initialized = True 299 | unsqueeze = lambda x: x.unsqueeze(0).unsqueeze(-1).detach() 300 | 301 | # Compute the mean and variance 302 | sum_size = input.size(0) * input.size(-1) 303 | b = -torch.sum(input, dim=(0, -1)) / sum_size 304 | vars = unsqueeze(torch.sum((input + unsqueeze(b)) ** 2, dim=(0, -1))/sum_size) 305 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) / self.logscale_factor 306 | 307 | self.b.data.copy_(unsqueeze(b).data) 308 | self.logs.data.copy_(logs.data) 309 | 310 | logs = self.logs * self.logscale_factor 311 | b = self.b 312 | 313 | output = (input + b) * torch.exp(logs) 314 | dlogdet = torch.sum(logs) * input.size(-1) # c x h 315 | 316 | return output.view(input_shape), objective + dlogdet 317 | 318 | def reverse_(self, input, objective): 319 | assert self.initialized 320 | input_shape = input.size() 321 | input = input.view(input_shape[0], input_shape[1], -1) 322 | logs = self.logs * self.logscale_factor 323 | b = self.b 324 | output = input * torch.exp(-logs) - b 325 | dlogdet = torch.sum(logs) * input.size(-1) # c x h 326 | 327 | return output.view(input_shape), objective - dlogdet 328 | 329 | # (Note: a BatchNorm layer can be found in previous commits) 330 | 331 | 332 | # ------------------------------------------------------------------------------ 333 | # Stacked Layers 334 | # ------------------------------------------------------------------------------ 335 | 336 | # 1 step of the flow (see Figure 2 a) in the original paper) 337 | class RevNetStep(LayerList): 338 | def __init__(self, num_channels, args): 339 | super(RevNetStep, self).__init__() 340 | self.args = args 341 | layers = [] 342 | if args.norm == 'actnorm': 343 | layers += [ActNorm(num_channels)] 344 | else: 345 | assert not args.norm 346 | 347 | if args.permutation == 'reverse': 348 | layers += [Reverse(num_channels)] 349 | elif args.permutation == 'shuffle': 350 | layers += [Shuffle(num_channels)] 351 | elif args.permutation == 'conv': 352 | layers += [Invertible1x1Conv(num_channels)] 353 | else: 354 | raise ValueError 355 | 356 | if args.coupling == 'additive': 357 | layers += [AdditiveCoupling(num_channels)] 358 | elif args.coupling == 'affine': 359 | layers += [AffineCoupling(num_channels)] 360 | else: 361 | raise ValueError 362 | 363 | self.layers = nn.ModuleList(layers) 364 | 365 | 366 | # Full model 367 | class Glow_(LayerList, nn.Module): 368 | def __init__(self, input_shape, args): 369 | super(Glow_, self).__init__() 370 | layers = [] 371 | output_shapes = [] 372 | _, C, H, W = input_shape 373 | 374 | for i in range(args.n_levels): 375 | # Squeeze Layer 376 | layers += [Squeeze(input_shape)] 377 | C, H, W = C * 4, H // 2, W // 2 378 | output_shapes += [(-1, C, H, W)] 379 | 380 | # RevNet Block 381 | layers += [RevNetStep(C, args) for _ in range(args.depth)] 382 | output_shapes += [(-1, C, H, W) for _ in range(args.depth)] 383 | 384 | if i < args.n_levels - 1: 385 | # Split Layer 386 | layers += [Split(output_shapes[-1])] 387 | C = C // 2 388 | output_shapes += [(-1, C, H, W)] 389 | 390 | layers += [GaussianPrior((args.batch_size, C, H, W), args)] 391 | output_shapes += [output_shapes[-1]] 392 | 393 | self.layers = nn.ModuleList(layers) 394 | self.output_shapes = output_shapes 395 | self.args = args 396 | self.flatten() 397 | 398 | def forward(self, *inputs): 399 | return self.forward_(*inputs) 400 | 401 | def sample(self): 402 | with torch.no_grad(): 403 | samples = self.reverse_(None, 0.)[0] 404 | return samples 405 | 406 | def flatten(self): 407 | # flattens the list of layers to avoid recursive call every time. 408 | processed_layers = [] 409 | to_be_processed = [self] 410 | while len(to_be_processed) > 0: 411 | current = to_be_processed.pop(0) 412 | if isinstance(current, LayerList): 413 | to_be_processed = [x for x in current.layers] + to_be_processed 414 | elif isinstance(current, Layer): 415 | processed_layers += [current] 416 | 417 | self.layers = nn.ModuleList(processed_layers) 418 | 419 | 420 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as wn 5 | 6 | import numpy as np 7 | import pdb 8 | 9 | ''' 10 | Convolution Layer with zero initialisation 11 | ''' 12 | class Conv2dZeroInit(nn.Conv2d): 13 | def __init__(self, channels_in, channels_out, filter_size, stride=1, padding=0, logscale=3.): 14 | super().__init__(channels_in, channels_out, filter_size, stride=stride, padding=padding) 15 | self.register_parameter("logs", nn.Parameter(torch.zeros(channels_out, 1, 1))) 16 | self.logscale_factor = logscale 17 | 18 | def reset_parameters(self): 19 | self.weight.data.zero_() 20 | self.bias.data.zero_() 21 | 22 | def forward(self, input): 23 | out = super().forward(input) 24 | return out * torch.exp(self.logs * self.logscale_factor) 25 | 26 | ''' 27 | Convolution Interlaced with Actnorm 28 | ''' 29 | class Conv2dActNorm(nn.Module): 30 | def __init__(self, channels_in, channels_out, filter_size, stride=1, padding=None): 31 | from invertible_layers import ActNorm 32 | super(Conv2dActNorm, self).__init__() 33 | padding = (filter_size - 1) // 2 or padding 34 | self.conv = nn.Conv2d(channels_in, channels_out, filter_size, padding=padding, bias=False) 35 | self.actnorm = ActNorm(channels_out) 36 | 37 | def forward(self, x): 38 | x = self.conv(x) 39 | x = self.actnorm.forward_(x, -1)[0] 40 | return x 41 | 42 | ''' 43 | Linear layer zero initialization 44 | ''' 45 | class LinearZeroInit(nn.Linear): 46 | def reset_parameters(self): 47 | self.weight.data.fill_(0.) 48 | self.bias.data.fill_(0.) 49 | 50 | ''' 51 | Shallow NN used for skip connection. Labelled `f` in the original repo. 52 | ''' 53 | def NN(in_channels, hidden_channels=512, channels_out=None): 54 | channels_out = channels_out or in_channels 55 | return nn.Sequential( 56 | Conv2dActNorm(in_channels, hidden_channels, 3, stride=1, padding=1), 57 | nn.ReLU(inplace=True), 58 | Conv2dActNorm(hidden_channels, hidden_channels, 1, stride=1, padding=0), 59 | nn.ReLU(inplace=True), 60 | Conv2dZeroInit(hidden_channels, channels_out, 3, stride=1, padding=1)) 61 | -------------------------------------------------------------------------------- /test_flows.py: -------------------------------------------------------------------------------- 1 | # inspired by https://github.com/ikostrikov/pytorch-flows/blob/master/flow_test.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.weight_norm as wn 7 | 8 | import numpy as np 9 | import unittest 10 | import pdb 11 | 12 | from invertible_layers import * 13 | 14 | EPS = 1e-5 15 | BATCH_SIZE = 17 16 | NUM_CHANNELS = 64 17 | H = 32 18 | W = 48 19 | 20 | 21 | class TestFlow(unittest.TestCase): 22 | def test_shuffle(self): 23 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 24 | layer = Shuffle(NUM_CHANNELS) 25 | log_det = torch.randn(BATCH_SIZE) 26 | 27 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 28 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 29 | 30 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 31 | 'Shuffle Layer det is not zero.') 32 | 33 | self.assertTrue((x - x_).abs().max() < EPS, 'Shuffle Layer is wrong') 34 | 35 | def test_reverse(self): 36 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 37 | layer = Reverse(NUM_CHANNELS) 38 | log_det = torch.randn(BATCH_SIZE) 39 | 40 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 41 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 42 | 43 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 44 | 'Shuffle Layer det is not zero.') 45 | 46 | self.assertTrue((x - x_).abs().max() < EPS, 'Shuffle Layer is wrong') 47 | 48 | def test_conv(self): 49 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 50 | layer = Invertible1x1Conv(NUM_CHANNELS) 51 | log_det = torch.randn(BATCH_SIZE) 52 | 53 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 54 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 55 | 56 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 57 | 'Conv Layer det is not zero.') 58 | 59 | self.assertTrue((x - x_).abs().max() < EPS, 'Conv Layer is wrong') 60 | self.assertTrue((log_det - inv_log_det).abs().max() > 0.01 * EPS, 'Determinant was not changed!') 61 | def test_squeeze(self): 62 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 63 | layer = Squeeze([int(y) for y in x.size()]) 64 | log_det = torch.randn(BATCH_SIZE) 65 | 66 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 67 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 68 | 69 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 70 | 'Squeeze Layer det is not zero.') 71 | 72 | self.assertTrue((x - x_).abs().max() < EPS, 'Squeeze Layer is wrong') 73 | 74 | def test_split(self): 75 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 76 | layer = Split([int(y) for y in x.size()]) 77 | log_det = torch.randn(BATCH_SIZE) 78 | 79 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 80 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone(), use_stored_sample=True) 81 | 82 | self.assertTrue((log_det_ - log_det).abs().max() < 1e-2, 83 | 'Squeeze Layer det is not zero.') 84 | 85 | self.assertTrue((x - x_).abs().max() < EPS, 'Squeeze Layer is wrong') 86 | self.assertTrue((log_det - inv_log_det).abs().max() > EPS, 'Determinant was not changed!') 87 | def test_add(self): 88 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 89 | layer = AdditiveCoupling(NUM_CHANNELS) 90 | log_det = torch.randn(BATCH_SIZE) 91 | 92 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 93 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 94 | 95 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 96 | 'Additive Coupling Layer det is not zero.') 97 | 98 | self.assertTrue((x - x_).abs().max() < EPS, 'Additive Coupling Layer is wrong') 99 | 100 | def test_affine(self): 101 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 102 | layer = AffineCoupling(NUM_CHANNELS) 103 | log_det = torch.randn(BATCH_SIZE) 104 | 105 | # import pdb; pdb.set_trace() 106 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 107 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 108 | 109 | self.assertTrue((log_det_ - log_det).abs().max() < 1e-3, 110 | 'affine coupling layer det is not zero.') 111 | 112 | self.assertTrue((x - x_).abs().max() < EPS, 'affine coupling layer is wrong') 113 | self.assertTrue((log_det - inv_log_det).abs().max() > EPS, 'determinant was not changed!') 114 | 115 | def test_actnorm(self): 116 | x = torch.randn(BATCH_SIZE, NUM_CHANNELS, H, W) 117 | layer = ActNorm(NUM_CHANNELS) 118 | log_det = torch.randn(BATCH_SIZE) 119 | 120 | # import pdb; pdb.set_trace() 121 | y, inv_log_det = layer.forward_(x.clone(), log_det.clone()) 122 | x_, log_det_ = layer.reverse_(y.clone(), inv_log_det.clone()) 123 | 124 | self.assertTrue((log_det_ - log_det).abs().max() < EPS, 125 | 'actnorm layer det is not zero.') 126 | 127 | self.assertTrue((x - x_).abs().max() < EPS, 'actnorm layer is wrong') 128 | self.assertTrue((log_det - inv_log_det).abs().max() > EPS, 'determinant was not changed!') 129 | 130 | 131 | 132 | if __name__ == '__main__': 133 | unittest.main() 134 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler 6 | from torchvision import datasets, transforms, utils 7 | 8 | import numpy as np 9 | import pdb 10 | import argparse 11 | import time 12 | 13 | from invertible_layers import * 14 | from utils import * 15 | 16 | # ------------------------------------------------------------------------------ 17 | parser = argparse.ArgumentParser() 18 | # training 19 | parser.add_argument('--batch_size', type=int, default=64) 20 | parser.add_argument('--depth', type=int, default=32) 21 | parser.add_argument('--n_levels', type=int, default=3) 22 | parser.add_argument('--norm', type=str, default='actnorm') 23 | parser.add_argument('--permutation', type=str, default='conv') 24 | parser.add_argument('--coupling', type=str, default='affine') 25 | parser.add_argument('--n_bits_x', type=int, default=8) 26 | parser.add_argument('--n_epochs', type=int, default=2000) 27 | parser.add_argument('--learntop', action='store_true') 28 | parser.add_argument('--n_warmup', type=int, default=20, help='number of warmup epochs') 29 | parser.add_argument('--lr', type=float, default=1e-3) 30 | # logging 31 | parser.add_argument('--print_every', type=int, default=500, help='print NLL every _ minibatches') 32 | parser.add_argument('--test_every', type=int, default=5, help='test on valid every _ epochs') 33 | parser.add_argument('--save_every', type=int, default=5, help='save model every _ epochs') 34 | parser.add_argument('--data_dir', type=str, default='../pixelcnn-pp') 35 | parser.add_argument('--save_dir', type=str, default='exps', help='directory for log / saving') 36 | parser.add_argument('--load_dir', type=str, default=None, help='directory from which to load existing model') 37 | args = parser.parse_args() 38 | args.n_bins = 2 ** args.n_bits_x 39 | 40 | # reproducibility is good 41 | np.random.seed(0) 42 | torch.manual_seed(0) 43 | torch.cuda.manual_seed_all(0) 44 | 45 | # loading / dataset preprocessing 46 | tf = transforms.Compose([transforms.ToTensor(), 47 | lambda x: x + torch.zeros_like(x).uniform_(0., 1./args.n_bins)]) 48 | 49 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, 50 | download=True, transform=tf), batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) 51 | 52 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False, 53 | transform=tf), batch_size=args.batch_size, shuffle=False, num_workers=10, drop_last=True) 54 | 55 | # construct model and ship to GPU 56 | model = Glow_((args.batch_size, 3, 32, 32), args).cuda() 57 | print(model) 58 | print("number of model parameters:", sum([np.prod(p.size()) for p in model.parameters()])) 59 | 60 | # set up the optimizer 61 | optim = optim.Adam(model.parameters(), lr=1e-3) 62 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=45, gamma=0.1) 63 | 64 | # data dependant init 65 | init_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, 66 | download=True, transform=tf), batch_size=512, shuffle=True, num_workers=1) 67 | 68 | with torch.no_grad(): 69 | model.eval() 70 | for (img, _) in init_loader: 71 | img = img.cuda() 72 | objective = torch.zeros_like(img[:, 0, 0, 0]) 73 | _ = model(img, objective) 74 | break 75 | 76 | # once init is done, we leverage Data Parallel 77 | model = nn.DataParallel(model).cuda() 78 | start_epoch = 0 79 | 80 | # load trained model if necessary (must be done after DataParallel) 81 | if args.load_dir is not None: 82 | model, optim, start_epoch = load_session(model, optim, args) 83 | 84 | # training loop 85 | # ------------------------------------------------------------------------------ 86 | for epoch in range(start_epoch, args.n_epochs): 87 | print('epoch %s' % epoch) 88 | model.train() 89 | avg_train_bits_x = 0. 90 | num_batches = len(train_loader) 91 | for i, (img, label) in enumerate(train_loader): 92 | # if i > 10 : break 93 | 94 | t = time.time() 95 | img = img.cuda() 96 | objective = torch.zeros_like(img[:, 0, 0, 0]) 97 | 98 | # discretizing cost 99 | objective += float(-np.log(args.n_bins) * np.prod(img.shape[1:])) 100 | 101 | # log_det_jacobian cost (and some prior from Split OP) 102 | z, objective = model(img, objective) 103 | 104 | nll = (-objective) / float(np.log(2.) * np.prod(img.shape[1:])) 105 | 106 | # Generative loss 107 | nobj = torch.mean(nll) 108 | 109 | optim.zero_grad() 110 | nobj.backward() 111 | torch.nn.utils.clip_grad_value_(model.parameters(), 5) 112 | torch.nn.utils.clip_grad_norm_(model.parameters(), 100) 113 | optim.step() 114 | avg_train_bits_x += nobj.item() 115 | 116 | # update learning rate 117 | new_lr = float(args.lr * min(1., (i + epoch * num_batches) / (args.n_warmup * num_batches))) 118 | for pg in optim.param_groups: pg['lr'] = new_lr 119 | 120 | if (i + 1) % args.print_every == 0: 121 | print('avg train bits per pixel {:.4f}'.format(avg_train_bits_x / args.print_every)) 122 | avg_train_bits_x = 0. 123 | sample = model.module.sample() 124 | grid = utils.make_grid(sample) 125 | utils.save_image(grid, '../glow/samples/cifar_Test_{}_{}.png'.format(epoch, i // args.print_every)) 126 | 127 | print('iteration took {:.4f}'.format(time.time() - t)) 128 | 129 | # test loop 130 | # -------------------------------------------------------------------------- 131 | if (epoch + 1) % args.test_every == 0: 132 | model.eval() 133 | avg_test_bits_x = 0. 134 | with torch.no_grad(): 135 | for i, (img, label) in enumerate(test_loader): 136 | # if i > 10 : break 137 | img = img.cuda() 138 | objective = torch.zeros_like(img[:, 0, 0, 0]) 139 | 140 | # discretizing cost 141 | objective += float(-np.log(args.n_bins) * np.prod(img.shape[1:])) 142 | 143 | # log_det_jacobian cost (and some prior from Split OP) 144 | z, objective = model(img, objective) 145 | last_img = img 146 | 147 | nll = (-objective) / float(np.log(2.) * np.prod(img.shape[1:])) 148 | 149 | # Generative loss 150 | nobj = torch.mean(nll) 151 | avg_test_bits_x += nobj 152 | 153 | print('avg test bits per pixel {:.4f}'.format(avg_test_bits_x.item() / i)) 154 | 155 | sample = model.module.sample() 156 | grid = utils.make_grid(sample) 157 | utils.save_image(grid, '../glow/samples/cifar_Test_{}.png'.format(epoch)) 158 | 159 | # reconstruct 160 | x_hat = model.module.reverse_(z, objective)[0] 161 | grid = utils.make_grid(x_hat) 162 | utils.save_image(grid, '../glow/samples/cifar_Test_Recon{}.png'.format(epoch)) 163 | 164 | grid = utils.make_grid(last_img) 165 | utils.save_image(grid, '../glow/samples/cifar_Test_Target.png') 166 | 167 | 168 | if (epoch + 1) % args.save_every == 0: 169 | save_session(model, optim, args, epoch) 170 | 171 | 172 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as wn 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | import numpy as np 8 | import pdb 9 | import os 10 | 11 | # ------------------------------------------------------------------------------ 12 | # Utility Methods 13 | # ------------------------------------------------------------------------------ 14 | 15 | def flatten_sum(logps): 16 | while len(logps.size()) > 1: 17 | logps = logps.sum(dim=-1) 18 | return logps 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Logging 22 | # ------------------------------------------------------------------------------ 23 | 24 | def save_session(model, optim, args, epoch): 25 | path = os.path.join(args.save_dir, str(epoch)) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | 29 | # save the model and optimizer state 30 | torch.save(model.state_dict(), os.path.join(path, 'model.pth')) 31 | torch.save(optim.state_dict(), os.path.join(path, 'optim.pth')) 32 | print('Successfully saved model') 33 | 34 | def load_session(model, optim, args): 35 | try: 36 | start_epoch = int(args.load_dir.split('/')[-1]) 37 | model.load_state_dict(torch.load(os.path.join(args.load_dir, 'model.pth'))) 38 | optim.load_state_dict(torch.load(os.path.join(args.load_dir, 'optim.pth'))) 39 | print('Successfully loaded model') 40 | except Exception as e: 41 | pdb.set_trace() 42 | print('Could not restore session properly') 43 | 44 | return model, optim, start_epoch 45 | 46 | 47 | # ------------------------------------------------------------------------------ 48 | # Distributions 49 | # ------------------------------------------------------------------------------ 50 | 51 | def standard_gaussian(shape): 52 | mean, logsd = [torch.cuda.FloatTensor(shape).fill_(0.) for _ in range(2)] 53 | return gaussian_diag(mean, logsd) 54 | 55 | def gaussian_diag(mean, logsd): 56 | class o(object): 57 | Log2PI = float(np.log(2 * np.pi)) 58 | pass 59 | 60 | def logps(x): 61 | return -0.5 * (o.Log2PI + 2. * logsd + ((x - mean) ** 2) / torch.exp(2. * logsd)) 62 | 63 | def sample(): 64 | eps = torch.zeros_like(mean).normal_() 65 | return mean + torch.exp(logsd) * eps 66 | 67 | o.logp = lambda x: flatten_sum(o.logps(x)) 68 | return o 69 | 70 | 71 | --------------------------------------------------------------------------------