├── LICENSE
├── README.md
├── cluttered_mnist.py
├── data
├── cat.jpg
└── mnist_sequence1_sample_5distortions5x5.npz
├── example.py
├── spatial_transformer.py
└── tf_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 David Dao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatial Transformer Network
2 |
3 | [](https://github.com/daviddao/green-ai)
4 |
5 | The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
6 |
7 |
8 |

9 |
10 |
11 | ### API
12 |
13 | A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
14 |
15 | #### How to use
16 |
17 |
18 |

19 |
20 |
21 | ```python
22 | transformer(U, theta, out_size)
23 | ```
24 |
25 | #### Parameters
26 |
27 | U : float
28 | The output of a convolutional net should have the
29 | shape [num_batch, height, width, num_channels].
30 | theta: float
31 | The output of the
32 | localisation network should be [num_batch, 6].
33 | out_size: tuple of two ints
34 | The size of the output of the network
35 |
36 |
37 | #### Notes
38 | To initialize the network to the identity transform init ``theta`` to :
39 |
40 | ```python
41 | identity = np.array([[1., 0., 0.],
42 | [0., 1., 0.]])
43 | identity = identity.flatten()
44 | theta = tf.Variable(initial_value=identity)
45 | ```
46 |
47 | #### Experiments
48 |
49 |
50 |

51 |
52 |
53 | We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
54 |
55 | All experiments were run in Tensorflow 0.7.
56 |
57 | ### References
58 |
59 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
60 |
61 | [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
62 |
--------------------------------------------------------------------------------
/cluttered_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | import tensorflow as tf
16 | from spatial_transformer import transformer
17 | import numpy as np
18 | from tf_utils import weight_variable, bias_variable, dense_to_one_hot
19 |
20 | # %% Load data
21 | mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
22 |
23 | X_train = mnist_cluttered['X_train']
24 | y_train = mnist_cluttered['y_train']
25 | X_valid = mnist_cluttered['X_valid']
26 | y_valid = mnist_cluttered['y_valid']
27 | X_test = mnist_cluttered['X_test']
28 | y_test = mnist_cluttered['y_test']
29 |
30 | # % turn from dense to one hot representation
31 | Y_train = dense_to_one_hot(y_train, n_classes=10)
32 | Y_valid = dense_to_one_hot(y_valid, n_classes=10)
33 | Y_test = dense_to_one_hot(y_test, n_classes=10)
34 |
35 | # %% Graph representation of our network
36 |
37 | # %% Placeholders for 40x40 resolution
38 | x = tf.placeholder(tf.float32, [None, 1600])
39 | y = tf.placeholder(tf.float32, [None, 10])
40 |
41 | # %% Since x is currently [batch, height*width], we need to reshape to a
42 | # 4-D tensor to use it in a convolutional graph. If one component of
43 | # `shape` is the special value -1, the size of that dimension is
44 | # computed so that the total size remains constant. Since we haven't
45 | # defined the batch dimension's shape yet, we use -1 to denote this
46 | # dimension should not change size.
47 | x_tensor = tf.reshape(x, [-1, 40, 40, 1])
48 |
49 | # %% We'll setup the two-layer localisation network to figure out the
50 | # %% parameters for an affine transformation of the input
51 | # %% Create variables for fully connected layer
52 | W_fc_loc1 = weight_variable([1600, 20])
53 | b_fc_loc1 = bias_variable([20])
54 |
55 | W_fc_loc2 = weight_variable([20, 6])
56 | # Use identity transformation as starting point
57 | initial = np.array([[1., 0, 0], [0, 1., 0]])
58 | initial = initial.astype('float32')
59 | initial = initial.flatten()
60 | b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
61 |
62 | # %% Define the two layer localisation network
63 | h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
64 | # %% We can add dropout for regularizing and to reduce overfitting like so:
65 | keep_prob = tf.placeholder(tf.float32)
66 | h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
67 | # %% Second layer
68 | h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
69 |
70 | # %% We'll create a spatial transformer module to identify discriminative
71 | # %% patches
72 | out_size = (40, 40)
73 | h_trans = transformer(x_tensor, h_fc_loc2, out_size)
74 |
75 | # %% We'll setup the first convolutional layer
76 | # Weight matrix is [height x width x input_channels x output_channels]
77 | filter_size = 3
78 | n_filters_1 = 16
79 | W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
80 |
81 | # %% Bias is [output_channels]
82 | b_conv1 = bias_variable([n_filters_1])
83 |
84 | # %% Now we can build a graph which does the first layer of convolution:
85 | # we define our stride as batch x height x width x channels
86 | # instead of pooling, we use strides of 2 and more layers
87 | # with smaller filters.
88 |
89 | h_conv1 = tf.nn.relu(
90 | tf.nn.conv2d(input=h_trans,
91 | filter=W_conv1,
92 | strides=[1, 2, 2, 1],
93 | padding='SAME') +
94 | b_conv1)
95 |
96 | # %% And just like the first layer, add additional layers to create
97 | # a deep net
98 | n_filters_2 = 16
99 | W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
100 | b_conv2 = bias_variable([n_filters_2])
101 | h_conv2 = tf.nn.relu(
102 | tf.nn.conv2d(input=h_conv1,
103 | filter=W_conv2,
104 | strides=[1, 2, 2, 1],
105 | padding='SAME') +
106 | b_conv2)
107 |
108 | # %% We'll now reshape so we can connect to a fully-connected layer:
109 | h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
110 |
111 | # %% Create a fully-connected layer:
112 | n_fc = 1024
113 | W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
114 | b_fc1 = bias_variable([n_fc])
115 | h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
116 |
117 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
118 |
119 | # %% And finally our softmax layer:
120 | W_fc2 = weight_variable([n_fc, 10])
121 | b_fc2 = bias_variable([10])
122 | y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
123 |
124 | # %% Define loss/eval/training functions
125 | cross_entropy = tf.reduce_mean(
126 | tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_logits, labels=y))
127 | opt = tf.train.AdamOptimizer()
128 | optimizer = opt.minimize(cross_entropy)
129 | grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
130 |
131 | # %% Monitor accuracy
132 | correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1))
133 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
134 |
135 | # %% We now create a new session to actually perform the initialization the
136 | # variables:
137 | sess = tf.Session()
138 | sess.run(tf.global_variables_initializer())
139 |
140 |
141 | # %% We'll now train in minibatches and report accuracy, loss:
142 | iter_per_epoch = 100
143 | n_epochs = 500
144 | train_size = 10000
145 |
146 | indices = np.linspace(0, 10000 - 1, iter_per_epoch)
147 | indices = indices.astype('int')
148 |
149 | for epoch_i in range(n_epochs):
150 | for iter_i in range(iter_per_epoch - 1):
151 | batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
152 | batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
153 |
154 | if iter_i % 10 == 0:
155 | loss = sess.run(cross_entropy,
156 | feed_dict={
157 | x: batch_xs,
158 | y: batch_ys,
159 | keep_prob: 1.0
160 | })
161 | print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
162 |
163 | sess.run(optimizer, feed_dict={
164 | x: batch_xs, y: batch_ys, keep_prob: 0.8})
165 |
166 | print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy,
167 | feed_dict={
168 | x: X_valid,
169 | y: Y_valid,
170 | keep_prob: 1.0
171 | })))
172 | # theta = sess.run(h_fc_loc2, feed_dict={
173 | # x: batch_xs, keep_prob: 1.0})
174 | # print(theta[0])
175 |
--------------------------------------------------------------------------------
/data/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddao/spatial-transformer-tensorflow/7d954f8d57f75c1787b489328c8fb4201ed2fb05/data/cat.jpg
--------------------------------------------------------------------------------
/data/mnist_sequence1_sample_5distortions5x5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daviddao/spatial-transformer-tensorflow/7d954f8d57f75c1787b489328c8fb4201ed2fb05/data/mnist_sequence1_sample_5distortions5x5.npz
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from scipy import ndimage
16 | import tensorflow as tf
17 | from spatial_transformer import transformer
18 | import numpy as np
19 | import matplotlib.pyplot as plt
20 |
21 | # %% Create a batch of three images (1600 x 1200)
22 | # %% Image retrieved from:
23 | # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
24 | im = ndimage.imread('cat.jpg')
25 | im = im / 255.
26 | im = im.reshape(1, 1200, 1600, 3)
27 | im = im.astype('float32')
28 |
29 | # %% Let the output size of the transformer be half the image size.
30 | out_size = (600, 800)
31 |
32 | # %% Simulate batch
33 | batch = np.append(im, im, axis=0)
34 | batch = np.append(batch, im, axis=0)
35 | num_batch = 3
36 |
37 | x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
38 | x = tf.cast(batch, 'float32')
39 |
40 | # %% Create localisation network and convolutional layer
41 | with tf.variable_scope('spatial_transformer_0'):
42 |
43 | # %% Create a fully-connected layer with 6 output nodes
44 | n_fc = 6
45 | W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
46 |
47 | # %% Zoom into the image
48 | initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
49 | initial = initial.astype('float32')
50 | initial = initial.flatten()
51 |
52 | b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
53 | h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1
54 | h_trans = transformer(x, h_fc1, out_size)
55 |
56 | # %% Run session
57 | sess = tf.Session()
58 | sess.run(tf.initialize_all_variables())
59 | y = sess.run(h_trans, feed_dict={x: batch})
60 |
61 | # plt.imshow(y[0])
62 |
--------------------------------------------------------------------------------
/spatial_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | import tensorflow as tf
16 |
17 |
18 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
19 | """Spatial Transformer Layer
20 |
21 | Implements a spatial transformer layer as described in [1]_.
22 | Based on [2]_ and edited by David Dao for Tensorflow.
23 |
24 | Parameters
25 | ----------
26 | U : float
27 | The output of a convolutional net should have the
28 | shape [num_batch, height, width, num_channels].
29 | theta: float
30 | The output of the
31 | localisation network should be [num_batch, 6].
32 | out_size: tuple of two ints
33 | The size of the output of the network (height, width)
34 |
35 | References
36 | ----------
37 | .. [1] Spatial Transformer Networks
38 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
39 | Submitted on 5 Jun 2015
40 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
41 |
42 | Notes
43 | -----
44 | To initialize the network to the identity transform init
45 | ``theta`` to :
46 | identity = np.array([[1., 0., 0.],
47 | [0., 1., 0.]])
48 | identity = identity.flatten()
49 | theta = tf.Variable(initial_value=identity)
50 |
51 | """
52 |
53 | def _repeat(x, n_repeats):
54 | with tf.variable_scope('_repeat'):
55 | rep = tf.transpose(
56 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0])
57 | rep = tf.cast(rep, 'int32')
58 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
59 | return tf.reshape(x, [-1])
60 |
61 | def _interpolate(im, x, y, out_size):
62 | with tf.variable_scope('_interpolate'):
63 | # constants
64 | num_batch = tf.shape(im)[0]
65 | height = tf.shape(im)[1]
66 | width = tf.shape(im)[2]
67 | channels = tf.shape(im)[3]
68 |
69 | x = tf.cast(x, 'float32')
70 | y = tf.cast(y, 'float32')
71 | height_f = tf.cast(height, 'float32')
72 | width_f = tf.cast(width, 'float32')
73 | out_height = out_size[0]
74 | out_width = out_size[1]
75 | zero = tf.zeros([], dtype='int32')
76 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
77 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
78 |
79 | # scale indices from [-1, 1] to [0, width/height]
80 | x = (x + 1.0)*(width_f) / 2.0
81 | y = (y + 1.0)*(height_f) / 2.0
82 |
83 | # do sampling
84 | x0 = tf.cast(tf.floor(x), 'int32')
85 | x1 = x0 + 1
86 | y0 = tf.cast(tf.floor(y), 'int32')
87 | y1 = y0 + 1
88 |
89 | x0 = tf.clip_by_value(x0, zero, max_x)
90 | x1 = tf.clip_by_value(x1, zero, max_x)
91 | y0 = tf.clip_by_value(y0, zero, max_y)
92 | y1 = tf.clip_by_value(y1, zero, max_y)
93 | dim2 = width
94 | dim1 = width*height
95 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
96 | base_y0 = base + y0*dim2
97 | base_y1 = base + y1*dim2
98 | idx_a = base_y0 + x0
99 | idx_b = base_y1 + x0
100 | idx_c = base_y0 + x1
101 | idx_d = base_y1 + x1
102 |
103 | # use indices to lookup pixels in the flat image and restore
104 | # channels dim
105 | im_flat = tf.reshape(im, tf.stack([-1, channels]))
106 | im_flat = tf.cast(im_flat, 'float32')
107 | Ia = tf.gather(im_flat, idx_a)
108 | Ib = tf.gather(im_flat, idx_b)
109 | Ic = tf.gather(im_flat, idx_c)
110 | Id = tf.gather(im_flat, idx_d)
111 |
112 | # and finally calculate interpolated values
113 | x0_f = tf.cast(x0, 'float32')
114 | x1_f = tf.cast(x1, 'float32')
115 | y0_f = tf.cast(y0, 'float32')
116 | y1_f = tf.cast(y1, 'float32')
117 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
118 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
119 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
120 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
121 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
122 | return output
123 |
124 | def _meshgrid(height, width):
125 | with tf.variable_scope('_meshgrid'):
126 | # This should be equivalent to:
127 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
128 | # np.linspace(-1, 1, height))
129 | # ones = np.ones(np.prod(x_t.shape))
130 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
131 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
132 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
133 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
134 | tf.ones(shape=tf.stack([1, width])))
135 |
136 | x_t_flat = tf.reshape(x_t, (1, -1))
137 | y_t_flat = tf.reshape(y_t, (1, -1))
138 |
139 | ones = tf.ones_like(x_t_flat)
140 | grid = tf.concat([x_t_flat, y_t_flat, ones], 0)
141 | return grid
142 |
143 | def _transform(theta, input_dim, out_size):
144 | with tf.variable_scope('_transform'):
145 | num_batch = tf.shape(input_dim)[0]
146 | height = tf.shape(input_dim)[1]
147 | width = tf.shape(input_dim)[2]
148 | num_channels = tf.shape(input_dim)[3]
149 | theta = tf.reshape(theta, (-1, 2, 3))
150 | theta = tf.cast(theta, 'float32')
151 |
152 | # grid of (x_t, y_t, 1), eq (1) in ref [1]
153 | height_f = tf.cast(height, 'float32')
154 | width_f = tf.cast(width, 'float32')
155 | out_height = out_size[0]
156 | out_width = out_size[1]
157 | grid = _meshgrid(out_height, out_width)
158 | grid = tf.expand_dims(grid, 0)
159 | grid = tf.reshape(grid, [-1])
160 | grid = tf.tile(grid, tf.stack([num_batch]))
161 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
162 |
163 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
164 | T_g = tf.matmul(theta, grid)
165 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
166 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
167 | x_s_flat = tf.reshape(x_s, [-1])
168 | y_s_flat = tf.reshape(y_s, [-1])
169 |
170 | input_transformed = _interpolate(
171 | input_dim, x_s_flat, y_s_flat,
172 | out_size)
173 |
174 | output = tf.reshape(
175 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
176 | return output
177 |
178 | with tf.variable_scope(name):
179 | output = _transform(theta, U, out_size)
180 | return output
181 |
182 |
183 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
184 | """Batch Spatial Transformer Layer
185 |
186 | Parameters
187 | ----------
188 |
189 | U : float
190 | tensor of inputs [num_batch,height,width,num_channels]
191 | thetas : float
192 | a set of transformations for each input [num_batch,num_transforms,6]
193 | out_size : int
194 | the size of the output [out_height,out_width]
195 |
196 | Returns: float
197 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
198 | """
199 | with tf.variable_scope(name):
200 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
201 | indices = [[i]*num_transforms for i in xrange(num_batch)]
202 | input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
203 | return transformer(input_repeated, thetas, out_size)
204 |
--------------------------------------------------------------------------------
/tf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
17 | import tensorflow as tf
18 | import numpy as np
19 |
20 | def conv2d(x, n_filters,
21 | k_h=5, k_w=5,
22 | stride_h=2, stride_w=2,
23 | stddev=0.02,
24 | activation=lambda x: x,
25 | bias=True,
26 | padding='SAME',
27 | name="Conv2D"):
28 | """2D Convolution with options for kernel size, stride, and init deviation.
29 | Parameters
30 | ----------
31 | x : Tensor
32 | Input tensor to convolve.
33 | n_filters : int
34 | Number of filters to apply.
35 | k_h : int, optional
36 | Kernel height.
37 | k_w : int, optional
38 | Kernel width.
39 | stride_h : int, optional
40 | Stride in rows.
41 | stride_w : int, optional
42 | Stride in cols.
43 | stddev : float, optional
44 | Initialization's standard deviation.
45 | activation : arguments, optional
46 | Function which applies a nonlinearity
47 | padding : str, optional
48 | 'SAME' or 'VALID'
49 | name : str, optional
50 | Variable scope to use.
51 | Returns
52 | -------
53 | x : Tensor
54 | Convolved input.
55 | """
56 | with tf.variable_scope(name):
57 | w = tf.get_variable(
58 | 'w', [k_h, k_w, x.get_shape()[-1], n_filters],
59 | initializer=tf.truncated_normal_initializer(stddev=stddev))
60 | conv = tf.nn.conv2d(
61 | x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
62 | if bias:
63 | b = tf.get_variable(
64 | 'b', [n_filters],
65 | initializer=tf.truncated_normal_initializer(stddev=stddev))
66 | conv = conv + b
67 | return conv
68 |
69 | def linear(x, n_units, scope=None, stddev=0.02,
70 | activation=lambda x: x):
71 | """Fully-connected network.
72 | Parameters
73 | ----------
74 | x : Tensor
75 | Input tensor to the network.
76 | n_units : int
77 | Number of units to connect to.
78 | scope : str, optional
79 | Variable scope to use.
80 | stddev : float, optional
81 | Initialization's standard deviation.
82 | activation : arguments, optional
83 | Function which applies a nonlinearity
84 | Returns
85 | -------
86 | x : Tensor
87 | Fully-connected output.
88 | """
89 | shape = x.get_shape().as_list()
90 |
91 | with tf.variable_scope(scope or "Linear"):
92 | matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
93 | tf.random_normal_initializer(stddev=stddev))
94 | return activation(tf.matmul(x, matrix))
95 |
96 | # %%
97 | def weight_variable(shape):
98 | '''Helper function to create a weight variable initialized with
99 | a normal distribution
100 | Parameters
101 | ----------
102 | shape : list
103 | Size of weight variable
104 | '''
105 | #initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
106 | initial = tf.zeros(shape)
107 | return tf.Variable(initial)
108 |
109 | # %%
110 | def bias_variable(shape):
111 | '''Helper function to create a bias variable initialized with
112 | a constant value.
113 | Parameters
114 | ----------
115 | shape : list
116 | Size of weight variable
117 | '''
118 | initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
119 | return tf.Variable(initial)
120 |
121 | # %%
122 | def dense_to_one_hot(labels, n_classes=2):
123 | """Convert class labels from scalars to one-hot vectors."""
124 | labels = np.array(labels)
125 | n_labels = labels.shape[0]
126 | index_offset = np.arange(n_labels) * n_classes
127 | labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
128 | labels_one_hot.flat[index_offset + labels.ravel()] = 1
129 | return labels_one_hot
130 |
--------------------------------------------------------------------------------