├── LICENSE ├── README.md ├── cluttered_mnist.py ├── data ├── cat.jpg └── mnist_sequence1_sample_5distortions5x5.npz ├── example.py ├── spatial_transformer.py └── tf_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 David Dao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial Transformer Network 2 | 3 | [![](https://tinyurl.com/greenai-pledge)](https://github.com/daviddao/green-ai) 4 | 5 | The Spatial Transformer Network [1] allows the spatial manipulation of data within the network. 6 | 7 |
8 |

9 |
10 | 11 | ### API 12 | 13 | A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2]. 14 | 15 | #### How to use 16 | 17 |
18 |

19 |
20 | 21 | ```python 22 | transformer(U, theta, out_size) 23 | ``` 24 | 25 | #### Parameters 26 | 27 | U : float 28 | The output of a convolutional net should have the 29 | shape [num_batch, height, width, num_channels]. 30 | theta: float 31 | The output of the 32 | localisation network should be [num_batch, 6]. 33 | out_size: tuple of two ints 34 | The size of the output of the network 35 | 36 | 37 | #### Notes 38 | To initialize the network to the identity transform init ``theta`` to : 39 | 40 | ```python 41 | identity = np.array([[1., 0., 0.], 42 | [0., 1., 0.]]) 43 | identity = identity.flatten() 44 | theta = tf.Variable(initial_value=identity) 45 | ``` 46 | 47 | #### Experiments 48 | 49 |
50 |

51 |
52 | 53 | We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN. 54 | 55 | All experiments were run in Tensorflow 0.7. 56 | 57 | ### References 58 | 59 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015) 60 | 61 | [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 62 | -------------------------------------------------------------------------------- /cluttered_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import tensorflow as tf 16 | from spatial_transformer import transformer 17 | import numpy as np 18 | from tf_utils import weight_variable, bias_variable, dense_to_one_hot 19 | 20 | # %% Load data 21 | mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz') 22 | 23 | X_train = mnist_cluttered['X_train'] 24 | y_train = mnist_cluttered['y_train'] 25 | X_valid = mnist_cluttered['X_valid'] 26 | y_valid = mnist_cluttered['y_valid'] 27 | X_test = mnist_cluttered['X_test'] 28 | y_test = mnist_cluttered['y_test'] 29 | 30 | # % turn from dense to one hot representation 31 | Y_train = dense_to_one_hot(y_train, n_classes=10) 32 | Y_valid = dense_to_one_hot(y_valid, n_classes=10) 33 | Y_test = dense_to_one_hot(y_test, n_classes=10) 34 | 35 | # %% Graph representation of our network 36 | 37 | # %% Placeholders for 40x40 resolution 38 | x = tf.placeholder(tf.float32, [None, 1600]) 39 | y = tf.placeholder(tf.float32, [None, 10]) 40 | 41 | # %% Since x is currently [batch, height*width], we need to reshape to a 42 | # 4-D tensor to use it in a convolutional graph. If one component of 43 | # `shape` is the special value -1, the size of that dimension is 44 | # computed so that the total size remains constant. Since we haven't 45 | # defined the batch dimension's shape yet, we use -1 to denote this 46 | # dimension should not change size. 47 | x_tensor = tf.reshape(x, [-1, 40, 40, 1]) 48 | 49 | # %% We'll setup the two-layer localisation network to figure out the 50 | # %% parameters for an affine transformation of the input 51 | # %% Create variables for fully connected layer 52 | W_fc_loc1 = weight_variable([1600, 20]) 53 | b_fc_loc1 = bias_variable([20]) 54 | 55 | W_fc_loc2 = weight_variable([20, 6]) 56 | # Use identity transformation as starting point 57 | initial = np.array([[1., 0, 0], [0, 1., 0]]) 58 | initial = initial.astype('float32') 59 | initial = initial.flatten() 60 | b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2') 61 | 62 | # %% Define the two layer localisation network 63 | h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1) 64 | # %% We can add dropout for regularizing and to reduce overfitting like so: 65 | keep_prob = tf.placeholder(tf.float32) 66 | h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob) 67 | # %% Second layer 68 | h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2) 69 | 70 | # %% We'll create a spatial transformer module to identify discriminative 71 | # %% patches 72 | out_size = (40, 40) 73 | h_trans = transformer(x_tensor, h_fc_loc2, out_size) 74 | 75 | # %% We'll setup the first convolutional layer 76 | # Weight matrix is [height x width x input_channels x output_channels] 77 | filter_size = 3 78 | n_filters_1 = 16 79 | W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1]) 80 | 81 | # %% Bias is [output_channels] 82 | b_conv1 = bias_variable([n_filters_1]) 83 | 84 | # %% Now we can build a graph which does the first layer of convolution: 85 | # we define our stride as batch x height x width x channels 86 | # instead of pooling, we use strides of 2 and more layers 87 | # with smaller filters. 88 | 89 | h_conv1 = tf.nn.relu( 90 | tf.nn.conv2d(input=h_trans, 91 | filter=W_conv1, 92 | strides=[1, 2, 2, 1], 93 | padding='SAME') + 94 | b_conv1) 95 | 96 | # %% And just like the first layer, add additional layers to create 97 | # a deep net 98 | n_filters_2 = 16 99 | W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2]) 100 | b_conv2 = bias_variable([n_filters_2]) 101 | h_conv2 = tf.nn.relu( 102 | tf.nn.conv2d(input=h_conv1, 103 | filter=W_conv2, 104 | strides=[1, 2, 2, 1], 105 | padding='SAME') + 106 | b_conv2) 107 | 108 | # %% We'll now reshape so we can connect to a fully-connected layer: 109 | h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2]) 110 | 111 | # %% Create a fully-connected layer: 112 | n_fc = 1024 113 | W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc]) 114 | b_fc1 = bias_variable([n_fc]) 115 | h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1) 116 | 117 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 118 | 119 | # %% And finally our softmax layer: 120 | W_fc2 = weight_variable([n_fc, 10]) 121 | b_fc2 = bias_variable([10]) 122 | y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 123 | 124 | # %% Define loss/eval/training functions 125 | cross_entropy = tf.reduce_mean( 126 | tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_logits, labels=y)) 127 | opt = tf.train.AdamOptimizer() 128 | optimizer = opt.minimize(cross_entropy) 129 | grads = opt.compute_gradients(cross_entropy, [b_fc_loc2]) 130 | 131 | # %% Monitor accuracy 132 | correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1)) 133 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) 134 | 135 | # %% We now create a new session to actually perform the initialization the 136 | # variables: 137 | sess = tf.Session() 138 | sess.run(tf.global_variables_initializer()) 139 | 140 | 141 | # %% We'll now train in minibatches and report accuracy, loss: 142 | iter_per_epoch = 100 143 | n_epochs = 500 144 | train_size = 10000 145 | 146 | indices = np.linspace(0, 10000 - 1, iter_per_epoch) 147 | indices = indices.astype('int') 148 | 149 | for epoch_i in range(n_epochs): 150 | for iter_i in range(iter_per_epoch - 1): 151 | batch_xs = X_train[indices[iter_i]:indices[iter_i+1]] 152 | batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]] 153 | 154 | if iter_i % 10 == 0: 155 | loss = sess.run(cross_entropy, 156 | feed_dict={ 157 | x: batch_xs, 158 | y: batch_ys, 159 | keep_prob: 1.0 160 | }) 161 | print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss)) 162 | 163 | sess.run(optimizer, feed_dict={ 164 | x: batch_xs, y: batch_ys, keep_prob: 0.8}) 165 | 166 | print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy, 167 | feed_dict={ 168 | x: X_valid, 169 | y: Y_valid, 170 | keep_prob: 1.0 171 | }))) 172 | # theta = sess.run(h_fc_loc2, feed_dict={ 173 | # x: batch_xs, keep_prob: 1.0}) 174 | # print(theta[0]) 175 | -------------------------------------------------------------------------------- /data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddao/spatial-transformer-tensorflow/7d954f8d57f75c1787b489328c8fb4201ed2fb05/data/cat.jpg -------------------------------------------------------------------------------- /data/mnist_sequence1_sample_5distortions5x5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddao/spatial-transformer-tensorflow/7d954f8d57f75c1787b489328c8fb4201ed2fb05/data/mnist_sequence1_sample_5distortions5x5.npz -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from scipy import ndimage 16 | import tensorflow as tf 17 | from spatial_transformer import transformer 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | 21 | # %% Create a batch of three images (1600 x 1200) 22 | # %% Image retrieved from: 23 | # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg 24 | im = ndimage.imread('cat.jpg') 25 | im = im / 255. 26 | im = im.reshape(1, 1200, 1600, 3) 27 | im = im.astype('float32') 28 | 29 | # %% Let the output size of the transformer be half the image size. 30 | out_size = (600, 800) 31 | 32 | # %% Simulate batch 33 | batch = np.append(im, im, axis=0) 34 | batch = np.append(batch, im, axis=0) 35 | num_batch = 3 36 | 37 | x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) 38 | x = tf.cast(batch, 'float32') 39 | 40 | # %% Create localisation network and convolutional layer 41 | with tf.variable_scope('spatial_transformer_0'): 42 | 43 | # %% Create a fully-connected layer with 6 output nodes 44 | n_fc = 6 45 | W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1') 46 | 47 | # %% Zoom into the image 48 | initial = np.array([[0.5, 0, 0], [0, 0.5, 0]]) 49 | initial = initial.astype('float32') 50 | initial = initial.flatten() 51 | 52 | b_fc1 = tf.Variable(initial_value=initial, name='b_fc1') 53 | h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1 54 | h_trans = transformer(x, h_fc1, out_size) 55 | 56 | # %% Run session 57 | sess = tf.Session() 58 | sess.run(tf.initialize_all_variables()) 59 | y = sess.run(h_trans, feed_dict={x: batch}) 60 | 61 | # plt.imshow(y[0]) 62 | -------------------------------------------------------------------------------- /spatial_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | 18 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs): 19 | """Spatial Transformer Layer 20 | 21 | Implements a spatial transformer layer as described in [1]_. 22 | Based on [2]_ and edited by David Dao for Tensorflow. 23 | 24 | Parameters 25 | ---------- 26 | U : float 27 | The output of a convolutional net should have the 28 | shape [num_batch, height, width, num_channels]. 29 | theta: float 30 | The output of the 31 | localisation network should be [num_batch, 6]. 32 | out_size: tuple of two ints 33 | The size of the output of the network (height, width) 34 | 35 | References 36 | ---------- 37 | .. [1] Spatial Transformer Networks 38 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 39 | Submitted on 5 Jun 2015 40 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 41 | 42 | Notes 43 | ----- 44 | To initialize the network to the identity transform init 45 | ``theta`` to : 46 | identity = np.array([[1., 0., 0.], 47 | [0., 1., 0.]]) 48 | identity = identity.flatten() 49 | theta = tf.Variable(initial_value=identity) 50 | 51 | """ 52 | 53 | def _repeat(x, n_repeats): 54 | with tf.variable_scope('_repeat'): 55 | rep = tf.transpose( 56 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 57 | rep = tf.cast(rep, 'int32') 58 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 59 | return tf.reshape(x, [-1]) 60 | 61 | def _interpolate(im, x, y, out_size): 62 | with tf.variable_scope('_interpolate'): 63 | # constants 64 | num_batch = tf.shape(im)[0] 65 | height = tf.shape(im)[1] 66 | width = tf.shape(im)[2] 67 | channels = tf.shape(im)[3] 68 | 69 | x = tf.cast(x, 'float32') 70 | y = tf.cast(y, 'float32') 71 | height_f = tf.cast(height, 'float32') 72 | width_f = tf.cast(width, 'float32') 73 | out_height = out_size[0] 74 | out_width = out_size[1] 75 | zero = tf.zeros([], dtype='int32') 76 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 77 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 78 | 79 | # scale indices from [-1, 1] to [0, width/height] 80 | x = (x + 1.0)*(width_f) / 2.0 81 | y = (y + 1.0)*(height_f) / 2.0 82 | 83 | # do sampling 84 | x0 = tf.cast(tf.floor(x), 'int32') 85 | x1 = x0 + 1 86 | y0 = tf.cast(tf.floor(y), 'int32') 87 | y1 = y0 + 1 88 | 89 | x0 = tf.clip_by_value(x0, zero, max_x) 90 | x1 = tf.clip_by_value(x1, zero, max_x) 91 | y0 = tf.clip_by_value(y0, zero, max_y) 92 | y1 = tf.clip_by_value(y1, zero, max_y) 93 | dim2 = width 94 | dim1 = width*height 95 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) 96 | base_y0 = base + y0*dim2 97 | base_y1 = base + y1*dim2 98 | idx_a = base_y0 + x0 99 | idx_b = base_y1 + x0 100 | idx_c = base_y0 + x1 101 | idx_d = base_y1 + x1 102 | 103 | # use indices to lookup pixels in the flat image and restore 104 | # channels dim 105 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 106 | im_flat = tf.cast(im_flat, 'float32') 107 | Ia = tf.gather(im_flat, idx_a) 108 | Ib = tf.gather(im_flat, idx_b) 109 | Ic = tf.gather(im_flat, idx_c) 110 | Id = tf.gather(im_flat, idx_d) 111 | 112 | # and finally calculate interpolated values 113 | x0_f = tf.cast(x0, 'float32') 114 | x1_f = tf.cast(x1, 'float32') 115 | y0_f = tf.cast(y0, 'float32') 116 | y1_f = tf.cast(y1, 'float32') 117 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 118 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 119 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 120 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 121 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) 122 | return output 123 | 124 | def _meshgrid(height, width): 125 | with tf.variable_scope('_meshgrid'): 126 | # This should be equivalent to: 127 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 128 | # np.linspace(-1, 1, height)) 129 | # ones = np.ones(np.prod(x_t.shape)) 130 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 131 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), 132 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) 133 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), 134 | tf.ones(shape=tf.stack([1, width]))) 135 | 136 | x_t_flat = tf.reshape(x_t, (1, -1)) 137 | y_t_flat = tf.reshape(y_t, (1, -1)) 138 | 139 | ones = tf.ones_like(x_t_flat) 140 | grid = tf.concat([x_t_flat, y_t_flat, ones], 0) 141 | return grid 142 | 143 | def _transform(theta, input_dim, out_size): 144 | with tf.variable_scope('_transform'): 145 | num_batch = tf.shape(input_dim)[0] 146 | height = tf.shape(input_dim)[1] 147 | width = tf.shape(input_dim)[2] 148 | num_channels = tf.shape(input_dim)[3] 149 | theta = tf.reshape(theta, (-1, 2, 3)) 150 | theta = tf.cast(theta, 'float32') 151 | 152 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 153 | height_f = tf.cast(height, 'float32') 154 | width_f = tf.cast(width, 'float32') 155 | out_height = out_size[0] 156 | out_width = out_size[1] 157 | grid = _meshgrid(out_height, out_width) 158 | grid = tf.expand_dims(grid, 0) 159 | grid = tf.reshape(grid, [-1]) 160 | grid = tf.tile(grid, tf.stack([num_batch])) 161 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) 162 | 163 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 164 | T_g = tf.matmul(theta, grid) 165 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 166 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 167 | x_s_flat = tf.reshape(x_s, [-1]) 168 | y_s_flat = tf.reshape(y_s, [-1]) 169 | 170 | input_transformed = _interpolate( 171 | input_dim, x_s_flat, y_s_flat, 172 | out_size) 173 | 174 | output = tf.reshape( 175 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) 176 | return output 177 | 178 | with tf.variable_scope(name): 179 | output = _transform(theta, U, out_size) 180 | return output 181 | 182 | 183 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): 184 | """Batch Spatial Transformer Layer 185 | 186 | Parameters 187 | ---------- 188 | 189 | U : float 190 | tensor of inputs [num_batch,height,width,num_channels] 191 | thetas : float 192 | a set of transformations for each input [num_batch,num_transforms,6] 193 | out_size : int 194 | the size of the output [out_height,out_width] 195 | 196 | Returns: float 197 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels] 198 | """ 199 | with tf.variable_scope(name): 200 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) 201 | indices = [[i]*num_transforms for i in xrange(num_batch)] 202 | input_repeated = tf.gather(U, tf.reshape(indices, [-1])) 203 | return transformer(input_repeated, thetas, out_size) 204 | -------------------------------------------------------------------------------- /tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/ 17 | import tensorflow as tf 18 | import numpy as np 19 | 20 | def conv2d(x, n_filters, 21 | k_h=5, k_w=5, 22 | stride_h=2, stride_w=2, 23 | stddev=0.02, 24 | activation=lambda x: x, 25 | bias=True, 26 | padding='SAME', 27 | name="Conv2D"): 28 | """2D Convolution with options for kernel size, stride, and init deviation. 29 | Parameters 30 | ---------- 31 | x : Tensor 32 | Input tensor to convolve. 33 | n_filters : int 34 | Number of filters to apply. 35 | k_h : int, optional 36 | Kernel height. 37 | k_w : int, optional 38 | Kernel width. 39 | stride_h : int, optional 40 | Stride in rows. 41 | stride_w : int, optional 42 | Stride in cols. 43 | stddev : float, optional 44 | Initialization's standard deviation. 45 | activation : arguments, optional 46 | Function which applies a nonlinearity 47 | padding : str, optional 48 | 'SAME' or 'VALID' 49 | name : str, optional 50 | Variable scope to use. 51 | Returns 52 | ------- 53 | x : Tensor 54 | Convolved input. 55 | """ 56 | with tf.variable_scope(name): 57 | w = tf.get_variable( 58 | 'w', [k_h, k_w, x.get_shape()[-1], n_filters], 59 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 60 | conv = tf.nn.conv2d( 61 | x, w, strides=[1, stride_h, stride_w, 1], padding=padding) 62 | if bias: 63 | b = tf.get_variable( 64 | 'b', [n_filters], 65 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 66 | conv = conv + b 67 | return conv 68 | 69 | def linear(x, n_units, scope=None, stddev=0.02, 70 | activation=lambda x: x): 71 | """Fully-connected network. 72 | Parameters 73 | ---------- 74 | x : Tensor 75 | Input tensor to the network. 76 | n_units : int 77 | Number of units to connect to. 78 | scope : str, optional 79 | Variable scope to use. 80 | stddev : float, optional 81 | Initialization's standard deviation. 82 | activation : arguments, optional 83 | Function which applies a nonlinearity 84 | Returns 85 | ------- 86 | x : Tensor 87 | Fully-connected output. 88 | """ 89 | shape = x.get_shape().as_list() 90 | 91 | with tf.variable_scope(scope or "Linear"): 92 | matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32, 93 | tf.random_normal_initializer(stddev=stddev)) 94 | return activation(tf.matmul(x, matrix)) 95 | 96 | # %% 97 | def weight_variable(shape): 98 | '''Helper function to create a weight variable initialized with 99 | a normal distribution 100 | Parameters 101 | ---------- 102 | shape : list 103 | Size of weight variable 104 | ''' 105 | #initial = tf.random_normal(shape, mean=0.0, stddev=0.01) 106 | initial = tf.zeros(shape) 107 | return tf.Variable(initial) 108 | 109 | # %% 110 | def bias_variable(shape): 111 | '''Helper function to create a bias variable initialized with 112 | a constant value. 113 | Parameters 114 | ---------- 115 | shape : list 116 | Size of weight variable 117 | ''' 118 | initial = tf.random_normal(shape, mean=0.0, stddev=0.01) 119 | return tf.Variable(initial) 120 | 121 | # %% 122 | def dense_to_one_hot(labels, n_classes=2): 123 | """Convert class labels from scalars to one-hot vectors.""" 124 | labels = np.array(labels) 125 | n_labels = labels.shape[0] 126 | index_offset = np.arange(n_labels) * n_classes 127 | labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32) 128 | labels_one_hot.flat[index_offset + labels.ravel()] = 1 129 | return labels_one_hot 130 | --------------------------------------------------------------------------------