├── LICENSE.txt ├── README.md ├── cifar10_gen.py ├── experiment.sh ├── generic_utils.py ├── install_dependencies.sh ├── layers.py ├── mnist_gen.py ├── models.py ├── output ├── generated_only_images.jpg └── training_images.jpg └── plot_images.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Kundan Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating images pixel by pixel 2 | ###### Theano implementation of pixelCNN architecture 3 | ### This repository contains code for training an image generator using a slight variant of the pixelCNN architecture as described in [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328) 4 | 5 | Most of the code is in core theano. 'keras' has been used for loading data. Optimizer implementation from 'lasagne' has been used. 6 | 7 | Dependencies: 8 | 9 | [theano](http://deeplearning.net/software/theano/install.html) 10 | 11 | [lasagne](http://lasagne.readthedocs.io/en/latest/user/installation.html) 12 | 13 | [keras](http://keras.io/#getting-started-30-seconds-to-keras) 14 | 15 | You can use [experiments.sh](experiments.sh) to train the model and [install_dependencies.sh](install_dependencies.sh) to install the dependencies. 16 | 17 | Notes on results: 18 | 19 | 1. Images with 2-bit depth has been considered for training as well as generation e.g. every pixel is quantized into four levels and then used for training. Four-way softmax has been used to predict pixel quantization. 20 | 21 | 2. Following is the result after 60 epochs of training which got completed in about 10 hrs on K6000 Gpu. No hyper parameter search has been performed. 22 | 23 | Generated images 24 | 25 | ![Generated images](output/generated_only_images.jpg) 26 | 27 | Training images 28 | 29 | ![Training images](output/training_images.jpg) 30 | 31 | 32 | 33 | 34 | Salient features: No blind spots, efficient implemenattion of vertical stacks and horizontal stacks, residual connections and good generation results :D 35 | 36 | 37 | For any comments/feedback, feel free to email me at kundankumar2510@gmail.com or open an issue here. 38 | 39 | TODO: Implement gated activation and conditional generation. 40 | 41 | If you have GPU resources, feel free to train on CIFAR10. I have provided training script for that. Let me know how it goes. 42 | Also, one can train with 256-way softmax and perform hyperparameter search on MNIST dataset. 43 | -------------------------------------------------------------------------------- /cifar10_gen.py: -------------------------------------------------------------------------------- 1 | from keras.datasets import cifar10 2 | import numpy 3 | from generic_utils import * 4 | from models import Model 5 | from layers import WrapperLayer, pixelConv, Softmax 6 | import theano 7 | import theano.tensor as T 8 | import lasagne 9 | import random 10 | from plot_images import plot_25_figure 11 | 12 | DIM = 32 13 | GRAD_CLIP = 1. 14 | Q_LEVELS = 256 15 | BATCH_SIZE = 20 16 | PRINT_EVERY = 250 17 | EPOCH = 100 18 | 19 | OUT_DIR = '/Tmp/kumarkun/cifar10' 20 | create_folder_if_not_there(OUT_DIR) 21 | 22 | model = Model(name = "CIFAR10.pixelCNN") 23 | 24 | 25 | is_train = T.scalar() 26 | X = T.tensor4('X') # shape: (batchsize, channels, height, width) 27 | X_r = T.itensor4('X_r') 28 | 29 | X_transformed = X_r.dimshuffle(0,2,3,1) 30 | input_layer = WrapperLayer(X.dimshuffle(0,2,3,1)) # input reshaped to (batchsize, height, width,3) 31 | 32 | pixel_CNN = pixelConv( 33 | input_layer, 34 | 3, 35 | DIM, 36 | Q_LEVELS = Q_LEVELS, 37 | name = model.name + ".pxCNN", 38 | num_layers = 12, 39 | ) 40 | 41 | model.add_layer(pixel_CNN) 42 | 43 | output_probab = Softmax(pixel_CNN).output() 44 | 45 | cost = T.nnet.categorical_crossentropy( 46 | output_probab.reshape((-1,output_probab.shape[output_probab.ndim - 1])), 47 | X_r.flatten() 48 | ).mean() 49 | # in nats 50 | output_image = sample_from_softmax(output_probab) 51 | 52 | model.print_params() 53 | 54 | params = model.get_params() 55 | 56 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 57 | grads = [T.clip(g, floatX(-GRAD_CLIP), floatX(GRAD_CLIP)) for g in grads] 58 | 59 | # learning_rate = T.scalar('learning_rate') 60 | 61 | updates = lasagne.updates.adam(grads, params, learning_rate = 1e-3) 62 | 63 | train_fn = theano.function([X, X_r], cost, updates = updates) 64 | 65 | valid_fn = theano.function([X, X_r], cost) 66 | 67 | generate_routine = theano.function([X], output_image) 68 | 69 | def generate_fn(generate_routine, HEIGHT, WIDTH, num): 70 | X = floatX(numpy.zeros((num, 3, HEIGHT, WIDTH))) 71 | out = numpy.zeros((num,HEIGHT, WIDTH, 3)) 72 | 73 | for i in range(HEIGHT): 74 | for j in range(WIDTH): 75 | samples = generate_routine(X) 76 | out[:,i,j] = samples[:,i,j] 77 | X[:,:,i,j] = downscale_images(samples[:,i,j,:], Q_LEVELS - 1) 78 | 79 | return out 80 | 81 | (X_train_r, _), (X_test_r, _) = cifar10.load_data() 82 | 83 | X_train_r = upscale_images(downscale_images(X_train_r, 256), Q_LEVELS) 84 | X_test_r = upscale_images(downscale_images(X_test_r, 256), Q_LEVELS) 85 | 86 | X_train = downscale_images(X_train_r, Q_LEVELS - 1) 87 | X_test = downscale_images(X_test_r, Q_LEVELS - 1) 88 | 89 | errors = {'training' : [], 'validation' : []} 90 | 91 | num_iters = 0 92 | # init_learning_rate = floatX(0.001) 93 | 94 | print "Training" 95 | for i in range(EPOCH): 96 | """Training""" 97 | costs = [] 98 | num_batch_train = len(X_train)//BATCH_SIZE 99 | for j in range(num_batch_train): 100 | 101 | cost = train_fn( 102 | X_train[j*BATCH_SIZE: (j+1)*BATCH_SIZE], 103 | X_train_r[j*BATCH_SIZE: (j+1)*BATCH_SIZE] 104 | ) 105 | 106 | costs.append(cost) 107 | 108 | num_iters += 1 109 | 110 | if (j+1) % PRINT_EVERY == 0: 111 | print ("Training: epoch {}, iter {}, cost {}".format(i,j+1,numpy.mean(costs))) 112 | 113 | print("Training cost for epoch {}: {}".format(i+1, numpy.mean(costs))) 114 | errors['training'].append(numpy.mean(costs)) 115 | 116 | costs = [] 117 | num_batch_valid = len(X_test)//BATCH_SIZE 118 | 119 | for j in range(num_batch_valid): 120 | cost = valid_fn( 121 | X_test[j*BATCH_SIZE: (j+1)*BATCH_SIZE], 122 | X_test_r[j*BATCH_SIZE: (j+1)*BATCH_SIZE] 123 | ) 124 | costs.append(cost) 125 | 126 | if (j+1) % PRINT_EVERY == 0: 127 | print ("Validation: epoch {}, iter {}, cost {}".format(i,j+1,numpy.mean(costs))) 128 | 129 | model.save_params('{}/epoch_{}_val_error_{}.pkl'.format(OUT_DIR,i, numpy.mean(costs))) 130 | 131 | X = generate_fn(generate_routine, 32, 32, 25) 132 | 133 | reconstruction = generate_routine(X_test[:25]) 134 | 135 | plot_25_figure(X, '{}/epoch_{}_val_error_{}_gen_images.jpg'.format(OUT_DIR, i, numpy.mean(costs)), num_channels = 3) 136 | plot_25_figure(reconstruction, '{}/epoch_{}_reconstructed.jpg'.format(OUT_DIR, i), num_channels = 3) 137 | 138 | print("Validation cost after epoch {}: {}".format(i+1, numpy.mean(costs))) 139 | errors['validation'].append(numpy.mean(costs)) 140 | 141 | if i % 2 == 0: 142 | save(errors, '{}/epoch_{}_NLL.pkl'.format(OUT_DIR, i)) 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /experiment.sh: -------------------------------------------------------------------------------- 1 | THEANO_FLAGS=mode=FAST_RUN,device=gpu0,floatX=float32 python mnist_gen.py 2 | -------------------------------------------------------------------------------- /generic_utils.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import os 4 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 5 | import pickle 6 | import json 7 | import numpy 8 | 9 | srng = RandomStreams(seed=4884) 10 | def create_folder_if_not_there(folder): 11 | if not os.path.exists(folder): 12 | os.makedirs(folder) 13 | print "Created folder {}".format(folder) 14 | 15 | def floatX(num): 16 | if theano.config.floatX == 'float32': 17 | return numpy.float32(num) 18 | else: 19 | raise Exception("{} type not supported".format(theano.config.floatX)) 20 | 21 | 22 | def downscale_images(X, LEVEL): 23 | X = floatX(X)/floatX(LEVEL) 24 | return X 25 | 26 | def upscale_images(X, LEVEL): 27 | X = numpy.uint8(X*LEVEL) 28 | return X 29 | 30 | def stochastic_binarize(X): 31 | return (numpy.random.uniform(size=X.shape) < X).astype('float32') 32 | 33 | def sample_from_softmax(softmax_var): 34 | #softmax_var assumed to be of shape (batch_size, num_classes) 35 | old_shape = softmax_var.shape 36 | 37 | softmax_var_reshaped = softmax_var.reshape((-1,softmax_var.shape[softmax_var.ndim-1])) 38 | 39 | return T.argmax( 40 | T.cast( 41 | srng.multinomial(pvals=softmax_var_reshaped), 42 | theano.config.floatX 43 | ).reshape(old_shape), 44 | axis = softmax_var.ndim-1 45 | ) 46 | 47 | 48 | # 49 | def Skew(inputs, WIDTH, HEIGHT): 50 | """ 51 | input.shape: (batch size, HEIGHT, WIDTH, num_channels) 52 | """ 53 | buf = T.zeros( 54 | (inputs.shape[0], inputs.shape[1], 2*inputs.shape[2] - 1, inputs.shape[3]), 55 | theano.config.floatX 56 | ) 57 | 58 | for i in xrange(HEIGHT): 59 | buf = T.inc_subtensor(buf[:, i, i:i+WIDTH, :], inputs[:,i,:,:]) 60 | 61 | return buf 62 | 63 | def Unskew(padded, WIDTH, HEIGHT): 64 | """ 65 | input.shape: (batch size, HEIGHT, 2*WIDTH - 1, num_channels) 66 | """ 67 | return T.stack([padded[:, i, i:i+WIDTH, :] for i in xrange(HEIGHT)], axis=1) 68 | 69 | def new_learning_time_decay(learning_rate, iter_num, k): 70 | return floatX(learning_rate/(1.0+ iter_num*k)) 71 | 72 | # 73 | 74 | def load(file_name): 75 | open_file = open(file_name, 'rb') 76 | if ".json" in file_name: 77 | obj = json.load(open_file) 78 | elif ".pkl" in file_name: 79 | obj = pickle.load(open_file) 80 | open_file.close() 81 | return obj 82 | 83 | def save(obj, file_name): 84 | open_file = open(file_name, 'wb') 85 | if ".json" in file_name: 86 | json.dump(obj,open_file) 87 | elif ".pkl" in file_name: 88 | pickle.dump(obj, open_file) 89 | open_file.close() 90 | 91 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | pip install theano --user 2 | pip install keras --user 3 | pip install lasagne --user 4 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy 4 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 5 | from theano.sandbox.cuda.dnn import dnn_conv 6 | from generic_utils import * 7 | 8 | srng = RandomStreams(seed=3732) 9 | T.nnet.relu = lambda x: T.switch(x > floatX(0.), x, floatX(0.00001)*x) #this helps avoid Nan 10 | 11 | 12 | def uniform(stdev, size): 13 | """uniform distribution with the given stdev and size""" 14 | return numpy.random.uniform( 15 | low=-stdev * numpy.sqrt(3), 16 | high=stdev * numpy.sqrt(3), 17 | size=size 18 | ).astype(theano.config.floatX) 19 | 20 | def bias_weights(length, initialization='zeros', param_list = None, name = ""): 21 | "theano shared variable for bias unit, given length and initialization" 22 | if initialization == 'zeros': 23 | bias_initialization = numpy.zeros(length).astype(theano.config.floatX) 24 | else: 25 | raise Exception("Not Implemented Error: {} initialization not implemented".format(initialization)) 26 | 27 | bias = theano.shared( 28 | bias_initialization, 29 | name=name 30 | ) 31 | if param_list is not None: 32 | param_list.append(bias) 33 | 34 | return bias 35 | 36 | ''' 37 | get_conv_2d_filter: Takes a filter_shape (a tuple/array of length 4) and returns corresponding convolution filter 38 | masktype is optional. 39 | param_list is mandatory. It appends all the parameters from the function to this list 40 | ''' 41 | def get_conv_2d_filter(filter_shape, param_list = None, masktype = None, name = ""): 42 | fan_in = numpy.prod(filter_shape[1:]) 43 | fan_out = (filter_shape[0] * numpy.prod(filter_shape[2:])) 44 | w_std = numpy.sqrt(2.0 / (fan_in + fan_out)) 45 | 46 | filter_init = uniform(w_std, filter_shape) 47 | 48 | 49 | if masktype is not None: 50 | filter_init *= floatX(numpy.sqrt(2.)) 51 | 52 | conv_filter = theano.shared(filter_init, name = name) 53 | param_list.append(conv_filter) 54 | 55 | if masktype is not None: 56 | mask = numpy.ones( 57 | filter_shape, 58 | dtype=theano.config.floatX 59 | ) 60 | 61 | for i in range(filter_shape[2]): 62 | for j in range(filter_shape[3]): 63 | if i > filter_shape[2]//2: 64 | mask[:,:,i,j] = floatX(0.0) 65 | 66 | if i == filter_shape[2]//2 and j > filter_shape[3]//2: 67 | mask[:,:,i,j] = floatX(0.0) 68 | 69 | if masktype == 'a': 70 | mask[:,:,filter_shape[2]//2,filter_shape[3]//2] = floatX(0.0) 71 | 72 | conv_filter = conv_filter*mask 73 | 74 | return conv_filter 75 | 76 | class Layer: 77 | ''' 78 | Generic Layer Template which all layers should inherit. 79 | Every layer should have a name and params attribute containing all 80 | trainable parameters for that layer. 81 | ''' 82 | def __init__(name = ""): 83 | self.name = name 84 | self.params = [] 85 | 86 | def get_params(): 87 | return self.params 88 | 89 | 90 | class Conv2D(Layer): 91 | """ 92 | Basic convolution layer 93 | 94 | input_shape: (batch_size, input_channels, height, width) 95 | filter_size: int or (row, column) 96 | 97 | """ 98 | def __init__(self, input_layer, input_channels, output_channels, filter_size, subsample = (1,1), border_mode='half', masktype = None, activation = None, name = ""): 99 | self.X = input_layer.output() 100 | self.name = name 101 | self.subsample = subsample 102 | self.border_mode = border_mode 103 | 104 | self.params = [] 105 | 106 | if isinstance(filter_size, tuple): 107 | self.filter_shape = (output_channels, input_channels, filter_size[0], filter_size[1]) 108 | else: 109 | self.filter_shape = (output_channels, input_channels, filter_size, filter_size) 110 | 111 | self.filter = get_conv_2d_filter(self.filter_shape, param_list = self.params, masktype = masktype, name=name+'.filter') 112 | 113 | self.bias = bias_weights((output_channels,), param_list = self.params, name = name+'.b') 114 | 115 | self.activation = activation 116 | 117 | 118 | conv_out = T.nnet.conv2d(self.X, self.filter, border_mode = self.border_mode, filter_flip=False) 119 | self.Y = conv_out + self.bias[None,:,None,None] 120 | if self.activation is not None: 121 | if self.activation == 'relu': 122 | self.Y = T.nnet.relu(self.Y) 123 | elif self.activation == 'tanh': 124 | self.Y = T.tanh(self.Y) 125 | else: 126 | raise Exception("Not Implemented Error: {} activation not allowed".format(activation)) 127 | 128 | def output(self): 129 | return self.Y 130 | 131 | 132 | class pixelConv(Layer): 133 | """ 134 | This layer implements pixelCNN module which is mentioned in https://arxiv.org/abs/1606.05328 135 | Main diferences: activation is not gated in this implementation. 136 | 137 | Masking is not used except for the first horizontal stack. Instead, appropriate filter size 138 | with appropriate shifting of output feature maps used get the same effect as that of masking. This 139 | has been described in detail in the second paragraph of section 2.2 of the paper. There is no blind spots. 140 | 141 | There are four convolutions per as described in figure 2. Left output and input in this image corresponds 142 | to vertical feature map and right output/input corresponds to horizontal feature map. There is residual 143 | connection added on the horizontal stack 144 | 145 | input_shape: (batch_size, height, width, input_dim) 146 | output_shape: (batch_size, height, width, input_dim) 147 | when Q_LEVELS is None 148 | else 149 | (batch_size, height, width, input_dim, Q_LEVELS) 150 | """ 151 | def __init__(self, input_layer, input_dim, DIM, Q_LEVELS = None, num_layers = 6, activation='relu', name=""): 152 | 153 | if activation is None: 154 | apply_act = lambda r: r 155 | elif activation == 'relu': 156 | apply_act = T.nnet.relu 157 | elif activation == 'tanh': 158 | apply_act = T.tanh 159 | else: 160 | raise Exception("{} activation not implemented!!".format(activation)) 161 | 162 | 163 | self.X = input_layer.output().dimshuffle(0,3,1,2) 164 | 165 | ''' for first layer filter size should be 7 x 7 ''' 166 | filter_size = 7 # for first layer 167 | 168 | ''' 169 | masked filter_size x filter_size convolution for vertical stack effect can be achieved 170 | by just convolving the image with (filter_size // 2) + 1, filter_size) filter, 171 | padding filter_size // 2 + 1 rows and filter_size // 2 0s columns on both sides of images with 0s. 172 | 173 | This is easy to see that in this case first row in the ouput does not depend on the image, 174 | second row depends only on the first row of the image and so on. The final effect is anything in the i'th row 175 | of the output use information only upto i-1th row in the input. 176 | 177 | ''' 178 | vertical_stack = Conv2D( 179 | WrapperLayer(self.X), 180 | input_dim, 181 | DIM, 182 | ((filter_size // 2) + 1, filter_size), 183 | masktype=None, 184 | border_mode=(filter_size // 2 + 1, filter_size // 2), 185 | name= name + ".vstack1", 186 | activation = None 187 | ) 188 | 189 | out_v = vertical_stack.output() 190 | 191 | ''' 192 | while generating i'th row we can only use information upto i-1th row in the vertical stack. 193 | Horizontal stack gets input from vertical stack as well as previous layer. 194 | 195 | ''' 196 | vertical_and_input_stack = T.concatenate([out_v[:,:,:-(filter_size//2)-2,:], self.X], axis=1) 197 | 198 | '''horizontal stack is straight forward. For first layer, I have used masked convolution as 199 | we are not allowed to see the pixel we would generate. 200 | 201 | ''' 202 | 203 | horizontal_stack = Conv2D( 204 | WrapperLayer(vertical_and_input_stack), 205 | input_dim+DIM, DIM, 206 | (1,filter_size), 207 | border_mode = (0,filter_size//2), 208 | masktype='a', 209 | name = name + ".hstack1", 210 | activation = None 211 | ) 212 | 213 | self.params = vertical_stack.params + horizontal_stack.params 214 | 215 | X_h = horizontal_stack.output() #horizontal stack output 216 | X_v = out_v[:,:,1:-(filter_size//2) - 1,:] #vertical stack output 217 | 218 | filter_size = 3 #all layers beyond first has effective filtersize 3 219 | 220 | ''' 221 | one run of the loop implements four convolutions mentioned in figure 2 of the image 222 | with residual connection added on the horizontal stack. 223 | 224 | Two convolutions are just linear transformations of the faeture maps as convolution filter size is (1,1) 225 | ''' 226 | for i in range(num_layers - 2): 227 | vertical_stack = Conv2D( 228 | WrapperLayer(X_v), 229 | DIM, 230 | DIM, 231 | ((filter_size // 2) + 1, filter_size), 232 | masktype = None, 233 | border_mode = (filter_size // 2 + 1, filter_size // 2), 234 | name= name + ".vstack{}".format(i+2), 235 | activation = None 236 | ) 237 | v2h = Conv2D( 238 | vertical_stack, 239 | DIM, 240 | DIM, 241 | (1,1), 242 | masktype = None, 243 | border_mode = 'valid', 244 | name= name + ".v2h{}".format(i+2), 245 | activation = None 246 | ) 247 | out_v = v2h.output() 248 | vertical_and_prev_stack = T.concatenate([out_v[:,:,:-(filter_size//2)-2,:], X_h], axis=1) 249 | 250 | horizontal_stack = Conv2D( 251 | WrapperLayer(vertical_and_prev_stack), 252 | DIM*2, 253 | DIM, 254 | (1, (filter_size // 2) + 1), 255 | border_mode = (0, filter_size // 2), 256 | masktype = None, 257 | name = name + ".hstack{}".format(i+2), 258 | activation = activation 259 | ) 260 | 261 | h2h = Conv2D( 262 | horizontal_stack, 263 | DIM, 264 | DIM, 265 | (1, 1), 266 | border_mode = 'valid', 267 | masktype = None, 268 | name = name + ".h2hstack{}".format(i+2), 269 | activation = activation 270 | ) 271 | 272 | self.params += (vertical_stack.params + horizontal_stack.params + v2h.params + h2h.params) 273 | 274 | X_v = apply_act(vertical_stack.output()[:,:,1:-(filter_size//2) - 1,:]) 275 | X_h = h2h.output()[:,:,:,:-(filter_size//2)] + X_h #residual connection added 276 | 277 | '''single fully connected layer.''' 278 | 279 | combined_stack1 = Conv2D( 280 | WrapperLayer(X_h), 281 | DIM, 282 | DIM, 283 | (1, 1), 284 | masktype = None, 285 | border_mode = 'valid', 286 | name=name+".combined_stack1", 287 | activation = activation 288 | ) 289 | 290 | if Q_LEVELS is None: 291 | out_dim = input_dim 292 | else: 293 | out_dim = input_dim*Q_LEVELS 294 | 295 | combined_stack2 = Conv2D( 296 | combined_stack1, 297 | DIM, 298 | out_dim, 299 | (1, 1), 300 | masktype = None, 301 | border_mode = 'valid', 302 | name=name+".combined_stack2", 303 | activation = None 304 | ) 305 | 306 | self.params += (combined_stack1.params + combined_stack2.params) 307 | 308 | pre_final_out = combined_stack2.output().dimshuffle(0,2,3,1) 309 | 310 | if Q_LEVELS is None: 311 | self.Y = pre_final_out 312 | else: 313 | # pre_final_out = pre_final_out.dimshuffle(0,1,2,3,'x') 314 | old_shape = pre_final_out.shape 315 | self.Y = pre_final_out.reshape((old_shape[0], old_shape[1], old_shape[2], old_shape[3]//Q_LEVELS, -1)) 316 | 317 | def output(self): 318 | return self.Y 319 | 320 | class WrapperLayer(Layer): 321 | def __init__(self, X, name=""): 322 | self.params = [] 323 | self.name = name 324 | self.X = X 325 | 326 | def output(self): 327 | return self.X 328 | 329 | class Softmax(Layer): 330 | def __init__(self, input_layer, name=""): 331 | self.input_layer = input_layer 332 | self.name = name 333 | self.params = [] 334 | self.X = self.input_layer.output() 335 | self.input_shape = self.X.shape 336 | 337 | def output(self): 338 | return T.nnet.softmax(self.X.reshape((-1,self.input_shape[self.X.ndim-1]))).reshape(self.input_shape) 339 | 340 | -------------------------------------------------------------------------------- /mnist_gen.py: -------------------------------------------------------------------------------- 1 | from keras.datasets import mnist #for loading the dataset 2 | import numpy 3 | from generic_utils import * #has some utility functions 4 | from models import Model # class for collecting parameters fom layers and other stuffs 5 | from layers import WrapperLayer, pixelConv, Softmax # module that implements various layers 6 | import theano 7 | import theano.tensor as T 8 | import lasagne 9 | import random 10 | from plot_images import plot_25_figure, plot_100_figure # visualiztion of images 11 | from sys import argv 12 | 13 | DIM = 32 # feature map length for convolution layers 14 | GRAD_CLIP = 1. 15 | Q_LEVELS = 4 #level of quantization for channels, it has to be between 2 and 256 16 | TRAIN_BATCH_SIZE = 100 17 | VALIDATE_BATCH_SIZE = 200 18 | PRINT_EVERY = 100 #number of iterations after which stats are printed 19 | EPOCH = 1000 20 | 21 | PRETRAINED = True # whether to use a pre-trained weigths. if True, then argv[1] is assumed to be pickled weights of pre-trained model 22 | GENERATE_ONLY = True # whether to only generate samples. Useful when you want to generate from a pre-trained model 23 | 24 | OUT_DIR = '/Tmp/kumarkun/mnist_new_samples' # output folder for storing weights, generated and reconstructed samples 25 | create_folder_if_not_there(OUT_DIR) 26 | 27 | 28 | # Creating model. 29 | 30 | model = Model(name = "MNIST.pixelCNN") 31 | 32 | 33 | X = T.tensor3('X') # shape: (batchsize, height, width) 34 | X_r = T.itensor3('X_r') #shape: (batchsize, height, width) 35 | 36 | input_layer = WrapperLayer(X.dimshuffle(0,1,2,'x')) # input reshaped to (batchsize, height, width,1) 37 | 38 | pixel_CNN = pixelConv( 39 | input_layer, 40 | 1, 41 | DIM, 42 | name = model.name + ".pxCNN", 43 | num_layers = 12, 44 | Q_LEVELS = Q_LEVELS 45 | ) 46 | 47 | model.add_layer(pixel_CNN) 48 | 49 | output_probab = Softmax(pixel_CNN).output() 50 | 51 | # in nats 52 | cost = T.nnet.categorical_crossentropy( 53 | output_probab.reshape((-1,output_probab.shape[output_probab.ndim - 1])), 54 | X_r.flatten() 55 | ).mean() 56 | 57 | output_image = sample_from_softmax(output_probab) 58 | 59 | 60 | 61 | model.print_params() 62 | 63 | params = model.get_params() 64 | 65 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 66 | grads = [T.clip(g, floatX(-GRAD_CLIP), floatX(GRAD_CLIP)) for g in grads] 67 | 68 | 69 | 70 | updates = lasagne.updates.adam(grads, params, learning_rate = 1e-3) 71 | 72 | train_fn = theano.function([X, X_r], cost, updates = updates) 73 | 74 | valid_fn = theano.function([X, X_r], cost) 75 | 76 | generate_routine = theano.function([X], output_image) 77 | 78 | def generate_fn(generate_routine, HEIGHT, WIDTH, num): 79 | X = floatX(numpy.zeros((num, HEIGHT, WIDTH))) 80 | for i in range(HEIGHT): 81 | for j in range(WIDTH): 82 | samples = generate_routine(X) 83 | X[:,i,j] = downscale_images(samples[:,i,j,0], Q_LEVELS-1) 84 | 85 | return X 86 | 87 | 88 | if PRETRAINED: 89 | model.load_params(argv[1]) 90 | 91 | 92 | (X_train_r, _), (X_test_r, _) = mnist.load_data() 93 | 94 | ''' 95 | First, downscale images from 0-255 to [0,1), then upscale to 0-(Q_LEVELS-1). 96 | This quantized image becomes the target. 97 | Targets are again downscaled to [0,1] to get the inputs. 98 | ''' 99 | 100 | X_train_r = upscale_images(downscale_images(X_train_r, 256), Q_LEVELS) 101 | X_test_r = upscale_images(downscale_images(X_test_r, 256), Q_LEVELS) 102 | 103 | X_train = downscale_images(X_train_r, Q_LEVELS - 1) 104 | X_test = downscale_images(X_test_r, Q_LEVELS - 1) 105 | 106 | 107 | if GENERATE_ONLY: 108 | X = generate_fn(generate_routine, 28, 28, 100) 109 | plot_100_figure(X, '{}/generated_only_images.jpg'.format(OUT_DIR)) 110 | j = int(random.random() * (len(X_train) - 101)) 111 | plot_100_figure(X_train[j : j + 100], '{}/training_images.jpg'.format(OUT_DIR)) 112 | exit() 113 | 114 | 115 | 116 | errors = {'training' : [], 'validation' : []} 117 | 118 | num_iters = 0 119 | 120 | def validate(): 121 | costs = [] 122 | BATCH_SIZE = VALIDATE_BATCH_SIZE 123 | num_batch_valid = len(X_test)//BATCH_SIZE 124 | 125 | for j in range(num_batch_valid): 126 | cost = valid_fn(X_test[j*BATCH_SIZE: (j+1)*BATCH_SIZE], X_test_r[j*BATCH_SIZE: (j+1)*BATCH_SIZE]) 127 | costs.append(cost) 128 | 129 | return numpy.mean(costs) 130 | 131 | 132 | print "Training" 133 | for i in range(EPOCH): 134 | """Training""" 135 | costs = [] 136 | BATCH_SIZE = TRAIN_BATCH_SIZE 137 | num_batch_train = len(X_train)//BATCH_SIZE 138 | for j in range(num_batch_train): 139 | 140 | cost = train_fn(X_train[j*BATCH_SIZE: (j+1)*BATCH_SIZE], X_train_r[j*BATCH_SIZE: (j+1)*BATCH_SIZE]) 141 | 142 | costs.append(cost) 143 | 144 | num_iters += 1 145 | 146 | if (j+1) % PRINT_EVERY == 0: 147 | print ("Training: epoch {}, iter {}, cost {}".format(i,j+1,numpy.mean(costs))) 148 | 149 | print("Training cost for epoch {}: {}".format(i+1, numpy.mean(costs))) 150 | errors['training'].append(numpy.mean(costs)) 151 | 152 | val_error = validate() 153 | errors['validation'].append(val_error) 154 | 155 | model.save_params('{}/epoch_{}_val_error_{}.pkl'.format(OUT_DIR,i, val_error)) #parameters are saved after every epoch 156 | 157 | X = generate_fn(generate_routine, 28, 28, 25) # 25 images are generated after every epoch 158 | 159 | reconstruction = generate_routine(X_test[:25])[:,:,:,0] 160 | 161 | plot_25_figure(X, '{}/epoch_{}_val_error_{}_gen_images.jpg'.format(OUT_DIR, i, val_error)) 162 | plot_25_figure(reconstruction, '{}/epoch_{}_reconstructed.jpg'.format(OUT_DIR, i)) 163 | 164 | print("Validation cost after epoch {}: {}".format(i+1, val_error)) 165 | 166 | if i % 2 == 0: 167 | save(errors, '{}/epoch_{}_NLL.pkl'.format(OUT_DIR, i)) #NLL upto ith epoch stored after every 2 epochs. Too much redundancy here. 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | class Model: 5 | def __init__(self, name=""): 6 | self.name = name 7 | self.layers = [] 8 | self.params = [] 9 | 10 | def add_layer(self,layer): 11 | self.layers.append(layer) 12 | for p in layer.params: 13 | self.params.append(p) 14 | 15 | def print_layers(self): 16 | for layer in self.layers: 17 | print layer.name 18 | 19 | def get_params(self): 20 | return self.params 21 | 22 | def print_params(self): 23 | total_params = 0 24 | for p in self.params: 25 | curr_params = numpy.prod(numpy.shape(p.get_value())) 26 | total_params += curr_params 27 | print "{} ({})".format(p.name, curr_params) 28 | print ("total number of parameters: {}".format(total_params)) 29 | print ("Note: Effective number of parameters might be less due to masking!!") 30 | 31 | def save_params(self, file_name): 32 | params = {} 33 | for p in self.params: 34 | params[p.name] = p.get_value() 35 | pickle.dump(params, open(file_name, 'wb')) 36 | 37 | def load_params(self, file_name): 38 | params = pickle.load(open(file_name, 'rb')) 39 | for p in self.params: 40 | p.set_value(params[p.name]) 41 | -------------------------------------------------------------------------------- /output/generated_only_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundan2510/pixelCNN/2cbb5a2db9260d3f9aa0edcc4b49c874b9ac3067/output/generated_only_images.jpg -------------------------------------------------------------------------------- /output/training_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundan2510/pixelCNN/2cbb5a2db9260d3f9aa0edcc4b49c874b9ac3067/output/training_images.jpg -------------------------------------------------------------------------------- /plot_images.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import scipy.misc 3 | import numpy as np 4 | from sys import argv 5 | 6 | def plot_25_figure(images, output_name, num_channels = 1): 7 | HEIGHT, WIDTH = images.shape[1], images.shape[2] 8 | if num_channels == 1: 9 | images = images.reshape((5,5,HEIGHT,WIDTH)) 10 | # rowx, rowy, height, width -> rowy, height, rowx, width 11 | images = images.transpose(1,2,0,3) 12 | images = images.reshape((5*28, 5*28)) 13 | scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(output_name) 14 | elif num_channels == 3: 15 | images = images.reshape((5,5,HEIGHT,WIDTH,3)) 16 | images = images.transpose(1,2,0,3,4) 17 | images = images.reshape((5*HEIGHT, 5*WIDTH, 3)) 18 | scipy.misc.toimage(images).save(output_name) 19 | else: 20 | raise Exception("You should not be here!! Only 1 or 3 channels allowed for images!!") 21 | 22 | 23 | def plot_100_figure(images, output_name, num_channels = 1): 24 | HEIGHT, WIDTH = images.shape[1], images.shape[2] 25 | if num_channels == 1: 26 | images = images.reshape((10,10,HEIGHT,WIDTH)) 27 | # rowx, rowy, height, width -> rowy, height, rowx, width 28 | images = images.transpose(1,2,0,3) 29 | images = images.reshape((10*28, 10*28)) 30 | scipy.misc.toimage(images, cmin=0.0, cmax=1.0).save(output_name) 31 | elif num_channels == 3: 32 | images = images.reshape((10,10,HEIGHT,WIDTH,3)) 33 | images = images.transpose(1,2,0,3,4) 34 | images = images.reshape((10*HEIGHT, 10*WIDTH, 3)) 35 | scipy.misc.toimage(images).save(output_name) 36 | else: 37 | raise Exception("You should not be here!! Only 1 or 3 channels allowed for images!!") 38 | 39 | 40 | if __name__ == "__main__": 41 | X = pickle.load(open(argv[1],'rb')) 42 | output_name = argv[1].split('/')[-1].split('.')[0] + '.jpg' 43 | plot_25_figure(X, output_name) 44 | 45 | --------------------------------------------------------------------------------