├── README.md
├── images
├── after_stn.png
├── before_stn.png
├── formula.jpeg
├── network.jpeg
└── transform.jpeg
└── main.py
/README.md:
--------------------------------------------------------------------------------
1 | # [MOVED TO HERE](https://github.com/tensorlayer/tensorlayer/tree/master/examples/spatial_transformer_network)
2 |
3 | # Spatial Transformer Networks
4 |
5 | [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) (STN) is a dynamic mechanism that produces transformations of input images (or feature maps)including scaling, cropping, rotations, as well as non-rigid deformations. This enables the network to not only select regions of an image that are most relevant (attention), but also to transform those regions to simplify recognition in the following layers.
6 |
7 | Video for different transformation [click me](https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view).
8 |
9 | In this repositary, we implemented a STN for [2D Affine Transformation](https://en.wikipedia.org/wiki/Affine_transformation) on MNIST dataset. We generated images with size of 40x40 from the original MNIST dataset, and distorted the images by random rotation, shifting, shearing and zoom in/out. The STN was able to learn to automatically apply transformations on distorted images via classification task.
10 |
11 |
12 |
13 |

14 |
15 |
Fig 1:Transformation
16 |
17 |
18 |
19 |
20 |

21 |
22 |
Fig 2:Network
23 |
24 |
25 |
26 |

27 |
28 |
Fig 3:Formula
29 |
30 |
31 | ## Result
32 |
33 | After classification task, the STN is able to transform the distorted image from Fig 4 back to Fig 5.
34 |
35 |
36 |

37 |
38 |
Fig 4: Input
39 |
40 |
41 |
42 |

43 |
44 |
Fig 5: Output
45 |
46 |
47 |
--------------------------------------------------------------------------------
/images/after_stn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/after_stn.png
--------------------------------------------------------------------------------
/images/before_stn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/before_stn.png
--------------------------------------------------------------------------------
/images/formula.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/formula.jpeg
--------------------------------------------------------------------------------
/images/network.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/network.jpeg
--------------------------------------------------------------------------------
/images/transform.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsdonghao/Spatial-Transformer-Nets/abfdbf9ec26891b4b051f4ec0d4124e2c8d6a834/images/transform.jpeg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- coding: utf8 -*-
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | import tensorlayer as tl
7 | from tensorlayer.layers import *
8 |
9 | ##================== PREPARE DATA ============================================##
10 | sess = tf.InteractiveSession()
11 | X_train, y_train, X_val, y_val, X_test, y_test = \
12 | tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
13 |
14 | def pad_distort_im_fn(x):
15 | """ Zero pads an image to 40x40, and distort it.
16 |
17 | Examples
18 | ---------
19 | x = pad_distort_im_fn(X_train[0])
20 | print(x, x.shape, x.max())
21 | tl.vis.save_image(x, '_xd.png')
22 | tl.vis.save_image(X_train[0], '_x.png')
23 | """
24 | b = np.zeros((40, 40, 1))
25 | o = int((40-28)/2)
26 | b[o:o+28, o:o+28] = x
27 | x = b
28 | x = tl.prepro.rotation(x, rg=30, is_random=True, fill_mode='constant')
29 | x = tl.prepro.shear(x, 0.05, is_random=True, fill_mode='constant')
30 | x = tl.prepro.shift(x, wrg=0.25, hrg=0.25, is_random=True, fill_mode='constant')
31 | x = tl.prepro.zoom(x, zoom_range=(0.95, 1.05), fill_mode='constant')
32 | return x
33 |
34 | def pad_distort_ims_fn(X):
35 | """ Zero pads images to 40x40, and distort them. """
36 | X_40 = []
37 | for X_a, _ in tl.iterate.minibatches(X, X, 50, shuffle=False):
38 | X_40.extend(tl.prepro.threading_data(X_a, fn=pad_distort_im_fn))
39 | X_40 = np.asarray(X_40)
40 | return X_40
41 |
42 | # create dataset with size of 40x40 with distortion
43 | X_train_40 = pad_distort_ims_fn(X_train)
44 | X_val_40 = pad_distort_ims_fn(X_val)
45 | X_test_40 = pad_distort_ims_fn(X_test)
46 |
47 | tl.vis.save_images(X_test[0:32], [4, 8], '_imgs_original.png')
48 | tl.vis.save_images(X_test_40[0:32], [4, 8], '_imgs_distorted.png')
49 |
50 | ##================== DEFINE MODEL ============================================##
51 | batch_size = 64
52 | x = tf.placeholder(tf.float32, shape=[batch_size, 40, 40, 1], name='x')
53 | y_ = tf.placeholder(tf.int64, shape=[batch_size, ], name='y_')
54 |
55 | def model(x, is_train, reuse):
56 | with tf.variable_scope("STN", reuse=reuse):
57 | tl.layers.set_name_reuse(reuse)
58 | nin = InputLayer(x, name='in')
59 | ## 1. Localisation network
60 | # use MLP as the localisation net
61 | nt = FlattenLayer(nin, name='tf')
62 | nt = DenseLayer(nt, n_units=20, act=tf.nn.tanh, name='td1')
63 | nt = DropoutLayer(nt, 0.8, True, is_train, name='tdrop')
64 | # you can also use CNN instead for MLP as the localisation net
65 | # nt = Conv2d(nin, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='tc1')
66 | # nt = Conv2d(nt, 8, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='tc2')
67 | ## 2. Spatial transformer module (sampler)
68 | n = SpatialTransformer2dAffineLayer(nin, nt, out_size=[40, 40], name='ST')
69 | s = n
70 | ## 3. Classifier
71 | n = Conv2d(n, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='c1')
72 | n = Conv2d(n, 16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME', name='c2')
73 | n = FlattenLayer(n, name='f')
74 | n = DenseLayer(n, n_units=1024, act=tf.nn.relu, name='d1')
75 | n = DenseLayer(n, n_units=10, act=tf.identity, name='do')
76 | ## 4. Cost function and Accuracy
77 | y = n.outputs
78 | cost = tl.cost.cross_entropy(y, y_, 'cost')
79 | correct_prediction = tf.equal(tf.argmax(y, 1), y_)
80 | acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
81 | return n, s, cost, acc
82 |
83 | net_train, _, cost, _ = model(x, is_train=True, reuse=False)
84 | net_test, net_trans, cost_test, acc = model(x, is_train=False, reuse=True)
85 |
86 | ##================== DEFINE TRAIN OPS ========================================##
87 | n_epoch = 500
88 | learning_rate = 0.0001
89 | print_freq = 10
90 |
91 | train_params = tl.layers.get_variables_with_name('STN', train_only=True, printable=True)
92 | train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
93 | epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)
94 |
95 | ##================== TRAINING ================================================##
96 | tl.layers.initialize_global_variables(sess)
97 | net_train.print_params()
98 | net_train.print_layers()
99 |
100 | for epoch in range(n_epoch):
101 | start_time = time.time()
102 | ## you can use continuous data augmentation
103 | # for X_train_a, y_train_a in tl.iterate.minibatches(
104 | # X_train, y_train, batch_size, shuffle=True):
105 | # X_train_a = tl.prepro.threading_data(X_train_a, fn=pad_distort_im_fn)
106 | # sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})
107 | ## or use pre-distorted images (faster)
108 | for X_train_a, y_train_a in tl.iterate.minibatches(
109 | X_train_40, y_train, batch_size, shuffle=True):
110 | sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})
111 |
112 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
113 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
114 | train_loss, train_acc, n_batch = 0, 0, 0
115 | for X_train_a, y_train_a in tl.iterate.minibatches(
116 | X_train_40, y_train, batch_size, shuffle=False):
117 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
118 | train_loss += err; train_acc += ac; n_batch += 1
119 | print(" train loss: %f" % (train_loss/ n_batch))
120 | print(" train acc: %f" % (train_acc/ n_batch))
121 | val_loss, val_acc, n_batch = 0, 0, 0
122 | for X_val_a, y_val_a in tl.iterate.minibatches(
123 | X_val_40, y_val, batch_size, shuffle=False):
124 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
125 | val_loss += err; val_acc += ac; n_batch += 1
126 | print(" val loss: %f" % (val_loss/ n_batch))
127 | print(" val acc: %f" % (val_acc/ n_batch))
128 |
129 | # net_train.print_params()
130 | # net_test.print_params()
131 | # net_trans.print_params()
132 | print('save images')
133 | trans_imgs = sess.run(net_trans.outputs, {x: X_test_40[0:64]})
134 | tl.vis.save_images(trans_imgs[0:32], [4, 8], '_imgs_distorted_after_stn_%s.png' % epoch)
135 |
136 | ##================== EVALUATION ==============================================##
137 | print('Evaluation')
138 | test_loss, test_acc, n_batch = 0, 0, 0
139 | for X_test_a, y_test_a in tl.iterate.minibatches(
140 | X_test_40, y_test, batch_size, shuffle=False):
141 | err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a})
142 | test_loss += err; test_acc += ac; n_batch += 1
143 | print(" test loss: %f" % (test_loss/n_batch))
144 | print(" test acc: %f" % (test_acc/n_batch))
145 |
--------------------------------------------------------------------------------