├── .gitignore ├── README.md ├── __init__.py ├── affine.py ├── displacement.py ├── grid.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spatial_transformer_network 2 | Implements a spatial transformer layer as described in [1]_. 3 | Based on [2]_, edited to fix some bugs and added support for non-rigid transformation and 3d images. 4 | 5 | References 6 | ---------- 7 | .. [1] Spatial Transformer Networks 8 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 9 | Submitted on 5 Jun 2015 10 | .. [2] https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py 11 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from affine import batch_affine_warp2d, batch_affine_warp3d 2 | from grid import batch_mgrid 3 | from warp import batch_warp2d, batch_warp3d 4 | from displacement import batch_displacement_warp2d, batch_displacement_warp3d 5 | __all__ = ['batch_affine_warp2d', 6 | 'batch_affine_warp3d', 7 | 'batch_mgrid', 8 | 'batch_warp2d', 9 | 'batch_warp3d', 10 | 'batch_displacement_warp2d', 11 | 'batch_displacement_warp3d'] 12 | -------------------------------------------------------------------------------- /affine.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from grid import batch_mgrid 3 | from warp import batch_warp2d, batch_warp3d 4 | 5 | 6 | def batch_affine_warp2d(imgs, theta): 7 | """ 8 | affine transforms 2d images 9 | 10 | Parameters 11 | ---------- 12 | imgs : tf.Tensor 13 | images to be warped 14 | [n_batch, xlen, ylen, n_channel] 15 | theta : tf.Tensor 16 | parameters of affine transformation 17 | [n_batch, 6] 18 | 19 | Returns 20 | ------- 21 | output : tf.Tensor 22 | warped images 23 | [n_batch, xlen, ylen, n_channel] 24 | """ 25 | n_batch = tf.shape(imgs)[0] 26 | xlen = tf.shape(imgs)[1] 27 | ylen = tf.shape(imgs)[2] 28 | theta = tf.reshape(theta, [-1, 2, 3]) 29 | matrix = tf.slice(theta, [0, 0, 0], [-1, -1, 2]) 30 | t = tf.slice(theta, [0, 0, 2], [-1, -1, -1]) 31 | 32 | grids = batch_mgrid(n_batch, xlen, ylen) 33 | coords = tf.reshape(grids, [n_batch, 2, -1]) 34 | 35 | T_g = tf.batch_matmul(matrix, coords) + t 36 | T_g = tf.reshape(T_g, [n_batch, 2, xlen, ylen]) 37 | output = batch_warp2d(imgs, T_g) 38 | return output 39 | 40 | 41 | def batch_affine_warp3d(imgs, theta): 42 | """ 43 | affine transforms 3d images 44 | 45 | Parameters 46 | ---------- 47 | imgs : tf.Tensor 48 | images to be warped 49 | [n_batch, xlen, ylen, zlen, n_channel] 50 | theta : tf.Tensor 51 | parameters of affine transformation 52 | [n_batch, 12] 53 | 54 | Returns 55 | ------- 56 | output : tf.Tensor 57 | warped images 58 | [n_batch, xlen, ylen, zlen, n_channel] 59 | """ 60 | n_batch = tf.shape(imgs)[0] 61 | xlen = tf.shape(imgs)[1] 62 | ylen = tf.shape(imgs)[2] 63 | zlen = tf.shape(imgs)[3] 64 | theta = tf.reshape(theta, [-1, 3, 4]) 65 | matrix = tf.slice(theta, [0, 0, 0], [-1, -1, 3]) 66 | t = tf.slice(theta, [0, 0, 3], [-1, -1, -1]) 67 | 68 | grids = batch_mgrid(n_batch, xlen, ylen, zlen) 69 | grids = tf.reshape(grids, [n_batch, 3, -1]) 70 | 71 | T_g = tf.batch_matmul(matrix, grids) + t 72 | T_g = tf.reshape(T_g, [n_batch, 3, xlen, ylen, zlen]) 73 | output = batch_warp3d(imgs, T_g) 74 | return output 75 | 76 | 77 | if __name__ == '__main__': 78 | """ 79 | for test 80 | 81 | the result will be 82 | 83 | the original image 84 | [[ 0. 1. 2. 3. 4.] 85 | [ 5. 6. 7. 8. 9.] 86 | [ 10. 11. 12. 13. 14.] 87 | [ 15. 16. 17. 18. 19.] 88 | [ 20. 21. 22. 23. 24.]] 89 | 90 | identity warped 91 | [[ 0. 1. 2. 3. 4.] 92 | [ 5. 6. 7. 8. 9.] 93 | [ 10. 11. 12. 13. 14.] 94 | [ 15. 16. 17. 18. 19.] 95 | [ 20. 21. 22. 23. 24.]] 96 | 97 | zoom in warped 98 | [[ 6. 6.5 7. 7.5 8. ] 99 | [ 8.5 9. 9.5 10. 10.5] 100 | [ 11. 11.5 12. 12.5 13. ] 101 | [ 13.5 14. 14.5 15. 15.5] 102 | [ 16. 16.5 17. 17.5 18. ]] 103 | """ 104 | import numpy as np 105 | img = tf.to_float(np.arange(25).reshape(1, 5, 5, 1)) 106 | identity_matrix = tf.to_float([1, 0, 0, 0, 1, 0]) 107 | zoom_in_matrix = identity_matrix * 0.5 108 | identity_warped = batch_affine_warp2d(img, identity_matrix) 109 | zoom_in_warped = batch_affine_warp2d(img, zoom_in_matrix) 110 | with tf.Session() as sess: 111 | print sess.run(img[0, :, :, 0]) 112 | print sess.run(identity_warped[0, :, :, 0]) 113 | print sess.run(zoom_in_warped[0, :, :, 0]) 114 | -------------------------------------------------------------------------------- /displacement.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from grid import batch_mgrid 3 | from warp import batch_warp2d, batch_warp3d 4 | 5 | 6 | def batch_displacement_warp2d(imgs, vector_fields): 7 | """ 8 | warp images by free form transformation 9 | 10 | Parameters 11 | ---------- 12 | imgs : tf.Tensor 13 | images to be warped 14 | [n_batch, xlen, ylen, n_channel] 15 | vector_fields : tf.Tensor 16 | [n_batch, 2, xlen, ylen] 17 | 18 | Returns 19 | ------- 20 | output : tf.Tensor 21 | warped imagees 22 | [n_batch, xlen, ylen, n_channel] 23 | """ 24 | n_batch = tf.shape(imgs)[0] 25 | xlen = tf.shape(imgs)[1] 26 | ylen = tf.shape(imgs)[2] 27 | 28 | grids = batch_mgrid(n_batch, xlen, ylen) 29 | 30 | T_g = grids + vector_fields 31 | output = batch_warp2d(imgs, T_g) 32 | return output 33 | 34 | 35 | def batch_displacement_warp3d(imgs, vector_fields): 36 | """ 37 | warp images by displacement vector fields 38 | 39 | Parameters 40 | ---------- 41 | imgs : tf.Tensor 42 | images to be warped 43 | [n_batch, xlen, ylen, zlen, n_channel] 44 | vector_fields : tf.Tensor 45 | [n_batch, 3, xlen, ylen, zlen] 46 | 47 | Returns 48 | ------- 49 | output : tf.Tensor 50 | warped imagees 51 | [n_batch, xlen, ylen, zlen, n_channel] 52 | """ 53 | n_batch = tf.shape(imgs)[0] 54 | xlen = tf.shape(imgs)[1] 55 | ylen = tf.shape(imgs)[2] 56 | zlen = tf.shape(imgs)[3] 57 | 58 | grids = batch_mgrid(n_batch, xlen, ylen, zlen) 59 | 60 | T_g = grids + vector_fields 61 | output = batch_warp3d(imgs, T_g) 62 | return output 63 | -------------------------------------------------------------------------------- /grid.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def mgrid(*args, **kwargs): 5 | """ 6 | create orthogonal grid 7 | similar to np.mgrid 8 | 9 | Parameters 10 | ---------- 11 | args : int 12 | number of points on each axis 13 | low : float 14 | minimum coordinate value 15 | high : float 16 | maximum coordinate value 17 | 18 | Returns 19 | ------- 20 | grid : tf.Tensor [len(args), args[0], ...] 21 | orthogonal grid 22 | """ 23 | low = kwargs.pop("low", -1) 24 | high = kwargs.pop("high", 1) 25 | low = tf.to_float(low) 26 | high = tf.to_float(high) 27 | coords = (tf.linspace(low, high, arg) for arg in args) 28 | grid = tf.pack(tf.meshgrid(*coords, indexing='ij')) 29 | return grid 30 | 31 | 32 | def batch_mgrid(n_batch, *args, **kwargs): 33 | """ 34 | create batch of orthogonal grids 35 | similar to np.mgrid 36 | 37 | Parameters 38 | ---------- 39 | n_batch : int 40 | number of grids to create 41 | args : int 42 | number of points on each axis 43 | low : float 44 | minimum coordinate value 45 | high : float 46 | maximum coordinate value 47 | 48 | Returns 49 | ------- 50 | grids : tf.Tensor [n_batch, len(args), args[0], ...] 51 | batch of orthogonal grids 52 | """ 53 | grid = mgrid(*args, **kwargs) 54 | grid = tf.expand_dims(grid, 0) 55 | grids = tf.tile(grid, [n_batch] + [1 for _ in xrange(len(args) + 1)]) 56 | return grids 57 | 58 | 59 | if __name__ == '__main__': 60 | with tf.Session() as sess: 61 | print sess.run(batch_mgrid(2, 5, 4)) 62 | -------------------------------------------------------------------------------- /warp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def batch_warp2d(imgs, mappings): 5 | """ 6 | warp image using mapping function 7 | I(x) -> I(phi(x)) 8 | phi: mapping function 9 | 10 | Parameters 11 | ---------- 12 | imgs : tf.Tensor 13 | images to be warped 14 | [n_batch, xlen, ylen, n_channel] 15 | mapping : tf.Tensor 16 | grids representing mapping function 17 | [n_batch, xlen, ylen, 2] 18 | 19 | Returns 20 | ------- 21 | output : tf.Tensor 22 | warped images 23 | [n_batch, xlen, ylen, n_channel] 24 | """ 25 | n_batch = tf.shape(imgs)[0] 26 | coords = tf.reshape(mappings, [n_batch, 2, -1]) 27 | x_coords = tf.slice(coords, [0, 0, 0], [-1, 1, -1]) 28 | y_coords = tf.slice(coords, [0, 1, 0], [-1, 1, -1]) 29 | x_coords_flat = tf.reshape(x_coords, [-1]) 30 | y_coords_flat = tf.reshape(y_coords, [-1]) 31 | 32 | output = _interpolate2d(imgs, x_coords_flat, y_coords_flat) 33 | return output 34 | 35 | 36 | def batch_warp3d(imgs, mappings): 37 | """ 38 | warp image using mapping function 39 | I(x) -> I(phi(x)) 40 | phi: mapping function 41 | 42 | Parameters 43 | ---------- 44 | imgs : tf.Tensor 45 | images to be warped 46 | [n_batch, xlen, ylen, zlen, n_channel] 47 | mapping : tf.Tensor 48 | grids representing mapping function 49 | [n_batch, xlen, ylen, zlen, 3] 50 | 51 | Returns 52 | ------- 53 | output : tf.Tensor 54 | warped images 55 | [n_batch, xlen, ylen, zlen, n_channel] 56 | """ 57 | n_batch = tf.shape(imgs)[0] 58 | coords = tf.reshape(mappings, [n_batch, 3, -1]) 59 | x_coords = tf.slice(coords, [0, 0, 0], [-1, 1, -1]) 60 | y_coords = tf.slice(coords, [0, 1, 0], [-1, 1, -1]) 61 | z_coords = tf.slice(coords, [0, 2, 0], [-1, 1, -1]) 62 | x_coords_flat = tf.reshape(x_coords, [-1]) 63 | y_coords_flat = tf.reshape(y_coords, [-1]) 64 | z_coords_flat = tf.reshape(z_coords, [-1]) 65 | 66 | output = _interpolate3d(imgs, x_coords_flat, y_coords_flat, z_coords_flat) 67 | return output 68 | 69 | 70 | def _repeat(base_indices, n_repeats): 71 | base_indices = tf.matmul( 72 | tf.reshape(base_indices, [-1, 1]), 73 | tf.ones([1, n_repeats], dtype='int32')) 74 | return tf.reshape(base_indices, [-1]) 75 | 76 | 77 | def _interpolate2d(imgs, x, y): 78 | n_batch = tf.shape(imgs)[0] 79 | xlen = tf.shape(imgs)[1] 80 | ylen = tf.shape(imgs)[2] 81 | n_channel = tf.shape(imgs)[3] 82 | 83 | x = tf.to_float(x) 84 | y = tf.to_float(y) 85 | xlen_f = tf.to_float(xlen) 86 | ylen_f = tf.to_float(ylen) 87 | zero = tf.zeros([], dtype='int32') 88 | max_x = tf.cast(xlen - 1, 'int32') 89 | max_y = tf.cast(ylen - 1, 'int32') 90 | 91 | # scale indices from [-1, 1] to [0, xlen/ylen] 92 | x = (x + 1.) * (xlen_f - 1.) * 0.5 93 | y = (y + 1.) * (ylen_f - 1.) * 0.5 94 | 95 | # do sampling 96 | x0 = tf.cast(tf.floor(x), 'int32') 97 | x1 = x0 + 1 98 | y0 = tf.cast(tf.floor(y), 'int32') 99 | y1 = y0 + 1 100 | 101 | x0 = tf.clip_by_value(x0, zero, max_x) 102 | x1 = tf.clip_by_value(x1, zero, max_x) 103 | y0 = tf.clip_by_value(y0, zero, max_y) 104 | y1 = tf.clip_by_value(y1, zero, max_y) 105 | base = _repeat(tf.range(n_batch) * xlen * ylen, ylen * xlen) 106 | base_x0 = base + x0 * ylen 107 | base_x1 = base + x1 * ylen 108 | index00 = base_x0 + y0 109 | index01 = base_x0 + y1 110 | index10 = base_x1 + y0 111 | index11 = base_x1 + y1 112 | 113 | # use indices to lookup pixels in the flat image and restore 114 | # n_channel dim 115 | imgs_flat = tf.reshape(imgs, [-1, n_channel]) 116 | imgs_flat = tf.to_float(imgs_flat) 117 | I00 = tf.gather(imgs_flat, index00) 118 | I01 = tf.gather(imgs_flat, index01) 119 | I10 = tf.gather(imgs_flat, index10) 120 | I11 = tf.gather(imgs_flat, index11) 121 | 122 | # and finally calculate interpolated values 123 | dx = x - tf.to_float(x0) 124 | dy = y - tf.to_float(y0) 125 | w00 = tf.expand_dims((1. - dx) * (1. - dy), 1) 126 | w01 = tf.expand_dims((1. - dx) * dy, 1) 127 | w10 = tf.expand_dims(dx * (1. - dy), 1) 128 | w11 = tf.expand_dims(dx * dy, 1) 129 | output = tf.add_n([w00*I00, w01*I01, w10*I10, w11*I11]) 130 | 131 | # reshape 132 | output = tf.reshape(output, [n_batch, xlen, ylen, n_channel]) 133 | 134 | return output 135 | 136 | 137 | def _interpolate3d(imgs, x, y, z): 138 | n_batch = tf.shape(imgs)[0] 139 | xlen = tf.shape(imgs)[1] 140 | ylen = tf.shape(imgs)[2] 141 | zlen = tf.shape(imgs)[3] 142 | n_channel = tf.shape(imgs)[4] 143 | 144 | x = tf.to_float(x) 145 | y = tf.to_float(y) 146 | z = tf.to_float(z) 147 | xlen_f = tf.to_float(xlen) 148 | ylen_f = tf.to_float(ylen) 149 | zlen_f = tf.to_float(zlen) 150 | zero = tf.zeros([], dtype='int32') 151 | max_x = tf.cast(xlen - 1, 'int32') 152 | max_y = tf.cast(ylen - 1, 'int32') 153 | max_z = tf.cast(zlen - 1, 'int32') 154 | 155 | # scale indices from [-1, 1] to [0, xlen/ylen] 156 | x = (x + 1.) * (xlen_f - 1.) * 0.5 157 | y = (y + 1.) * (ylen_f - 1.) * 0.5 158 | z = (z + 1.) * (zlen_f - 1.) * 0.5 159 | 160 | # do sampling 161 | x0 = tf.cast(tf.floor(x), 'int32') 162 | x1 = x0 + 1 163 | y0 = tf.cast(tf.floor(y), 'int32') 164 | y1 = y0 + 1 165 | z0 = tf.cast(tf.floor(z), 'int32') 166 | z1 = z0 + 1 167 | 168 | x0 = tf.clip_by_value(x0, zero, max_x) 169 | x1 = tf.clip_by_value(x1, zero, max_x) 170 | y0 = tf.clip_by_value(y0, zero, max_y) 171 | y1 = tf.clip_by_value(y1, zero, max_y) 172 | z0 = tf.clip_by_value(z0, zero, max_z) 173 | z1 = tf.clip_by_value(z1, zero, max_z) 174 | base = _repeat(tf.range(n_batch) * xlen * ylen * zlen, 175 | xlen * ylen * zlen) 176 | base_x0 = base + x0 * ylen * zlen 177 | base_x1 = base + x1 * ylen * zlen 178 | base00 = base_x0 + y0 * zlen 179 | base01 = base_x0 + y1 * zlen 180 | base10 = base_x1 + y0 * zlen 181 | base11 = base_x1 + y1 * zlen 182 | index000 = base00 + z0 183 | index001 = base00 + z1 184 | index010 = base01 + z0 185 | index011 = base01 + z1 186 | index100 = base10 + z0 187 | index101 = base10 + z1 188 | index110 = base11 + z0 189 | index111 = base11 + z1 190 | 191 | # use indices to lookup pixels in the flat image and restore 192 | # n_channel dim 193 | imgs_flat = tf.reshape(imgs, [-1, n_channel]) 194 | imgs_flat = tf.to_float(imgs_flat) 195 | I000 = tf.gather(imgs_flat, index000) 196 | I001 = tf.gather(imgs_flat, index001) 197 | I010 = tf.gather(imgs_flat, index010) 198 | I011 = tf.gather(imgs_flat, index011) 199 | I100 = tf.gather(imgs_flat, index100) 200 | I101 = tf.gather(imgs_flat, index101) 201 | I110 = tf.gather(imgs_flat, index110) 202 | I111 = tf.gather(imgs_flat, index111) 203 | 204 | # and finally calculate interpolated values 205 | dx = x - tf.to_float(x0) 206 | dy = y - tf.to_float(y0) 207 | dz = z - tf.to_float(z0) 208 | w000 = tf.expand_dims((1. - dx) * (1. - dy) * (1. - dz), 1) 209 | w001 = tf.expand_dims((1. - dx) * (1. - dy) * dz, 1) 210 | w010 = tf.expand_dims((1. - dx) * dy * (1. - dz), 1) 211 | w011 = tf.expand_dims((1. - dx) * dy * dz, 1) 212 | w100 = tf.expand_dims(dx * (1. - dy) * (1. - dz), 1) 213 | w101 = tf.expand_dims(dx * (1. - dy) * dz, 1) 214 | w110 = tf.expand_dims(dx * dy * (1. - dz), 1) 215 | w111 = tf.expand_dims(dx * dy * dz, 1) 216 | output = tf.add_n([w000 * I000, w001 * I001, w010 * I010, w011 * I011, 217 | w100 * I100, w101 * I101, w110 * I110, w111 * I111]) 218 | 219 | # reshape 220 | output = tf.reshape(output, [n_batch, xlen, ylen, zlen, n_channel]) 221 | 222 | return output 223 | --------------------------------------------------------------------------------