├── .gitignore ├── setup.sh ├── LICENSE ├── tensorflow ├── tools.py ├── rot_mnist12K_model.py └── rot_mnist12K.py ├── torch ├── rot_mnist12K_model.lua ├── tools.lua └── rot_mnist12K.lua └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.amat 2 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip 3 | unzip mnist_rotation_new.zip 4 | rm mnist_rotation_new.zip 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Dmitry Laptev http://laptev.ch, Nikolay Savinov 4 | http://people.inf.ethz.ch/nsavinov/, ETH Zurich http://ethz.ch. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /tensorflow/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from scipy.ndimage.interpolation import rotate 4 | 5 | class DataLoader: 6 | def __init__(self, 7 | name, 8 | number_of_classes, 9 | number_of_transformations, 10 | loaded_size, 11 | desired_size, 12 | max_size=None): 13 | loaded = np.loadtxt(name) 14 | if max_size is not None: 15 | subset = np.random.choice(loaded.shape[0], max_size, replace=False) 16 | loaded = loaded[subset, :] 17 | padded_x = self._pad(loaded[:, :-1], loaded_size, desired_size) 18 | self._x = self._transform(padded_x, number_of_transformations) 19 | self._y = self._int_labels_to_one_hot(loaded[:, -1], number_of_classes) 20 | self._completed_epochs = -1 21 | self._new_epoch = False 22 | self._start_new_epoch() 23 | 24 | def _pad(self, loaded_x, loaded_size, desired_size): 25 | padding_size = (desired_size - loaded_size) / 2 26 | padding_list = [[0, 0], 27 | [padding_size, padding_size], 28 | [padding_size, padding_size], 29 | [0, 0]] 30 | return np.pad(np.reshape(loaded_x, [-1, loaded_size, loaded_size, 1]), 31 | padding_list, 32 | 'constant', 33 | constant_values=0) 34 | 35 | def _transform(self, padded, number_of_transformations): 36 | tiled = np.tile(np.expand_dims(padded, 4), [number_of_transformations]) 37 | for transformation_index in xrange(number_of_transformations): 38 | angle = 360.0 * transformation_index / float(number_of_transformations) 39 | tiled[:, :, :, :, transformation_index] = rotate( 40 | tiled[:, :, :, :, transformation_index], 41 | angle, 42 | axes=[1, 2], 43 | reshape=False) 44 | print('finished transforming') 45 | return tiled 46 | 47 | def _int_labels_to_one_hot(self, int_labels, number_of_classes): 48 | offsets = np.arange(self._size()) * number_of_classes 49 | one_hot_labels = np.zeros((self._size(), number_of_classes)) 50 | flat_iterator = one_hot_labels.flat 51 | for index in xrange(self._size()): 52 | flat_iterator[offsets[index] + int(int_labels[index])] = 1 53 | return one_hot_labels 54 | 55 | def _size(self): 56 | return self._x.shape[0] 57 | 58 | def _start_new_epoch(self): 59 | permuted_indexes = np.random.permutation(self._size()) 60 | self._x = self._x[permuted_indexes, :] 61 | self._y = self._y[permuted_indexes] 62 | self._completed_epochs += 1 63 | self._index = 0 64 | self._new_epoch = True 65 | 66 | def get_completed_epochs(self): 67 | return self._completed_epochs 68 | 69 | def is_new_epoch(self): 70 | return self._new_epoch 71 | 72 | def next_batch(self, batch_size): 73 | if (self._new_epoch): 74 | self._new_epoch = False 75 | start = self._index 76 | end = start + batch_size 77 | if (end > self._size()): 78 | assert batch_size <= self._size() 79 | self._start_new_epoch() 80 | start = 0 81 | end = start + batch_size 82 | self._index += batch_size 83 | return self._x[start:end, :], self._y[start:end] 84 | 85 | def all(self): 86 | return self._x, self._y 87 | -------------------------------------------------------------------------------- /torch/rot_mnist12K_model.lua: -------------------------------------------------------------------------------- 1 | -- Initialize dataset and method properties. 2 | function define_constants() 3 | torch.setdefaulttensortype('torch.DoubleTensor') 4 | torch.manualSeed(1) 5 | 6 | local opt = {} 7 | opt.input_size = 32 8 | opt.real_size = 28 9 | opt.n_transformations = 24 -- TI-pooling parameter. 10 | 11 | opt.printing_interval = 2 -- Debug parameter 12 | opt.model_dump_name = 'saved_model' 13 | 14 | -- Optimization parameters. 15 | opt.batch_size = 64 16 | opt.weight_decay = 0 -- Regularization (0 means "not used"). 17 | opt.adadelta_rho = 0.9 18 | opt.adadelta_eps = 1e-6 19 | opt.decrease_step_size = 200 -- Decrease step size every 200 epochs. 20 | 21 | return opt 22 | end 23 | 24 | -- Define the topology of the network. TI-pooling takes the maximum of a 25 | -- feature over the transformed instances. 26 | function define_model(input_size, n_transformations, n_outputs) 27 | local fully_connected_multiplier = 128 28 | local model = nn.Sequential() 29 | local number_of_filters = 40 30 | 31 | -- Standard model definition: stacked convolutions, ReLU and max-pooling. 32 | model:add(nn.SpatialConvolution(1, number_of_filters, 33 | 3, 3, 1, 1, 1, 1)) 34 | model:add(nn.ReLU()) 35 | model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 36 | model:add(nn.SpatialConvolution(number_of_filters, 2*number_of_filters, 37 | 3, 3, 1, 1, 1, 1)) 38 | model:add(nn.ReLU()) 39 | model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 40 | model:add(nn.SpatialConvolution(2*number_of_filters, 4*number_of_filters, 41 | 3, 3, 1, 1, 1, 1)) 42 | model:add(nn.ReLU()) 43 | model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 44 | model:add(nn.Reshape(4*number_of_filters*input_size*input_size / (4*4*4))) 45 | model:add(nn.Linear(4*number_of_filters*input_size*input_size / (4*4*4), 46 | fully_connected_multiplier*number_of_filters)) 47 | model:add(nn.ReLU()) 48 | model:add(nn.Reshape(1, fully_connected_multiplier*number_of_filters, 1)) 49 | 50 | -- Put siamese replicas in parallel (replicate n_transformations times). 51 | local parallel_model = nn.Parallel(2, 4) 52 | for rotation_index = 1, n_transformations do 53 | parallel_model:add(model:clone()) 54 | end 55 | 56 | -- TI-pooling (transformation-invariance pooling) layer. 57 | local full_model = nn.Sequential() 58 | full_model:add(parallel_model) 59 | -- Take the maximum output of siamese replicas over the transformations. 60 | full_model:add(nn.SpatialMaxPooling(n_transformations, 1, 1, 1)) 61 | 62 | -- Add fully-connected and output layers on top of TI-pooling features. 63 | full_model:add(nn.Reshape(fully_connected_multiplier*number_of_filters)) 64 | full_model:add(nn.Dropout()) 65 | full_model:add(nn.Linear(fully_connected_multiplier*number_of_filters, 66 | n_outputs)) 67 | full_model:add(nn.LogSoftMax()) 68 | full_model = full_model:cuda() 69 | 70 | -- Share all the parameters between siamese replicas. 71 | parallel_model = full_model:get(1) 72 | for rotation_index = 2, n_transformations do 73 | local current_module = parallel_model:get(rotation_index) 74 | current_module:share(parallel_model:get(1), 'weight', 'bias', 75 | 'gradWeight', 'gradBias') 76 | end 77 | 78 | full_model:training() 79 | return full_model 80 | end 81 | -------------------------------------------------------------------------------- /torch/tools.lua: -------------------------------------------------------------------------------- 1 | -- This file contains the following functions: 2 | -- load_rotated_mnist(file_name, count) 3 | -- get_transformed(batch_inputs, opt) 4 | -- calculate_error(model, data_to_check, opt) 5 | 6 | -- Loads the dataset from an .amat file. 7 | function load_rotated_mnist(file_name, count) 8 | local loaded_data = {} 9 | for line in io.lines(file_name) do 10 | local chunks = {} 11 | for w in line:gmatch("%S+") do chunks[#chunks + 1] = tonumber(w) end 12 | loaded_data[#loaded_data + 1] = chunks 13 | end 14 | local loaded_data = torch.Tensor(loaded_data) 15 | local data = {} 16 | data.data = loaded_data[{{1, count}, {1, -2}}] 17 | data.labels = loaded_data[{{1, count}, {-1, -1}}] 18 | local shuffled_indices = torch.randperm(data.data:size(1)):long() 19 | data.data = data.data:index(1, shuffled_indices) 20 | data.labels = data.labels:index(1, shuffled_indices) 21 | data.labels:add(1) 22 | local real_size = math.sqrt(data.data:size(2)) 23 | data.data = data.data:reshape(data.data:size(1), 1, real_size, real_size) 24 | print('--------------------------------') 25 | print('inputs', data.data:size()) 26 | print('targets', data.labels:size()) 27 | print('min target', data.labels:min()) 28 | print('max target', data.labels:max()) 29 | print('--------------------------------') 30 | return data 31 | end 32 | 33 | -- Augments the tensor along the second dimension with transformed instances 34 | -- (rotation is used here, but various transformations can be used). 35 | function get_transformed(batch_inputs, opt) 36 | local st = torch.LongStorage(5) 37 | st[1] = batch_inputs:size(1) -- the number of images 38 | st[2] = opt.n_transformations 39 | st[3] = 1 -- the number of channels is 1 40 | st[4] = opt.input_size 41 | st[5] = opt.input_size 42 | local result = torch.Tensor(st) 43 | for index = 1, batch_inputs:size(1) do 44 | local padded_sample = torch.Tensor(opt.input_size, opt.input_size):zero() 45 | local offset = (opt.input_size - opt.real_size) / 2 46 | padded_sample[{{1 + offset, opt.input_size - offset}, 47 | {1 + offset, opt.input_size - offset}}] = 48 | batch_inputs[index]:squeeze() 49 | padded_sample = padded_sample:t():contiguous() 50 | for angle_index = 1, opt.n_transformations do 51 | result[index][angle_index][1] = image.rotate(padded_sample, 52 | 2 * math.pi * angle_index / opt.n_transformations, 'bilinear') 53 | end 54 | end 55 | return result 56 | end 57 | 58 | -- Calculates the number of mispredictions of the trained model. 59 | function calculate_error(model, data_to_check, opt) 60 | model:evaluate() 61 | local data_size = data_to_check.data:size(1) 62 | local batches_per_dataset = math.ceil(data_size / opt.batch_size) 63 | local error = 0 64 | for batch_index = 0, (batches_per_dataset - 1) do 65 | local start_index = batch_index * opt.batch_size + 1 66 | local end_index = math.min(data_size, (batch_index + 1) * opt.batch_size) 67 | local batch_targets = 68 | data_to_check.labels[{{start_index, end_index},1}]:cuda() 69 | local transformed_batch = data_to_check.data[{{start_index, end_index}}] 70 | local batch_inputs = transformed_batch:cuda() 71 | local logProbs = model:forward(batch_inputs) 72 | local classProbabilities = torch.exp(logProbs) 73 | local _, max_inds = torch.max(classProbabilities, 2) 74 | classPredictions = torch.Tensor():resize(max_inds:size(1)) 75 | :copy(max_inds[{{1,max_inds:size(1)},1}]):cuda() 76 | error = error + classPredictions:ne(batch_targets):sum() 77 | end 78 | model:training() 79 | return error / data_size 80 | end 81 | 82 | -------------------------------------------------------------------------------- /tensorflow/rot_mnist12K_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | def conv2d(x, W): 5 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 6 | 7 | def max_pool_2x2(x): 8 | return tf.nn.max_pool(x, 9 | ksize=[1, 2, 2, 1], 10 | strides=[1, 2, 2, 1], 11 | padding='SAME') 12 | 13 | # xavier-like initializer 14 | def weights_biases(kernel_shape, bias_shape): 15 | in_variables = 1 16 | for index in xrange(len(kernel_shape) - 1): 17 | in_variables *= kernel_shape[index] 18 | stdv = 1.0 / math.sqrt(in_variables) 19 | weights = tf.get_variable( 20 | 'weights', 21 | kernel_shape, 22 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 23 | biases = tf.get_variable( 24 | 'biases', 25 | bias_shape, 26 | initializer=tf.random_uniform_initializer(-stdv, stdv)) 27 | return weights, biases 28 | 29 | def conv_relu_maxpool(input, kernel_shape, bias_shape): 30 | weights, biases = weights_biases(kernel_shape, bias_shape) 31 | return max_pool_2x2(tf.nn.relu(conv2d(input, weights) + biases)) 32 | 33 | def fc_relu(input, kernel_shape, bias_shape): 34 | weights, biases = weights_biases(kernel_shape, bias_shape) 35 | return tf.nn.relu(tf.matmul(input, weights) + biases) 36 | 37 | def fc(input, kernel_shape, bias_shape): 38 | weights, biases = weights_biases(kernel_shape, bias_shape) 39 | return tf.matmul(input, weights) + biases 40 | 41 | # x should already be reshaped as a 32x32x1 image 42 | def single_branch(x, number_of_filters, number_of_fc_features): 43 | with tf.variable_scope('conv1'): 44 | max_pool1 = conv_relu_maxpool(x, 45 | [3, 46 | 3, 47 | 1, 48 | number_of_filters], 49 | [number_of_filters]) 50 | with tf.variable_scope('conv2'): 51 | max_pool2 = conv_relu_maxpool(max_pool1, 52 | [3, 53 | 3, 54 | number_of_filters, 55 | 2 * number_of_filters], 56 | [2 * number_of_filters]) 57 | with tf.variable_scope('conv3'): 58 | max_pool3 = conv_relu_maxpool(max_pool2, 59 | [3, 60 | 3, 61 | 2 * number_of_filters, 62 | 4 * number_of_filters], 63 | [4 * number_of_filters]) 64 | flattened_size = ((32 / 8) ** 2) * 4 * number_of_filters 65 | flattened = tf.reshape(max_pool3, [-1, flattened_size]) 66 | with tf.variable_scope('fc1'): 67 | fc1 = fc_relu(flattened, 68 | [flattened_size, number_of_fc_features], 69 | [number_of_fc_features]) 70 | return fc1 71 | 72 | # x are batches nx32x32x1xnumber_of_transformations 73 | def define_model(x, 74 | keep_prob, 75 | number_of_classes, 76 | number_of_filters, 77 | number_of_fc_features): 78 | splitted = tf.unpack(x, axis=4) 79 | branches = [] 80 | with tf.variable_scope('branches') as scope: 81 | for index, tensor_slice in enumerate(splitted): 82 | branches.append(single_branch(splitted[index], 83 | number_of_filters, 84 | number_of_fc_features)) 85 | if (index == 0): 86 | scope.reuse_variables() 87 | concatenated = tf.pack(branches, axis=2) 88 | ti_pooled = tf.reduce_max(concatenated, reduction_indices=[2]) 89 | drop = tf.nn.dropout(ti_pooled, keep_prob) 90 | with tf.variable_scope('fc2'): 91 | logits = fc(drop, 92 | [number_of_fc_features, number_of_classes], 93 | [number_of_classes]) 94 | return logits 95 | -------------------------------------------------------------------------------- /tensorflow/rot_mnist12K.py: -------------------------------------------------------------------------------- 1 | import tools 2 | import rot_mnist12K_model 3 | import tensorflow as tf 4 | import sys 5 | import numpy as np 6 | 7 | # TI-pooling example code for rot_mnist12k classification dataset. 8 | 9 | # The implementation mainly consists of two parts: 10 | # 1. the dataset is augmented with transformed samples 11 | # (see tools.DataLoader._transform); 12 | # 2. the model contains max pooling, that selects the maximum output of 13 | # siamese network replicas over the transformations 14 | # (see rot_mnist12K_model.define_model). 15 | 16 | # For further details and more experiments please refer to the original paper: 17 | # "TI-pooling: transformation-invariant pooling for feature learning in 18 | # Convolutional Neural Networks" 19 | # D. Laptev, N. Savinov, J.M. Buhmann, M. Pollefeys, CVPR 2016. 20 | 21 | # input data 22 | TRAIN_FILENAME = '../mnist_all_rotation_normalized_float_train_valid.amat' 23 | TEST_FILENAME = '../mnist_all_rotation_normalized_float_test.amat' 24 | LOADED_SIZE = 28 25 | DESIRED_SIZE = 32 26 | # model constants 27 | NUMBER_OF_CLASSES = 10 28 | NUMBER_OF_FILTERS = 40 29 | NUMBER_OF_FC_FEATURES = 5120 30 | NUMBER_OF_TRANSFORMATIONS = 24 31 | # optimization constants 32 | BATCH_SIZE = 64 33 | TEST_CHUNK_SIZE = 1000 34 | ADAM_LEARNING_RATE = 1e-4 35 | PRINTING_INTERVAL = 10 36 | # set seeds 37 | np.random.seed(100) 38 | tf.set_random_seed(100) 39 | # set up training graph 40 | x = tf.placeholder(tf.float32, shape=[None, 41 | DESIRED_SIZE, 42 | DESIRED_SIZE, 43 | 1, 44 | NUMBER_OF_TRANSFORMATIONS]) 45 | y_gt = tf.placeholder(tf.float32, shape=[None, NUMBER_OF_CLASSES]) 46 | keep_prob = tf.placeholder(tf.float32) 47 | logits = rot_mnist12K_model.define_model(x, 48 | keep_prob, 49 | NUMBER_OF_CLASSES, 50 | NUMBER_OF_FILTERS, 51 | NUMBER_OF_FC_FEATURES) 52 | cross_entropy = tf.reduce_mean( 53 | tf.nn.softmax_cross_entropy_with_logits(logits, y_gt)) 54 | train_step = tf.train.AdamOptimizer(ADAM_LEARNING_RATE).minimize(cross_entropy) 55 | correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_gt, 1)) 56 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 57 | # run training 58 | session = tf.Session() 59 | session.run(tf.initialize_all_variables()) 60 | train_data_loader = tools.DataLoader(TRAIN_FILENAME, 61 | NUMBER_OF_CLASSES, 62 | NUMBER_OF_TRANSFORMATIONS, 63 | LOADED_SIZE, 64 | DESIRED_SIZE) 65 | test_data_loader = tools.DataLoader(TEST_FILENAME, 66 | NUMBER_OF_CLASSES, 67 | NUMBER_OF_TRANSFORMATIONS, 68 | LOADED_SIZE, 69 | DESIRED_SIZE) 70 | test_size = test_data_loader.all()[1].shape[0] 71 | assert test_size % TEST_CHUNK_SIZE == 0 72 | number_of_test_chunks = test_size / TEST_CHUNK_SIZE 73 | while (True): 74 | batch = train_data_loader.next_batch(BATCH_SIZE) 75 | if (train_data_loader.is_new_epoch()): 76 | train_accuracy = session.run(accuracy, feed_dict={x : batch[0], 77 | y_gt : batch[1], 78 | keep_prob : 1.0}) 79 | print("completed_epochs %d, training accuracy %g" % 80 | (train_data_loader.get_completed_epochs(), train_accuracy)) 81 | sys.stdout.flush() 82 | if (train_data_loader.get_completed_epochs() % PRINTING_INTERVAL == 0): 83 | sum = 0.0 84 | for chunk_index in xrange(number_of_test_chunks): 85 | chunk = test_data_loader.next_batch(TEST_CHUNK_SIZE) 86 | sum += session.run(accuracy, feed_dict={x : chunk[0], 87 | y_gt : chunk[1], 88 | keep_prob : 1.0}) 89 | test_accuracy = sum / number_of_test_chunks 90 | print("testing accuracy %g" % test_accuracy) 91 | sys.stdout.flush() 92 | session.run(train_step, feed_dict={x : batch[0], 93 | y_gt : batch[1], 94 | keep_prob : 0.5}) 95 | 96 | 97 | -------------------------------------------------------------------------------- /torch/rot_mnist12K.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'math' 3 | require 'nn' 4 | require 'cunn' 5 | require 'image' 6 | require 'optim' 7 | 8 | require 'rot_mnist12K_model' 9 | require 'tools' 10 | 11 | -- TI-pooling example code for rot_mnist12k classification dataset. 12 | 13 | -- The implementation mainly consists of two parts: 14 | -- 1. the dataset is augmented with transformed samples (see get_transformed); 15 | -- 2. the model contains SpatialMaxPooling, that selects the maximum output of 16 | -- siamese network replicas over the transformations (see define_model). 17 | 18 | -- For further details and more experiments please refer to the original paper: 19 | -- "TI-pooling: transformation-invariant pooling for feature learning in 20 | -- Convolutional Neural Networks" 21 | -- D. Laptev, N. Savinov, J.M. Buhmann, M. Pollefeys, CVPR 2016. 22 | 23 | local opt = define_constants() 24 | 25 | -- Load and unzip "Rotated MNIST digits" dataset from the following address: 26 | -- http://www.iro.umontreal.ca/~lisa/twiki/bin/view.cgi/Public/MnistVariations 27 | local train_file = '../mnist_all_rotation_normalized_float_train_valid.amat' 28 | local test_file = '../mnist_all_rotation_normalized_float_test.amat' 29 | local train = load_rotated_mnist(train_file, -1) 30 | local test = load_rotated_mnist(test_file, -1) 31 | local n_train_data = train.data:size(1) -- Number of training samples. 32 | local n_outputs = train.labels:max() -- Number of classes. 33 | 34 | -- Augment the dataset with transformed (rotated) images. 35 | train.data = get_transformed(train.data, opt) 36 | test.data = get_transformed(test.data, opt) 37 | 38 | -- Define the model and the objective function. 39 | local model = define_model(opt.input_size, opt.n_transformations, n_outputs) 40 | local criterion = nn.ClassNLLCriterion() 41 | local criterion = criterion:cuda() 42 | 43 | local parameters, gradParameters = model:getParameters() 44 | local counter = 0 45 | local epoch = 1 46 | local new_epoch = false 47 | 48 | -- The value and the gradient of the functional on one batch. 49 | local batch_feval = function(x) 50 | -- Get the batch number and update the counters. 51 | if new_epoch then 52 | epoch = epoch + 1 53 | new_epoch = false 54 | end 55 | if x ~= parameters then 56 | parameters:copy(x) 57 | end 58 | local start_index = counter * opt.batch_size + 1 59 | local end_index = math.min((counter + 1) * opt.batch_size, n_train_data) 60 | if end_index == n_train_data then 61 | counter = 0 62 | new_epoch = true 63 | else 64 | counter = counter + 1 65 | end 66 | local batch_inputs = train.data[{{start_index, end_index}}]:cuda() 67 | local batch_targets = train.labels[{{start_index, end_index},1}]:cuda() 68 | 69 | -- Forward pass the batch, compute regularized loss. 70 | gradParameters:zero() 71 | local batch_outputs = model:forward(batch_inputs) 72 | local batch_loss = criterion:forward(batch_outputs, batch_targets) + 73 | opt.weight_decay * (parameters:norm()^2) / (2 * parameters:size(1)) 74 | 75 | -- Backward pass the loss, compute gradients. 76 | local dloss_doutput = criterion:backward(batch_outputs, batch_targets) 77 | model:backward(batch_inputs, dloss_doutput) 78 | gradParameters:add(parameters * opt.weight_decay / parameters:size(1)) 79 | 80 | return batch_loss, gradParameters 81 | end 82 | 83 | -- Main optimization cycle: iterate through epochs. 84 | local test_errors = {} 85 | local train_errors = {} 86 | local optim_state = {} 87 | optim_state.rho = opt.adadelta_rho 88 | optim_state.eps = opt.adadelta_eps 89 | while true do -- Cycle through the batches. 90 | if (counter == 0) and (epoch % opt.printing_interval == 0) then 91 | torch.manualSeed(100) 92 | train_errors[#train_errors + 1] = calculate_error(model, train, opt) 93 | test_errors[#test_errors + 1] = calculate_error(model, test, opt) 94 | print(string.format("epoch: %6s, train_error = %6.6f, test_error = %6.6f", 95 | epoch, train_errors[#train_errors], 96 | test_errors[#test_errors])) 97 | -- Save the model for testing or for further training. 98 | torch.save(opt.model_dump_name, model) 99 | torch.save(opt.model_dump_name .. '_state', optim_state) 100 | torch.manualSeed(epoch) 101 | end 102 | -- Make a step using AdaDelta optimization algorithm (updates parameters). 103 | optim.adadelta(batch_feval, parameters, optim_state) 104 | collectgarbage() 105 | end 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TI-pooling 2 | ========== 3 | 4 | This repository contains TensorFlow and Torch7 implementations of TI-pooling (transformation-invariant pooling) from the following paper: 5 | - "TI-pooling: transformation-invariant pooling for feature learning in Convolutional Neural Networks" D. Laptev, N. Savinov, J.M. Buhmann, M. Pollefeys, CVPR 2016. [Pdf](http://laptev.ch/files/laptev16_cvpr.pdf), bibtex: 6 | 7 | ``` 8 | @inproceedings{laptev2016ti, 9 | title={TI-POOLING: transformation-invariant pooling for feature learning in Convolutional Neural Networks}, 10 | author={Laptev, Dmitry and Savinov, Nikolay and Buhmann, Joachim M and Pollefeys, Marc}, 11 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 12 | pages={289--297}, 13 | year={2016} 14 | } 15 | ``` 16 | 17 | **Update February 2017.** TensorFlow implementation is ready! You can independently use either Torch7 or TensorFlow implementations or both: the code is structured very similarly. Scroll to "Instructions for Linux" for the details. 18 | 19 | The original paper provides experimental evaluation on three datasets. This repository contains source codes for one of these experiments: [mnist-rot dataset](http://www.iro.umontreal.ca/~lisa/twiki/bin/view.cgi/Public/MnistVariations), consisting of 12k training images of randomly rotated digits. 20 | 21 | ### What is TI-pooling? 22 | TI-pooling is a simple technique that allows to make a Convolutional Neural Networks (CNN) transformation-invariant. I.e. given a set of nuisance transformations (such as rotations, scale, shifts, illumination changes, etc.), TI-pooling guarantees that the output of the network will not to depend on whether the input image was transformed or not. 23 | 24 | ### Why TI-pooling and not data augmentation? 25 | Comparing to the very commonly used data augmentation, TI-pooling finds canonical orientation of input samples, and learns mostly from these samples. It means that the network does not have to learn different paths (features) for different transformations of the same object. This results in the following effects: 26 | * CNN with TI-pooling achieves similar or better results with smaller models. 27 | * Training is often done faster than for networks with augmentation. 28 | * It imposes internal regularization, making it harder to overfit. 29 | * It has theoretical guarantees on transformation-invariance. 30 | 31 | ### How does TI-pooling work? 32 | ![TI-pooling pipeline](https://img-fotki.yandex.ru/get/133056/10605357.9/0_907fc_3c7328bc_XL.png "TI-pooling pipeline") 33 | 34 | * First, input image (a) is transformed according to the considered set of transformations to obtain a set of new image instances (b). 35 | * For every transformed image, a parallel instance of partial siamese network is initialized, consisting only of convolutional and subsampling layers (two copies are shown in the top and in the bottom of the figure). 36 | * Every instance is then passed through a sequence of convolutional (c, e) and subsampling layers (d), until the vector of scalars is not achieved (e). This vector of scalars is composed of image features learned by the network. 37 | * Then TI-pooling (element-wise maximum) (g) is applied on the feature vectors to obtain a vector of transformation-invariant features (h). 38 | * This vector then serves as an input to a fully-connected layer (i), possibly with dropout, and further propagates to the network output (j). 39 | * Because of the weight-sharing between parallel siamese layers, the actual model requires the same amount of memory as just one convolutional neural network. 40 | * TI-pooling ensures that the actual training of each features parameters is performed on the most representative instance. 41 | 42 | ### Any caveats? 43 | One needs to be really sure to introduce transformation-invariance: in some real-world problems some transformation can seem like an nuisance factor, but can be in fact useful. E.g. rotation-invariance does not work well for natural objects, because most natural objects have a "default" orientation, which helps us to recognize them (an upside-down animal is usually harder to recognize, not only for a CNN, but also for a human being). Same rotation-invariance proved to be very useful for cell recognition, where orientation is essentially random. 44 | 45 | Also, while training time is comparable and usually faster than with data augmentation, the testing time increases linearly with the number of transformations. 46 | 47 | ### Instructions for Linux 48 | First run `./setup.sh` to download the dataset, it will be stored in the root directory. Then, depending on the framework you want to use, navigate to the corresponding directory and start training by calling `rot_mnist12K` file. 49 | 50 | * For TensorFlow: `cd tensorflow; python rot_mnist12K.py` 51 | * For Torch7: `cd torch; th rot_mnist12K.lua` 52 | 53 | The code was tested for the following configuration: 54 | 55 | * TensorFlow version: 0.11.0rc0 with Python 2.7.13, NumPy 1.11.3, SciPy 0.18.1. 56 | * Nvidia Titan X, cuda/7.5.18, cudnn/v5.1. 57 | * Torch7 commit ed547376d552346afc69a937c6b36cf9ea9d1135 (12 September 2016). 58 | --------------------------------------------------------------------------------