├── .gitattributes ├── README.md ├── SpatialTransformer.py ├── Spatial_test.ipynb └── froot.jpg /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Spatial Transformer Layer 2 | Using Spatial Transformer Layer[1] with keras (theano backend). The file 'SpatialTransformer.py' is originally from skaae's Lasagne version[[2]](https://github.com/skaae/transformer_network/blob/master/transformerlayer.py). 3 | 4 | You can find a simple example in the notebook file. 5 | 6 | ## Reference 7 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015). 8 | 9 | [2] https://github.com/skaae/transformer_network 10 | -------------------------------------------------------------------------------- /SpatialTransformer.py: -------------------------------------------------------------------------------- 1 | import theano.tensor as T 2 | from keras import backend as K 3 | from keras.engine.topology import Layer 4 | 5 | class SpatialTransformerLayer(Layer): 6 | """Spatial Transformer Layer 7 | 8 | This file is highly based on [1]_, written by skaae. 9 | 10 | Implements a spatial transformer layer as described in [2]_. 11 | 12 | Parameters 13 | ---------- 14 | incomings : a list of [:class:`Layer` instance or a tuple] 15 | The layers feeding into this layer. The list must have two entries with 16 | the first network being a convolutional net and the second layer 17 | being the transformation matrices. The first network should have output 18 | shape [num_batch, num_channels, height, width]. The output of the 19 | second network should be [num_batch, 6]. 20 | downsample_fator : float 21 | A value of 1 will keep the orignal size of the image. 22 | Values larger than 1 will down sample the image. Values below 1 will 23 | upsample the image. 24 | example image: height= 100, width = 200 25 | downsample_factor = 2 26 | output image will then be 50, 100 27 | 28 | References 29 | ---------- 30 | .. [1] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 31 | 32 | .. [2] Spatial Transformer Networks 33 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 34 | Submitted on 5 Jun 2015 35 | 36 | """ 37 | def __init__(self, downsample_factor=1, **kwargs): 38 | super(SpatialTransformerLayer, self).__init__(**kwargs) 39 | self.downsample_factor = downsample_factor 40 | 41 | def get_output_shape_for(self, input_shapes): 42 | # input dims are bs, num_filters, height, width. Scale height and width 43 | # by downsample factor 44 | shp = input_shapes[0] 45 | return list(shp[:2]) + [int(s//self.downsample_factor) for s in shp[2:]], input_shapes[1] 46 | 47 | def call(self, x, mask=None): 48 | # theta should be shape (batchsize, 2, 3) 49 | # see eq. (1) and sec 3.1 in ref [2] 50 | conv_input, theta = x 51 | output = _transform(theta, conv_input, self.downsample_factor) 52 | return output 53 | 54 | 55 | ########################## 56 | # TRANSFORMER LAYERS # 57 | ########################## 58 | 59 | 60 | def _repeat(x, n_repeats): 61 | rep = T.ones((n_repeats,), dtype='int32').dimshuffle('x', 0) 62 | x = K.dot(x.reshape((-1, 1)), rep) 63 | return x.flatten() 64 | 65 | 66 | def _interpolate(im, x, y, downsample_factor): 67 | # constants 68 | num_batch, height, width, channels = im.shape 69 | height_f = K.cast(height, 'float32') 70 | width_f = K.cast(width, 'float32') 71 | out_height = K.cast(height_f // downsample_factor, 'int64') 72 | out_width = K.cast(width_f // downsample_factor, 'int64') 73 | zero = K.zeros([], dtype='int64') 74 | max_y = K.cast(im.shape[1] - 1, 'int64') 75 | max_x = K.cast(im.shape[2] - 1, 'int64') 76 | 77 | # scale indices from [-1, 1] to [0, width/height] 78 | x = (x + 1.0)*(width_f) / 2.0 79 | y = (y + 1.0)*(height_f) / 2.0 80 | 81 | # do sampling 82 | x0 = K.cast(T.floor(x), 'int64') 83 | x1 = x0 + 1 84 | y0 = K.cast(T.floor(y), 'int64') 85 | y1 = y0 + 1 86 | 87 | x0 = T.clip(x0, zero, max_x) 88 | x1 = T.clip(x1, zero, max_x) 89 | y0 = T.clip(y0, zero, max_y) 90 | y1 = T.clip(y1, zero, max_y) 91 | dim2 = width 92 | dim1 = width*height 93 | base = _repeat( 94 | T.arange(num_batch, dtype='int32')*dim1, out_height*out_width) 95 | base_y0 = base + y0*dim2 96 | base_y1 = base + y1*dim2 97 | idx_a = base_y0 + x0 98 | idx_b = base_y1 + x0 99 | idx_c = base_y0 + x1 100 | idx_d = base_y1 + x1 101 | 102 | # use indices to lookup pixels in the flat image and restore channels dim 103 | im_flat = K.reshape(im, (-1, channels)) 104 | Ia = im_flat[idx_a] 105 | Ib = im_flat[idx_b] 106 | Ic = im_flat[idx_c] 107 | Id = im_flat[idx_d] 108 | 109 | # and finanly calculate interpolated values 110 | x0_f = K.cast(x0, 'float32') 111 | x1_f = K.cast(x1, 'float32') 112 | y0_f = K.cast(y0, 'float32') 113 | y1_f = K.cast(y1, 'float32') 114 | wa = ((x1_f-x) * (y1_f-y)).dimshuffle(0, 'x') 115 | wb = ((x1_f-x) * (y-y0_f)).dimshuffle(0, 'x') 116 | wc = ((x-x0_f) * (y1_f-y)).dimshuffle(0, 'x') 117 | wd = ((x-x0_f) * (y-y0_f)).dimshuffle(0, 'x') 118 | output = K.sum([wa*Ia, wb*Ib, wc*Ic, wd*Id], axis=0) 119 | return output 120 | 121 | 122 | def _linspace(start, stop, num): 123 | # produces results identical to: 124 | # np.linspace(start, stop, num) 125 | start = K.cast(start, 'float32') 126 | stop = K.cast(stop, 'float32') 127 | num = K.cast(num, 'float32') 128 | step = (stop-start)/(num-1) 129 | return T.arange(num, dtype='float32')*step+start 130 | 131 | 132 | def _meshgrid(height, width): 133 | # This should be equivalent to: 134 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 135 | # np.linspace(-1, 1, height)) 136 | # ones = np.ones(np.prod(x_t.shape)) 137 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 138 | x_t = K.dot(T.ones((height, 1)), 139 | _linspace(-1.0, 1.0, width).dimshuffle('x', 0)) 140 | y_t = K.dot(_linspace(-1.0, 1.0, height).dimshuffle(0, 'x'), 141 | T.ones((1, width))) 142 | 143 | x_t_flat = x_t.reshape((1, -1)) 144 | y_t_flat = y_t.reshape((1, -1)) 145 | ones = K.ones_like(x_t_flat) 146 | grid = K.concatenate([x_t_flat, y_t_flat, ones], axis=0) 147 | return grid 148 | 149 | 150 | def _transform(theta, input, downsample_factor): 151 | num_batch, num_channels, height, width = input.shape 152 | theta = K.reshape(theta, (-1, 2, 3)) 153 | 154 | # grid of (x_t, y_t, 1), eq (1) in ref [2] 155 | height_f = K.cast(height, 'float32') 156 | width_f = K.cast(width, 'float32') 157 | out_height = K.cast(height_f // downsample_factor, 'int64') 158 | out_width = K.cast(width_f // downsample_factor, 'int64') 159 | grid = _meshgrid(out_height, out_width) 160 | 161 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 162 | T_g = K.dot(theta, grid) 163 | x_s, y_s = T_g[:, 0], T_g[:, 1] 164 | x_s_flat = x_s.flatten() 165 | y_s_flat = y_s.flatten() 166 | 167 | # dimshuffle input to (bs, height, width, channels) 168 | #input_dim = input.dimshuffle(0, 2, 3, 1) 169 | input_dim = input.transpose(0, 2, 3, 1) 170 | input_transformed = _interpolate( 171 | input_dim, x_s_flat, y_s_flat, 172 | downsample_factor) 173 | 174 | output = K.reshape(input_transformed, 175 | (num_batch, out_height, out_width, num_channels)) 176 | output = output.transpose(0, 3, 1, 2) 177 | return output 178 | -------------------------------------------------------------------------------- /froot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoyizhi68/keras-Spatial-Transformer-Layer/3771702c7591c019e24d1000a3c3e91ddb64c17f/froot.jpg --------------------------------------------------------------------------------