├── .gitignore ├── README.md ├── configs ├── transformer_cat.yml ├── transformer_dmol.yml └── transformer_tiny.yml ├── image_transformer.py ├── requirements.txt └── train_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | transformer_logs/ 2 | venv/ 3 | datasets/ 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Transformer (pytorch) 2 | 3 | A Pytorch implementation of the [Image Transformer](https://arxiv.org/abs/1802.05751). Code adapted from the official implementation in the [tensor2tensor](https://github.com/tensorflow/tensor2tensor/) library. 4 | 5 | Currently supports unconditional image generation for CIFAR10, where the distribution for a pixel can either be categorical or discretized mixture of logistics (as in PixelCNN++). Supports block-wise attention using Local 1D blocks, which perform the best in evaluations on CIFAR10. 6 | 7 | Pull requests are welcome for supporting Local 2D blocked attention, image to image superresolution, or class conditional generation! 8 | 9 | ### Running the code 10 | 11 | Install the requirements with `pip install -r requirements.txt`. Then, run the code, with the optional sample flag to generate samples during train time. 12 | 13 | ``` 14 | python3 train_transformer.py --doc run_name --config transformer_cat.yml --sample 15 | ``` 16 | -------------------------------------------------------------------------------- /configs/transformer_cat.yml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | 3 | train: 4 | epochs: 100 5 | batch_size: 4 6 | sample_iter: 1000 7 | sample_size: 4 8 | log_iter: 1 9 | clip_grad_norm: 0 10 | 11 | optim: 12 | lr: 0.1 13 | warmup: 4000 14 | 15 | model: 16 | # Using cifar10_base hparams -- ideally gets 2.90 bits/dim 17 | channels: 3 18 | image_size: 32 19 | hidden_size: 512 20 | nlayers: 12 21 | num_heads: 4 22 | total_key_depth: 0 # If 0, uses hidden_size instead 23 | total_value_depth: 0 24 | attn_type: "local_1d" 25 | distr: "cat" 26 | filter_size: 2048 # in ffn 27 | dropout: 0.3 28 | block_length: 256 29 | initializer_gain: 0.2 30 | -------------------------------------------------------------------------------- /configs/transformer_dmol.yml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | 3 | train: 4 | epochs: 100 5 | batch_size: 8 6 | sample_iter: 1000 7 | sample_size: 8 8 | log_iter: 1 9 | clip_grad_norm: 0 10 | 11 | optim: 12 | lr: 0.1 13 | warmup: 4000 14 | 15 | model: 16 | # Using cifar10_base hparams -- ideally gets 2.90 bits/dim 17 | image_size: 32 18 | channels: 3 19 | hidden_size: 256 20 | nlayers: 14 21 | num_heads: 8 22 | total_key_depth: 512 # If 0, uses hidden_size instead 23 | total_value_depth: 512 24 | attn_type: "local_1d" 25 | distr: "dmol" 26 | num_mixtures: 10 # for dmol only 27 | filter_size: 512 # in ffn 28 | dropout: 0.1 29 | block_length: 256 30 | initializer_gain: 0.2 31 | 32 | 33 | -------------------------------------------------------------------------------- /configs/transformer_tiny.yml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | 3 | train: 4 | epochs: 100 5 | batch_size: 8 6 | sample_iter: 1000 7 | sample_size: 8 8 | log_iter: 10 9 | clip_grad_norm: 0.0 10 | 11 | optim: 12 | lr: 0.1 13 | warmup: 4000 14 | 15 | model: 16 | image_size: 8 17 | channels: 3 18 | hidden_size: 64 19 | nlayers: 12 20 | num_heads: 4 21 | total_key_depth: 0 # If 0, uses hidden_size instead 22 | total_value_depth: 0 23 | attn_type: "global" 24 | distr: "cat" 25 | num_mixtures: 5 # for dmol only 26 | filter_size: 128 # in ffn 27 | dropout: 0.1 28 | initializer_gain: 0.2 29 | 30 | 31 | -------------------------------------------------------------------------------- /image_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import logging 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | NUM_PIXELS = 256 11 | 12 | # Numerically stable implementations 13 | def logsoftmax(x): 14 | m = torch.max(x, -1, keepdim=True).values 15 | return x - m - torch.log(torch.exp(x - m).sum(-1, keepdim=True)) 16 | 17 | def logsumexp(x): 18 | m = x.max(-1).values 19 | return m + torch.log(torch.exp(x - m[...,None]).sum(-1)) 20 | 21 | class ImageTransformer(nn.Module): 22 | """ImageTransformer with DMOL or categorical distribution.""" 23 | def __init__(self, hparams): 24 | super().__init__() 25 | self.hparams = hparams 26 | self.layers = nn.ModuleList([DecoderLayer(hparams) for _ in range(hparams.nlayers)]) 27 | self.input_dropout = nn.Dropout(p=hparams.dropout) 28 | if self.hparams.distr == "dmol": # Discretized mixture of logistic, for ordinal valued inputs 29 | assert self.hparams.channels == 3, "Only supports 3 channels for DML" 30 | size = (1, self.hparams.channels) 31 | self.embedding_conv = nn.Conv2d(1, self.hparams.hidden_size, 32 | kernel_size=size, stride=size) 33 | # 10 = 1 + 2c + c(c-1)/2; if only 1 channel, then 3 total 34 | depth = self.hparams.num_mixtures * 10 35 | self.output_dense = nn.Linear(self.hparams.hidden_size, depth, bias=False) 36 | elif self.hparams.distr == "cat": # Categorical 37 | self.embeds = nn.Embedding(NUM_PIXELS * self.hparams.channels, self.hparams.hidden_size) 38 | self.output_dense = nn.Linear(self.hparams.hidden_size, NUM_PIXELS, bias=True) 39 | else: 40 | raise ValueError("Only dmol or categorical distributions") 41 | 42 | # TODO: can probably cache this computation. (Be wary of shapes for train vs. predict) 43 | def add_timing_signal(self, X, min_timescale=1.0, max_timescale=1.0e4): 44 | num_dims = len(X.shape) - 2 # 2 corresponds to batch and hidden_size dimensions 45 | num_timescales = self.hparams.hidden_size // (num_dims * 2) 46 | log_timescale_increment = np.log(max_timescale / min_timescale) / (num_timescales - 1) 47 | inv_timescales = min_timescale * torch.exp((torch.arange(num_timescales).float() * -log_timescale_increment)) 48 | inv_timescales = inv_timescales.to(X.device) 49 | total_signal = torch.zeros_like(X) # Only for debugging purposes 50 | for dim in range(num_dims): 51 | length = X.shape[dim + 1] # add 1 to exclude batch dim 52 | position = torch.arange(length).float().to(X.device) 53 | scaled_time = position.view(-1, 1) * inv_timescales.view(1, -1) 54 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1) 55 | prepad = dim * 2 * num_timescales 56 | postpad = self.hparams.hidden_size - (dim + 1) * 2 * num_timescales 57 | signal = F.pad(signal, (prepad, postpad)) 58 | for _ in range(1 + dim): 59 | signal = signal.unsqueeze(0) 60 | for _ in range(num_dims - 1 - dim): 61 | signal = signal.unsqueeze(-2) 62 | X += signal 63 | total_signal += signal 64 | return X 65 | 66 | def shift_and_pad_(self, X): 67 | # Shift inputs over by 1 and pad 68 | shape = X.shape 69 | X = X.view(shape[0], shape[1] * shape[2], shape[3]) 70 | X = X[:,:-1,:] 71 | X = F.pad(X, (0, 0, 1, 0)) # Pad second to last dimension 72 | X = X.view(shape) 73 | return X 74 | 75 | def forward(self, X, sampling=False): 76 | # Reshape inputs 77 | if sampling: 78 | curr_infer_length = X.shape[1] 79 | row_size = self.hparams.image_size * self.hparams.channels 80 | nrows = curr_infer_length // row_size + 1 81 | X = F.pad(X, (0, nrows * row_size - curr_infer_length)) 82 | X = X.view(X.shape[0], -1, row_size) 83 | else: 84 | X = X.permute([0, 2, 3, 1]).contiguous() 85 | X = X.view(X.shape[0], X.shape[1], X.shape[2] * X.shape[3]) # Flatten channels into width 86 | 87 | # Inputs -> embeddings 88 | if self.hparams.distr == "dmol": 89 | # Create a "channel" dimension for the 1x3 convolution 90 | # (NOTE: can apply a 1x1 convolution and not reshape, this is for consistency) 91 | X = X.unsqueeze(1) 92 | X = F.relu(self.embedding_conv(X)) 93 | X = X.permute([0, 2, 3, 1]) # move channels to the end 94 | elif self.hparams.distr == "cat": 95 | # Convert to indexes, and use separate embeddings for different channels 96 | X = (X * (NUM_PIXELS - 1)).long() 97 | channel_addition = (torch.tensor([0, 1, 2]) * NUM_PIXELS).to(X.device).repeat(X.shape[2] // 3).view(1, 1, -1) 98 | X += channel_addition 99 | X = self.embeds(X) * (self.hparams.hidden_size ** 0.5) 100 | 101 | X = self.shift_and_pad_(X) 102 | X = self.add_timing_signal(X) 103 | shape = X.shape 104 | X = X.view(shape[0], -1, shape[3]) 105 | 106 | X = self.input_dropout(X) 107 | for layer in self.layers: 108 | X = layer(X) 109 | X = self.layers[-1].preprocess_(X) # NOTE: this is identity (exists to replicate tensorflow code) 110 | X = self.output_dense(X).view(shape[:3] + (-1,)) 111 | 112 | if not sampling and self.hparams.distr == "cat": # Unpack the channels 113 | X = X.view(X.shape[0], X.shape[1], X.shape[2] // self.hparams.channels, self.hparams.channels, X.shape[3]) 114 | X = X.permute([0, 3, 1, 2, 4]) 115 | 116 | return X 117 | 118 | def split_to_dml_params(self, preds, targets=None, sampling=False): 119 | nm = self.hparams.num_mixtures 120 | mix_logits, locs, log_scales, coeffs = torch.split(preds, [nm, nm * 3, nm * 3, nm * 3], dim=-1) 121 | new_shape = preds.shape[:-1] + (3, nm) 122 | locs = locs.view(new_shape) 123 | coeffs = torch.tanh(coeffs.view(new_shape)) 124 | log_scales = torch.clamp(log_scales.view(new_shape), min=-7.) 125 | if not sampling: 126 | targets = targets.unsqueeze(-1) 127 | locs1 = locs[...,1,:] + coeffs[...,0,:] * targets[:,0,...] 128 | locs2 = locs[...,2,:] + coeffs[...,1,:] * targets[:,0,...] + coeffs[...,2,:] * targets[:,1,...] 129 | locs = torch.stack([locs[...,0,:], locs1, locs2], dim=-2) 130 | return mix_logits, locs, log_scales 131 | else: 132 | return mix_logits, locs, log_scales, coeffs 133 | 134 | # Modified from official PixCNN++ code 135 | def dml_logp(self, logits, means, log_scales, targets): 136 | targets = targets.unsqueeze(-1) 137 | centered_x = targets - means 138 | inv_stdv = torch.exp(-log_scales) 139 | plus_in = inv_stdv * (centered_x + 1. / 255.) 140 | cdf_plus = torch.sigmoid(plus_in) 141 | min_in = inv_stdv * (centered_x - 1. / 255.) 142 | cdf_min = torch.sigmoid(min_in) 143 | log_cdf_plus = plus_in - F.softplus(plus_in) # log probability for edge case of 0 (before scaling) 144 | log_one_minus_cdf_min = -F.softplus(min_in) # log probability for edge case of 255 (before scaling) 145 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 146 | mid_in = inv_stdv * centered_x 147 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus( 148 | mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 149 | 150 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 151 | log_probs = torch.where(targets < -0.999, log_cdf_plus, torch.where(targets > 0.999, log_one_minus_cdf_min, 152 | torch.where(cdf_delta > 1e-5, 153 | torch.log(torch.clamp(cdf_delta, min=1e-12)), 154 | log_pdf_mid - np.log(127.5)))) 155 | log_probs = log_probs.sum(3) + logsoftmax(logits) 156 | log_probs = logsumexp(log_probs) 157 | return log_probs 158 | 159 | # Assumes targets have been rescaled to [-1., 1.] 160 | def loss(self, preds, targets): 161 | if self.hparams.distr == "dmol": 162 | # Assumes 3 channels. Input: [batch_size, height, width, 10 * 10] 163 | logits, locs, log_scales = self.split_to_dml_params(preds, targets) 164 | targets = targets.permute([0, 2, 3, 1]) 165 | log_probs = self.dml_logp(logits, locs, log_scales, targets) 166 | return -log_probs 167 | elif self.hparams.distr == "cat": 168 | targets = (targets * (NUM_PIXELS - 1)).long() 169 | ce = F.cross_entropy(preds.permute(0, 4, 1, 2, 3), targets, reduction='none') 170 | return ce 171 | 172 | def accuracy(self, preds, targets): 173 | if self.hparams.distr == "cat": 174 | targets = (targets * (NUM_PIXELS - 1)).long() 175 | argmax_preds = torch.argmax(preds, dim=-1) 176 | acc = torch.eq(argmax_preds, targets).float().sum() / np.prod(argmax_preds.shape) 177 | return acc 178 | else: 179 | # Computing accuracy for dmol is more computationally intensive, so we skip it 180 | return torch.zeros((1,)) 181 | 182 | def sample_from_dmol(self, outputs): 183 | logits, locs, log_scales, coeffs = self.split_to_dml_params(outputs, sampling=True) 184 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5)) 185 | sel = torch.argmax(logits + gumbel_noise, -1, keepdim=True) 186 | one_hot = torch.zeros_like(logits).scatter_(-1, sel, 1).unsqueeze(-2) 187 | locs = (locs * one_hot).sum(-1) 188 | log_scales = (log_scales * one_hot).sum(-1) 189 | coeffs = (coeffs * one_hot).sum(-1) 190 | unif = torch.rand_like(log_scales) * (1. - 2 * 1e-5) + 1e-5 191 | logistic_noise = torch.log(unif) - torch.log1p(-unif) 192 | x = locs + torch.exp(log_scales) * logistic_noise 193 | # NOTE: sampling analogously to pixcnn++, which clamps first, unlike image transformer 194 | x0 = torch.clamp(x[..., 0], -1., 1.) 195 | x1 = torch.clamp(x[..., 1] + coeffs[..., 0] * x0, -1., 1.) 196 | x2 = torch.clamp(x[..., 2] + coeffs[..., 1] * x0 + coeffs[..., 2] * x1, -1., 1.) 197 | x = torch.stack([x0, x1, x2], -1) 198 | return x 199 | 200 | def sample_from_cat(self, logits, argmax=False): 201 | if argmax: 202 | sel = torch.argmax(logits, -1, keepdim=False).float() / 255. 203 | else: 204 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5)) 205 | sel = torch.argmax(logits + gumbel_noise, -1, keepdim=False).float() / 255. 206 | return sel 207 | 208 | def sample(self, n, device, argmax=False): 209 | total_len = (self.hparams.image_size ** 2) 210 | if self.hparams.distr == "cat": 211 | total_len *= self.hparams.channels 212 | samples = torch.zeros((n, 3)).to(device) 213 | for curr_infer_length in tqdm(range(total_len)): 214 | outputs = self.forward(samples, sampling=True) 215 | outputs = outputs.view(n, -1, outputs.shape[-1])[:,curr_infer_length:curr_infer_length+1,:] 216 | if self.hparams.distr == "dmol": 217 | x = self.sample_from_dmol(outputs).squeeze() 218 | elif self.hparams.distr == "cat": 219 | x = self.sample_from_cat(outputs, argmax=argmax) 220 | if curr_infer_length == 0: 221 | samples = x 222 | else: 223 | samples = torch.cat([samples, x], 1) 224 | samples = samples.view(n, self.hparams.image_size, self.hparams.image_size, self.hparams.channels) 225 | samples = samples.permute(0, 3, 1, 2) 226 | return samples 227 | 228 | def sample_from_preds(self, preds, argmax=False): 229 | if self.hparams.distr == "dmol": 230 | samples = self.sample_from_dmol(preds) 231 | samples = samples.permute(0, 3, 1, 2) 232 | elif self.hparams.distr == "cat": 233 | samples = self.sample_from_cat(preds, argmax=argmax) 234 | return samples 235 | 236 | 237 | class DecoderLayer(nn.Module): 238 | """Implements a single layer of an unconditional ImageTransformer""" 239 | def __init__(self, hparams): 240 | super().__init__() 241 | self.attn = Attn(hparams) 242 | self.hparams = hparams 243 | self.dropout = nn.Dropout(p=hparams.dropout) 244 | self.layernorm_attn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True) 245 | self.layernorm_ffn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True) 246 | self.ffn = nn.Sequential(nn.Linear(self.hparams.hidden_size, self.hparams.filter_size, bias=True), 247 | nn.ReLU(), 248 | nn.Linear(self.hparams.filter_size, self.hparams.hidden_size, bias=True)) 249 | 250 | def preprocess_(self, X): 251 | return X 252 | 253 | # Takes care of the "postprocessing" from tensorflow code with the layernorm and dropout 254 | def forward(self, X): 255 | X = self.preprocess_(X) 256 | y = self.attn(X) 257 | X = self.layernorm_attn(self.dropout(y) + X) 258 | y = self.ffn(self.preprocess_(X)) 259 | X = self.layernorm_ffn(self.dropout(y) + X) 260 | return X 261 | 262 | class Attn(nn.Module): 263 | def __init__(self, hparams): 264 | super().__init__() 265 | self.hparams = hparams 266 | self.kd = self.hparams.total_key_depth or self.hparams.hidden_size 267 | self.vd = self.hparams.total_value_depth or self.hparams.hidden_size 268 | self.q_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False) 269 | self.k_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False) 270 | self.v_dense = nn.Linear(self.hparams.hidden_size, self.vd, bias=False) 271 | self.output_dense = nn.Linear(self.vd, self.hparams.hidden_size, bias=False) 272 | assert self.kd % self.hparams.num_heads == 0 273 | assert self.vd % self.hparams.num_heads == 0 274 | 275 | def dot_product_attention(self, q, k, v, bias=None): 276 | logits = torch.einsum("...kd,...qd->...qk", k, q) 277 | if bias is not None: 278 | logits += bias 279 | weights = F.softmax(logits, dim=-1) 280 | return weights @ v 281 | 282 | def forward(self, X): 283 | q = self.q_dense(X) 284 | k = self.k_dense(X) 285 | v = self.v_dense(X) 286 | # Split to shape [batch_size, num_heads, len, depth / num_heads] 287 | q = q.view(q.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 288 | k = k.view(k.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 289 | v = v.view(v.shape[:-1] + (self.hparams.num_heads, self.vd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 290 | q *= (self.kd // self.hparams.num_heads) ** (-0.5) 291 | 292 | if self.hparams.attn_type == "global": 293 | bias = -1e9 * torch.triu(torch.ones(X.shape[1], X.shape[1]), 1).to(X.device) 294 | result = self.dot_product_attention(q, k, v, bias=bias) 295 | elif self.hparams.attn_type == "local_1d": 296 | len = X.shape[1] 297 | blen = self.hparams.block_length 298 | pad = (0, 0, 0, (-len) % self.hparams.block_length) # Append to multiple of block length 299 | q = F.pad(q, pad) 300 | k = F.pad(k, pad) 301 | v = F.pad(v, pad) 302 | 303 | bias = -1e9 * torch.triu(torch.ones(blen, blen), 1).to(X.device) 304 | first_output = self.dot_product_attention( 305 | q[:,:,:blen,:], k[:,:,:blen,:], v[:,:,:blen,:], bias=bias) 306 | 307 | if q.shape[2] > blen: 308 | q = q.view(q.shape[0], q.shape[1], -1, blen, q.shape[3]) 309 | k = k.view(k.shape[0], k.shape[1], -1, blen, k.shape[3]) 310 | v = v.view(v.shape[0], v.shape[1], -1, blen, v.shape[3]) 311 | local_k = torch.cat([k[:,:,:-1], k[:,:,1:]], 3) # [batch, nheads, (nblocks - 1), blen * 2, depth] 312 | local_v = torch.cat([v[:,:,:-1], v[:,:,1:]], 3) 313 | tail_q = q[:,:,1:] 314 | bias = -1e9 * torch.triu(torch.ones(blen, 2 * blen), blen + 1).to(X.device) 315 | tail_output = self.dot_product_attention(tail_q, local_k, local_v, bias=bias) 316 | tail_output = tail_output.view(tail_output.shape[0], tail_output.shape[1], -1, tail_output.shape[4]) 317 | result = torch.cat([first_output, tail_output], 2) 318 | result = result[:,:,:X.shape[1],:] 319 | else: 320 | result = first_output[:,:,:X.shape[1],:] 321 | 322 | result = result.permute([0, 2, 1, 3]).contiguous() 323 | result = result.view(result.shape[0:2] + (-1,)) 324 | result = self.output_dense(result) 325 | return result -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.0.3 2 | PyYAML==5.4 3 | seaborn==0.9.0 4 | tensorboardX==1.8 5 | torch==1.2.0 6 | torchvision==0.4.0 7 | tqdm==4.34.0 8 | torchviz==0.0.1 9 | -------------------------------------------------------------------------------- /train_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import sys 4 | import time 5 | import os 6 | import logging 7 | import yaml 8 | import shutil 9 | import numpy as np 10 | import tensorboardX 11 | import torch.optim as optim 12 | import torchvision 13 | from image_transformer import ImageTransformer 14 | import matplotlib 15 | import itertools 16 | from torch.utils.data import DataLoader 17 | from torchvision import datasets, transforms 18 | from torchviz import make_dot 19 | from tqdm import tqdm 20 | import torch.nn as nn 21 | 22 | matplotlib.use('Agg') 23 | import seaborn as sns 24 | import matplotlib.pyplot as plt 25 | 26 | sns.set() 27 | 28 | def dict2namespace(config): 29 | namespace = argparse.Namespace() 30 | for key, value in config.items(): 31 | if isinstance(value, dict): 32 | new_value = dict2namespace(value) 33 | else: 34 | new_value = value 35 | setattr(namespace, key, new_value) 36 | return namespace 37 | 38 | def parse_args_and_config(): 39 | """ 40 | :return args, config: namespace objects that stores information in args and config files. 41 | """ 42 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 43 | 44 | parser.add_argument('--config', type=str, default='transformer_dmol.yml', help='Path to the config file') 45 | parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose') 46 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') 47 | parser.add_argument('--sample', action='store_true', help='Sample at train time') 48 | 49 | args = parser.parse_args() 50 | args.log = os.path.join('transformer_logs', args.doc) 51 | # parse config file 52 | with open(os.path.join('configs', args.config), 'r') as f: 53 | config = yaml.load(f, Loader=yaml.FullLoader) 54 | new_config = dict2namespace({**config, **vars(args)}) 55 | 56 | if os.path.exists(args.log): 57 | shutil.rmtree(args.log) 58 | 59 | os.makedirs(args.log) 60 | 61 | with open(os.path.join(args.log, 'config.yml'), 'w') as f: 62 | yaml.dump(new_config, f, default_flow_style=False) 63 | 64 | # setup logger 65 | level = getattr(logging, args.verbose.upper(), None) 66 | if not isinstance(level, int): 67 | raise ValueError('level {} not supported'.format(args.verbose)) 68 | 69 | handler1 = logging.StreamHandler() 70 | handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt')) 71 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 72 | handler1.setFormatter(formatter) 73 | handler2.setFormatter(formatter) 74 | logger = logging.getLogger() 75 | logger.addHandler(handler1) 76 | logger.addHandler(handler2) 77 | logger.setLevel(level) 78 | 79 | # add device information to args 80 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 81 | logging.info("Using device: {}".format(device)) 82 | new_config.device = device 83 | 84 | # set random seed 85 | torch.manual_seed(new_config.seed) 86 | torch.cuda.manual_seed_all(new_config.seed) 87 | np.random.seed(new_config.seed) 88 | logging.info("Run name: {}".format(args.doc)) 89 | 90 | return args, new_config 91 | 92 | def get_lr(step, config): 93 | warmup_steps = config.optim.warmup 94 | lr_base = config.optim.lr * 0.002 # for Adam correction 95 | ret = 5000. * config.model.hidden_size ** (-0.5) * \ 96 | np.min([(step + 1) * warmup_steps ** (-1.5), (step + 1) ** (-0.5)]) 97 | return ret * lr_base 98 | 99 | def main(): 100 | args, config = parse_args_and_config() 101 | tb_logger = tensorboardX.SummaryWriter(log_dir=os.path.join('transformer_logs', args.doc)) 102 | 103 | if config.model.distr == "dmol": 104 | # Scale size and rescale data to [-1, 1] 105 | transform = transforms.Compose([ 106 | transforms.Resize(config.model.image_size), 107 | transforms.ToTensor(), 108 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 109 | std=[0.5, 0.5, 0.5]) 110 | ]) 111 | else: 112 | transform = transforms.Compose([ 113 | transforms.Resize(config.model.image_size), 114 | transforms.ToTensor() 115 | ]) 116 | dataset = datasets.CIFAR10('datasets/transformer', transform=transform, download=True) 117 | loader = DataLoader(dataset, batch_size=config.train.batch_size, shuffle=True, num_workers=4) 118 | input_dim = config.model.image_size ** 2 * config.model.channels 119 | model = ImageTransformer(config.model).to(config.device) 120 | optimizer = optim.Adam(model.parameters(), lr=1., betas=(0.9, 0.98), eps=1e-9) 121 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: get_lr(step, config)) 122 | 123 | # Initialize as in their code 124 | gain = config.model.initializer_gain 125 | for name, p in model.named_parameters(): 126 | if "layernorm" in name: 127 | continue 128 | # This is from a pytorch implementation of the language transformer, but is not needed/in TF code. 129 | # if "attn" in name and "output" not in name: 130 | # nn.init.xavier_normal_(p) 131 | if p.dim() > 1: 132 | nn.init.xavier_uniform_(p, gain=np.sqrt(gain)) # Need sqrt for inconsistency between pytorch / TF 133 | else: 134 | a = np.sqrt(3. * gain / p.shape[0]) 135 | nn.init.uniform_(p, -a, a) 136 | 137 | # Accumulate data statistics for debugging purposes, e.g. to analyze the entropy of the first dimension 138 | # data_avgs = torch.zeros(config.model.channels, config.model.image_size, config.model.image_size, 256) 139 | # for i, (imgs, l) in tqdm(enumerate(loader)): 140 | # one_hot_data = torch.zeros(imgs.shape + (256,)).scatter_(-1, (imgs * 255).long().unsqueeze(-1), 1) 141 | # data_avgs += one_hot_data.mean(0) 142 | # data_avgs /= i 143 | 144 | def revert_samples(input): 145 | if config.model.distr == "cat": 146 | return input 147 | elif config.model.distr == "dmol": 148 | return input * 0.5 + 0.5 149 | 150 | step = 0 151 | losses_per_dim = torch.zeros(config.model.channels, config.model.image_size, config.model.image_size).to(config.device) 152 | for _ in range(config.train.epochs): 153 | for _, (imgs, l) in enumerate(loader): 154 | imgs = imgs.to(config.device) 155 | model.train() 156 | 157 | scheduler.step() 158 | optimizer.zero_grad() 159 | preds = model(imgs) 160 | loss = model.loss(preds, imgs) 161 | decay = 0. if step == 0 else 0.99 162 | if config.model.distr == "dmol": 163 | losses_per_dim[0,:,:] = losses_per_dim[0,:,:] * decay + (1 - decay) * loss.detach().mean(0) / np.log(2) 164 | else: 165 | losses_per_dim = losses_per_dim * decay + (1 - decay) * loss.detach().mean(0) / np.log(2) 166 | loss = loss.view(loss.shape[0], -1).sum(1) 167 | loss = loss.mean(0) 168 | 169 | # Show computational graph 170 | # dot = make_dot(loss, dict(model.named_parameters())) 171 | # dot.render('test.gv', view=True) 172 | 173 | loss.backward() 174 | 175 | total_norm = 0 176 | for p in model.parameters(): 177 | param_norm = p.grad.data.norm(2) 178 | total_norm += param_norm.item() ** 2 179 | total_norm = (total_norm ** (1. / 2)) 180 | 181 | if config.train.clip_grad_norm > 0.0: 182 | nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad_norm) 183 | 184 | total_norm_post = 0 185 | for p in model.parameters(): 186 | param_norm = p.grad.data.norm(2) 187 | total_norm_post += param_norm.item() ** 2 188 | total_norm_post = (total_norm_post ** (1. / 2)) 189 | 190 | optimizer.step() 191 | bits_per_dim = loss / (np.log(2.) * input_dim) 192 | acc = model.accuracy(preds, imgs) 193 | 194 | if step % config.train.log_iter == 0: 195 | logging.info('step: {}; loss: {:.3f}; bits_per_dim: {:.3f}, acc: {:.3f}, grad norm pre: {:.3f}, post: {:.3f}' 196 | .format(step, loss.item(), bits_per_dim.item(), acc.item(), total_norm, total_norm_post)) 197 | tb_logger.add_scalar('loss', loss.item(), global_step=step) 198 | tb_logger.add_scalar('bits_per_dim', bits_per_dim.item(), global_step=step) 199 | tb_logger.add_scalar('acc', acc.item(), global_step=step) 200 | tb_logger.add_scalar('grad_norm', total_norm, global_step=step) 201 | 202 | if step % config.train.sample_iter == 0: 203 | logging.info("Sampling from model: {}".format(args.doc)) 204 | if config.model.distr == "cat": 205 | channels = ['r','g','b'] 206 | color_codes = ['Reds', "Greens", 'Blues'] 207 | for idx, c in enumerate(channels): 208 | ax = sns.heatmap(losses_per_dim[idx,:,:].cpu().numpy(), linewidth=0.5, cmap=color_codes[idx]) 209 | tb_logger.add_figure("losses_per_dim/{}".format(c), ax.get_figure(), close=True, global_step=step) 210 | else: 211 | ax = sns.heatmap(losses_per_dim[0,:,:].cpu().numpy(), linewidth=0.5, cmap='Blues') 212 | tb_logger.add_figure("losses_per_dim", ax.get_figure(), close=True, global_step=step) 213 | 214 | model.eval() 215 | with torch.no_grad(): 216 | imgs = revert_samples(imgs) 217 | imgs_grid = torchvision.utils.make_grid(imgs[:8, ...], 3) 218 | tb_logger.add_image('imgs', imgs_grid, global_step=step) 219 | 220 | # Evaluate model predictions for the input 221 | pred_samples = revert_samples(model.sample_from_preds(preds)) 222 | pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3) 223 | tb_logger.add_image('pred_samples/random', pred_samples_grid, global_step=step) 224 | pred_samples = revert_samples(model.sample_from_preds(preds, argmax=True)) 225 | pred_samples_grid = torchvision.utils.make_grid(pred_samples[:8, ...], 3) 226 | tb_logger.add_image('pred_samples/argmax', pred_samples_grid, global_step=step) 227 | 228 | if args.sample: 229 | samples = revert_samples(model.sample(config.train.sample_size, config.device)) 230 | samples_grid = torchvision.utils.make_grid(samples[:8, ...], 3) 231 | tb_logger.add_image('samples', samples_grid, global_step=step) 232 | 233 | # Argmax samples are not useful for unconditional generation 234 | # if config.model.distr == "cat": 235 | # argmax_samples = model.sample(1, config.device, argmax=True) 236 | # samples_grid = torchvision.utils.make_grid(argmax_samples[:8, ...], 3) 237 | # tb_logger.add_image('argmax_samples', samples_grid, global_step=step) 238 | torch.save(model.state_dict(), os.path.join('transformer_logs', args.doc, "model.pth")) 239 | step += 1 240 | 241 | return 0 242 | 243 | if __name__ == '__main__': 244 | sys.exit(main()) 245 | --------------------------------------------------------------------------------