├── README.md ├── images ├── after_stn.png ├── before_stn.png ├── formula.jpeg ├── network.jpeg └── transform.jpeg └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # [MOVED TO HERE](https://github.com/tensorlayer/tensorlayer/tree/master/examples/spatial_transformer_network) 2 | 3 | # Spatial Transformer Networks 4 | 5 | [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) (STN) is a dynamic mechanism that produces transformations of input images (or feature maps)including scaling, cropping, rotations, as well as non-rigid deformations. This enables the network to not only select regions of an image that are most relevant (attention), but also to transform those regions to simplify recognition in the following layers. 6 | 7 | Video for different transformation [click me](https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view). 8 | 9 | In this repositary, we implemented a STN for [2D Affine Transformation](https://en.wikipedia.org/wiki/Affine_transformation) on MNIST dataset. We generated images with size of 40x40 from the original MNIST dataset, and distorted the images by random rotation, shifting, shearing and zoom in/out. The STN was able to learn to automatically apply transformations on distorted images via classification task. 10 | 11 | 12 |
13 | 14 |
15 | Fig 1:Transformation 16 |
17 | 18 | 19 |
20 | 21 |
22 | Fig 2:Network 23 |
24 | 25 |
26 | 27 |
28 | Fig 3:Formula 29 |
30 | 31 | ## Result 32 | 33 | After classification task, the STN is able to transform the distorted image from Fig 4 back to Fig 5. 34 | 35 |
36 | 37 |
38 | Fig 4: Input 39 |
40 | 41 |
42 | 43 |
44 | Fig 5: Output 45 |
46 | 47 | -------------------------------------------------------------------------------- /images/after_stn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/after_stn.png -------------------------------------------------------------------------------- /images/before_stn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/before_stn.png -------------------------------------------------------------------------------- /images/formula.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/formula.jpeg -------------------------------------------------------------------------------- /images/network.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/network.jpeg -------------------------------------------------------------------------------- /images/transform.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/transform.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorlayer as tl 7 | from tensorlayer.layers import * 8 | 9 | ##================== PREPARE DATA ============================================## 10 | sess = tf.InteractiveSession() 11 | X_train, y_train, X_val, y_val, X_test, y_test = \ 12 | tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) 13 | 14 | def pad_distort_im_fn(x): 15 | """ Zero pads an image to 40x40, and distort it. 16 | 17 | Examples 18 | --------- 19 | x = pad_distort_im_fn(X_train[0]) 20 | print(x, x.shape, x.max()) 21 | tl.vis.save_image(x, '_xd.png') 22 | tl.vis.save_image(X_train[0], '_x.png') 23 | """ 24 | b = np.zeros((40, 40, 1)) 25 | o = int((40-28)/2) 26 | b[o:o+28, o:o+28] = x 27 | x = b 28 | x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant') 29 | x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant') 30 | x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant') 31 | x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05), fill_mode='constant') 32 | return x 33 | 34 | def pad_distort_ims_fn(X): 35 | """ Zero pads images to 40x40, and distort them. """ 36 | X_40 = [] 37 | for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False): 38 | X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn)) 39 | X_40 = np.asarray(X_40) 40 | return X_40 41 | 42 | # create dataset with size of 40x40 with distortion 43 | X_train_40 = pad_distort_ims_fn(X_train) 44 | X_val_40 = pad_distort_ims_fn(X_val) 45 | X_test_40 = pad_distort_ims_fn(X_test) 46 | 47 | tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png') 48 | tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png') 49 | 50 | ##================== DEFINE MODEL ============================================## 51 | batch_size = 64 52 | x = tf.placeholder(tf.float32, shape=[batch_size, 40, 40, 1], name='x') 53 | y_ = tf.placeholder(tf.int64, shape=[batch_size, ], name='y_') 54 | 55 | def model(x, is_train, reuse): 56 | with tf.variable_scope("STN", reuse=reuse): 57 | tl.layers.set_name_reuse(reuse) 58 | nin = InputLayer(x, name='in') 59 | ## 1. Localisation network 60 | # use MLP as the localisation net 61 | nt = FlattenLayer(nin, name='tf') 62 | nt = DenseLayer(nt, n_units=20, act=tf.nn.tanh, name='td1') 63 | nt = DropoutLayer(nt, 0.8, True, is_train, name='tdrop') 64 | # you can also use CNN instead for MLP as the localisation net 65 | # nt = Conv2d(nin, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='tc1') 66 | # nt = Conv2d(nt, 8, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='tc2') 67 | ## 2. Spatial transformer module (sampler) 68 | n = SpatialTransformer2dAffineLayer(nin, nt, out_size=[40, 40], name='ST') 69 | s = n 70 | ## 3. Classifier 71 | n = Conv2d(n, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='c1') 72 | n = Conv2d(n, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='c2') 73 | n = FlattenLayer(n, name='f') 74 | n = DenseLayer(n, n_units=1024, act=tf.nn.relu, name='d1') 75 | n = DenseLayer(n, n_units=10, act=tf.identity, name='do') 76 | ## 4. Cost function and Accuracy 77 | y = n.outputs 78 | cost = tl.cost.cross_entropy(y, y_, 'cost') 79 | correct_prediction = tf.equal(tf.argmax(y, 1), y_) 80 | acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 81 | return n, s, cost, acc 82 | 83 | net_train, _, cost, _ = model(x, is_train=True, reuse=False) 84 | net_test, net_trans, cost_test, acc = model(x, is_train=False, reuse=True) 85 | 86 | ##================== DEFINE TRAIN OPS ========================================## 87 | n_epoch = 500 88 | learning_rate = 0.0001 89 | print_freq = 10 90 | 91 | train_params = tl.layers.get_variables_with_name('STN', train_only=True, printable=True) 92 | train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, 93 | epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params) 94 | 95 | ##================== TRAINING ================================================## 96 | tl.layers.initialize_global_variables(sess) 97 | net_train.print_params() 98 | net_train.print_layers() 99 | 100 | for epoch in range(n_epoch): 101 | start_time = time.time() 102 | ## you can use continuous data augmentation 103 | # for X_train_a, y_train_a in tl.iterate.minibatches( 104 | # X_train, y_train, batch_size, shuffle=True): 105 | # X_train_a = tl.prepro.threading_data(X_train_a, fn=pad_distort_im_fn) 106 | # sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) 107 | ## or use pre-distorted images (faster) 108 | for X_train_a, y_train_a in tl.iterate.minibatches( 109 | X_train_40, y_train, batch_size, shuffle=True): 110 | sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) 111 | 112 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: 113 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) 114 | train_loss, train_acc, n_batch = 0, 0, 0 115 | for X_train_a, y_train_a in tl.iterate.minibatches( 116 | X_train_40, y_train, batch_size, shuffle=False): 117 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) 118 | train_loss += err; train_acc += ac; n_batch += 1 119 | print(" train loss: %f" % (train_loss/ n_batch)) 120 | print(" train acc: %f" % (train_acc/ n_batch)) 121 | val_loss, val_acc, n_batch = 0, 0, 0 122 | for X_val_a, y_val_a in tl.iterate.minibatches( 123 | X_val_40, y_val, batch_size, shuffle=False): 124 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) 125 | val_loss += err; val_acc += ac; n_batch += 1 126 | print(" val loss: %f" % (val_loss/ n_batch)) 127 | print(" val acc: %f" % (val_acc/ n_batch)) 128 | 129 | # net_train.print_params() 130 | # net_test.print_params() 131 | # net_trans.print_params() 132 | print('save images') 133 | trans_imgs = sess.run(net_trans.outputs, {x: X_test_40[0:64]}) 134 | tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch) 135 | 136 | ##================== EVALUATION ==============================================## 137 | print('Evaluation') 138 | test_loss, test_acc, n_batch = 0, 0, 0 139 | for X_test_a, y_test_a in tl.iterate.minibatches( 140 | X_test_40, y_test, batch_size, shuffle=False): 141 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) 142 | test_loss += err; test_acc += ac; n_batch += 1 143 | print(" test loss: %f" % (test_loss/n_batch)) 144 | print(" test acc: %f" % (test_acc/n_batch)) 145 | --------------------------------------------------------------------------------