├── README.md ├── ref_img2.png ├── spatial_transformer.py ├── test.py ├── train.py └── weight file.txt /README.md: -------------------------------------------------------------------------------- 1 | # ST-RoomNet: Learning Room Layout Estimation From Single Image Through Unsupervised Spatial Transformations 2 | 3 | This is the official implementation of ST-RoomNet: https://openaccess.thecvf.com/content/CVPR2023W/VOCVALC/html/Ibrahem_ST-RoomNet_Learning_Room_Layout_Estimation_From_Single_Image_Through_Unsupervised_CVPRW_2023_paper.html 4 | 5 | The spatial transformer module is based on this repo: https://github.com/dantkz/spatial-transformer-tensorflow 6 | 7 | We modified the spatial transformer module to work in Tensorflow2.x and added more features such as nearest neighbor interpolation in addition to the original bilinear and bicubic interpolations. 8 | 9 | requirements: 10 | 11 | opencv 4.4.1 12 | 13 | tensorflow 2.9.1 14 | 15 | If you use this code, please cite the paper as follows: 16 | 17 | @InProceedings{Ibrahem_2023_CVPR,
18 | author = {Ibrahem, Hatem and Salem, Ahmed and Kang, Hyun-Soo},
19 | title = {ST-RoomNet: Learning Room Layout Estimation From Single Image Through Unsupervised Spatial Transformations},
20 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
21 | month = {June},
22 | year = {2023},
23 | pages = {3375-3383}
24 | } 25 | -------------------------------------------------------------------------------- /ref_img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab231/ST-RoomNet/14f9492bad6479a0d88aeb688de70bd1d0ac8d21/ref_img2.png -------------------------------------------------------------------------------- /spatial_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Spatial Transformer Networks 3 | 4 | References 5 | ---------- 6 | [1] Spatial Transformer Networks 7 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 8 | Submitted on 5 Jun 2015 9 | 10 | [2] https://github.com/tensorflow/models/tree/master/transformer/transformerlayer.py 11 | 12 | [3] https://github.com/daviddao/spatial-transformer-tensorflow 13 | 14 | [4] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 15 | 16 | [5] https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/special.py 17 | 18 | [6] Fred L. Bookstein (1989): 19 | Principal warps: thin-plate splines and the decomposition of deformations. 20 | IEEE Transactions on Pattern Analysis and Machine Intelligence. 21 | http://doi.org/10.1109/34.24792 22 | 23 | """ 24 | 25 | import tensorflow as tf 26 | import math 27 | 28 | 29 | """ 30 | Legacy Function 31 | 32 | """ 33 | def transformer(inp, theta, out_size, name='SpatialTransformer', **kwargs): 34 | #with tf.variable_scope(name): 35 | stl = AffineTransformer(out_size) 36 | output = stl.transform(inp, theta, out_size) 37 | return output 38 | 39 | class AffineVolumeTransformer(object): 40 | """Spatial Affine Volume Transformer Layer 41 | Implements a spatial transformer layer for volumetric 3D input. 42 | Implemented by Daniyar Turmukhambetov. 43 | """ 44 | 45 | def __init__(self, out_size, name='SpatialAffineVolumeTransformer', interp_method='bilinear', **kwargs): 46 | """ 47 | Parameters 48 | ---------- 49 | out_size : tuple of three ints 50 | The size of the output of the spatial network (depth, height, width), i.e. z, y, x 51 | name : string 52 | The scope name of the variables in this network. 53 | 54 | """ 55 | self.name = name 56 | self.out_size = out_size 57 | self.param_dim = 3*4 58 | self.interp_method=interp_method 59 | 60 | #with tf.variable_scope(self.name): 61 | self.voxel_grid = _meshgrid3d(self.out_size) 62 | 63 | 64 | def transform(self, inp, theta): 65 | """ 66 | Affine Transformation of input tensor inp with parameters theta 67 | 68 | Parameters 69 | ---------- 70 | inp : float 71 | The input tensor should have the shape 72 | [batch_size, depth, height, width, in_channels]. 73 | theta: float 74 | The output of the localisation network 75 | should have the shape 76 | [batch_size, 12]. 77 | Notes 78 | ----- 79 | To initialize the network to the identity transform initialize ``theta`` to : 80 | identity = np.array([[1., 0., 0., 0.], 81 | [0., 1., 0., 0.], 82 | [0., 0., 1., 0.]]) 83 | identity = identity.flatten() 84 | theta = tf.Variable(initial_value=identity) 85 | 86 | """ 87 | 88 | x_s, y_s, z_s = self._transform(inp, theta) 89 | 90 | output = _interpolate3d( 91 | inp, x_s, y_s, z_s, 92 | self.out_size, 93 | method=self.interp_method 94 | ) 95 | 96 | batch_size, _, _, _, num_channels = inp.get_shape().as_list() 97 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], self.out_size[2], num_channels]) 98 | 99 | return output 100 | 101 | 102 | 103 | def _transform(self, inp, theta): 104 | 105 | batch_size, _, _, _, num_channels = inp.get_shape().as_list() 106 | 107 | theta = tf.reshape(theta, (-1, 3, 4)) 108 | voxel_grid = tf.tile(self.voxel_grid, [batch_size]) 109 | voxel_grid = tf.reshape(voxel_grid, [batch_size, 4, -1]) 110 | 111 | # Transform A x (x_t, y_t, z_t, 1)^T -> (x_s, y_s, z_s) 112 | T_g = tf.matmul(theta, voxel_grid) 113 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 114 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 115 | z_s = tf.slice(T_g, [0, 2, 0], [-1, 1, -1]) 116 | 117 | x_s_flat = tf.reshape(x_s, [-1]) 118 | y_s_flat = tf.reshape(y_s, [-1]) 119 | z_s_flat = tf.reshape(z_s, [-1]) 120 | return x_s_flat, y_s_flat, z_s_flat 121 | 122 | 123 | class AffineTransformer(object): 124 | """Spatial Affine Transformer Layer 125 | 126 | Implements a spatial transformer layer as described in [1]_. 127 | Based on [2]_ and [3]_. Edited by Daniyar Turmukhambetov. 128 | 129 | """ 130 | 131 | def __init__(self, out_size, name='SpatialAffineTransformer', interp_method='bilinear', **kwargs): 132 | """ 133 | Parameters 134 | ---------- 135 | out_size : tuple of two ints 136 | The size of the output of the spatial network (height, width). 137 | name : string 138 | The scope name of the variables in this network. 139 | 140 | """ 141 | self.name = name 142 | self.out_size = out_size 143 | self.param_dim = 6 144 | self.interp_method=interp_method 145 | 146 | #with tf.variable_scope(self.name): 147 | self.pixel_grid = _meshgrid(self.out_size) 148 | 149 | 150 | def transform(self, inp, theta): 151 | """ 152 | Affine Transformation of input tensor inp with parameters theta 153 | 154 | Parameters 155 | ---------- 156 | inp : float 157 | The input tensor should have the shape 158 | [batch_size, height, width, num_channels]. 159 | theta: float 160 | The output of the localisation network 161 | should have the shape 162 | [batch_size, 6]. 163 | Notes 164 | ----- 165 | To initialize the network to the identity transform initialize ``theta`` to : 166 | identity = np.array([[1., 0., 0.], 167 | [0., 1., 0.]]) 168 | identity = identity.flatten() 169 | theta = tf.Variable(initial_value=identity) 170 | 171 | """ 172 | #with tf.variable_scope(self.name): 173 | x_s, y_s = self._transform(inp, theta) 174 | 175 | output = _interpolate( 176 | inp, x_s, y_s, 177 | self.out_size, 178 | method=self.interp_method 179 | ) 180 | 181 | batch_size, _, _, num_channels = inp.get_shape().as_list() 182 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels]) 183 | 184 | return output 185 | 186 | 187 | 188 | def _transform(self, inp, theta): 189 | #with tf.variable_scope(self.name + '_affine_transform'): 190 | batch_size, _, _, num_channels = inp.get_shape().as_list() 191 | 192 | theta = tf.reshape(theta, (-1, 2, 3)) 193 | pixel_grid = tf.tile(self.pixel_grid, [batch_size]) 194 | pixel_grid = tf.reshape(pixel_grid, [batch_size, 3, -1]) 195 | 196 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 197 | T_g = tf.matmul(theta, pixel_grid) 198 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 199 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 200 | x_s_flat = tf.reshape(x_s, [-1]) 201 | y_s_flat = tf.reshape(y_s, [-1]) 202 | return x_s_flat, y_s_flat 203 | 204 | class ProjectiveTransformer(object): 205 | """Spatial Projective Transformer Layer 206 | 207 | Implements a spatial transformer layer as described in [1]_. 208 | Based on [2]_ and [3]_. Edited by Daniyar Turmukhambetov. 209 | 210 | """ 211 | 212 | def __init__(self, out_size, name='SpatialProjectiveTransformer', interp_method='bilinear', **kwargs): #interp_method='nearest' can be only used in inference mode 213 | """ 214 | Parameters 215 | ---------- 216 | out_size : tuple of two ints 217 | The size of the output of the spatial network (height, width). 218 | name : string 219 | The scope name of the variables in this network. 220 | 221 | """ 222 | self.name = name 223 | self.out_size = out_size 224 | self.param_dim = 8 225 | self.interp_method=interp_method 226 | 227 | #with tf.variable_scope(self.name): 228 | self.pixel_grid = _meshgrid(self.out_size) 229 | 230 | 231 | def transform(self, inp, theta): 232 | """ 233 | Projective Transformation of input tensor inp with parameters theta 234 | 235 | Parameters 236 | ---------- 237 | inp : float 238 | The input tensor should have the shape 239 | [batch_size, height, width, num_channels]. 240 | theta: float 241 | The output of the localisation network 242 | should have the shape 243 | [batch_size, 8]. 244 | Notes 245 | ----- 246 | To initialize the network to the identity transform initialize ``theta`` to : 247 | identity = np.array([1., 0., 0., 248 | [0., 1., 0., 249 | [0., 0.]) 250 | theta = tf.Variable(initial_value=identity) 251 | 252 | """ 253 | #with tf.variable_scope(self.name): 254 | x_s, y_s = self._transform(inp, theta) 255 | 256 | output = _interpolate( 257 | inp, x_s, y_s, 258 | self.out_size, 259 | method=self.interp_method 260 | ) 261 | 262 | batch_size, _, _, num_channels = inp.get_shape().as_list() 263 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels]) 264 | 265 | return output 266 | 267 | 268 | 269 | def _transform(self, inp, theta): 270 | #with tf.variable_scope(self.name + '_projective_transform'): 271 | batch_size, _, _, num_channels = inp.get_shape().as_list() 272 | 273 | theta = tf.reshape(theta, (batch_size, 8)) 274 | theta = tf.concat([theta, tf.ones([batch_size, 1])], 1) 275 | theta = tf.reshape(theta, (batch_size, 3, 3)) 276 | 277 | pixel_grid = tf.tile(self.pixel_grid, [batch_size]) 278 | pixel_grid = tf.reshape(pixel_grid, [batch_size, 3, -1]) 279 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 280 | T_g = tf.matmul(theta, pixel_grid) 281 | 282 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 283 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 284 | z_s = tf.slice(T_g, [0, 2, 0], [-1, 1, -1]) 285 | 286 | 287 | z_s += 0.0000001 288 | 289 | x_s = x_s/z_s 290 | y_s = y_s/z_s 291 | 292 | x_s_flat = tf.reshape(x_s, [-1]) 293 | y_s_flat = tf.reshape(y_s, [-1]) 294 | 295 | return x_s_flat, y_s_flat 296 | 297 | class ElasticTransformer(object): 298 | """Spatial Elastic Transformer Layer with Thin Plate Spline deformations 299 | 300 | Implements a spatial transformer layer as described in [1]_. 301 | Based on [4]_ and [5]_. Edited by Daniyar Turmukhambetov. 302 | 303 | """ 304 | 305 | def __init__(self, out_size, param_dim=2*16, name='SpatialElasticTransformer', interp_method='bilinear', **kwargs): 306 | """ 307 | Parameters 308 | ---------- 309 | out_size : tuple of two ints 310 | The size of the output of the spatial network (height, width). 311 | param_dim: int 312 | The 2 x number of control points that define 313 | Thin Plate Splines deformation field. 314 | number of control points *MUST* be a square of an integer. 315 | 2 x 16 by default. 316 | name : string 317 | The scope name of the variables in this network. 318 | 319 | """ 320 | num_control_points = int(param_dim/2) 321 | assert param_dim == 2*num_control_points, 'param_dim must be 2 times a square of an integer.' 322 | 323 | self.name = name 324 | self.param_dim = param_dim 325 | self.interp_method=interp_method 326 | self.num_control_points = num_control_points 327 | self.out_size = out_size 328 | 329 | self.grid_size = math.floor(math.sqrt(self.num_control_points)) 330 | assert self.grid_size*self.grid_size == self.num_control_points, 'num_control_points must be a square of an int' 331 | 332 | #with tf.variable_scope(self.name): 333 | # Create source grid 334 | self.source_points = ElasticTransformer.get_meshgrid(self.grid_size, self.grid_size) 335 | # Construct pixel grid 336 | self.pixel_grid = ElasticTransformer.get_meshgrid(self.out_size[1], self.out_size[0]) 337 | self.num_pixels = self.out_size[0]*self.out_size[1] 338 | self.pixel_distances, self.L_inv = self._initialize_tps(self.source_points, self.pixel_grid) 339 | 340 | 341 | def transform(self, inp, theta, forward=True, **kwargs): 342 | """ 343 | Parameters 344 | ---------- 345 | inp : float 346 | The input tensor should have the shape 347 | [batch_size, height, width, num_channels]. 348 | theta: float 349 | Should have the shape of [batch_size, self.num_control_points x 2] 350 | Theta is the output of the localisation network, so it is 351 | the x and y offsets of the destination coordinates 352 | of each of the control points. 353 | Notes 354 | ----- 355 | To initialize the network to the identity transform initialize ``theta`` to zeros: 356 | identity = np.zeros(16*2) 357 | identity = identity.flatten() 358 | theta = tf.Variable(initial_value=identity) 359 | 360 | """ 361 | #with tf.variable_scope(self.name): 362 | # reshape destination offsets to be (batch_size, 2, num_control_points) 363 | # and add to source_points 364 | source_points = tf.expand_dims(self.source_points, 0) 365 | theta = source_points + tf.reshape(theta, [-1, 2, self.num_control_points]) 366 | 367 | x_s, y_s = self._transform( 368 | inp, theta, self.num_control_points, 369 | self.pixel_grid, self.num_pixels, 370 | self.pixel_distances, self.L_inv, 371 | self.name + '_elastic_transform', forward) 372 | if forward: 373 | output = _interpolate( 374 | inp, x_s, y_s, 375 | self.out_size, 376 | method=self.interp_method 377 | ) 378 | else: 379 | rx_s, ry_s = self._transform( 380 | inp, theta, self.num_control_points, 381 | self.pixel_grid, self.num_pixels, 382 | self.pixel_distances, self.L_inv, 383 | self.name + '_elastic_transform', forward) 384 | output = _interpolate( 385 | inp, rx_s, ry_s, 386 | self.out_size, 387 | method=self.interp_method 388 | ) 389 | pass 390 | 391 | 392 | batch_size, _, _, num_channels = inp.get_shape().as_list() 393 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels]) 394 | return output 395 | 396 | def _transform(self, inp, theta, num_control_points, pixel_grid, num_pixels, pixel_distances, L_inv, name, forward=True): 397 | #with tf.variable_scope(name): 398 | batch_size = inp.get_shape().as_list()[0] 399 | 400 | # Solve as in ref [2] 401 | theta = tf.reshape(theta, [-1, num_control_points]) 402 | coefficients = tf.matmul(theta, L_inv) 403 | coefficients = tf.reshape(coefficients, [-1, 2, num_control_points+3]) 404 | 405 | # Transform each point on the target grid (out_size) 406 | right_mat = tf.concat([pixel_grid, pixel_distances], 0) 407 | right_mat = tf.tile(tf.expand_dims(right_mat, 0), (batch_size, 1, 1)) 408 | transformed_points = tf.matmul(coefficients, right_mat) 409 | transformed_points = tf.reshape(transformed_points, [-1, 2, num_pixels]) 410 | 411 | x_s_flat = tf.reshape(transformed_points[:,0,:], [-1]) 412 | y_s_flat = tf.reshape(transformed_points[:,1,:], [-1]) 413 | 414 | return x_s_flat, y_s_flat 415 | 416 | # U function for the new point and each source point 417 | @staticmethod 418 | def U_func(points1, points2): 419 | # The U function is simply U(r) = r^2 * log(r^2), as in ref [5]_, 420 | # where r is the euclidean distance 421 | r_sq = tf.transpose(tf.reduce_sum(tf.square(points1 - points2), axis=0)) 422 | log_r = tf.log(r_sq) 423 | log_r = tf.where(tf.is_inf(log_r), tf.zeros_like(log_r), log_r) 424 | phi = r_sq * log_r 425 | 426 | # The U function is simply U(r) = r, where r is the euclidean distance 427 | #phi = tf.sqrt(r_sq) 428 | return phi 429 | 430 | 431 | @staticmethod 432 | def get_meshgrid(grid_size_x, grid_size_y): 433 | # Create 2 x num_points array of source points 434 | x_points, y_points = tf.meshgrid( 435 | tf.linspace(-1.0, 1.0, int(grid_size_x)), 436 | tf.linspace(-1.0, 1.0, int(grid_size_y))) 437 | x_flat = tf.reshape(x_points, (1,-1)) 438 | y_flat = tf.reshape(y_points, (1,-1)) 439 | points = tf.concat([x_flat, y_flat], 0) 440 | return points 441 | 442 | 443 | def _initialize_tps(self, source_points, pixel_grid): 444 | """ 445 | Initializes the thin plate spline calculation by creating the source 446 | point array and the inverted L matrix used for calculating the 447 | transformations as in ref [5]_ 448 | 449 | Returns 450 | ---------- 451 | right_mat : float 452 | Tensor of shape [num_control_points + 3, out_height*out_width]. 453 | L_inv : float 454 | Tensor of shape [num_control_points + 3, num_control_points]. 455 | source_points : float 456 | Tensor of shape (2, num_control_points). 457 | 458 | """ 459 | 460 | tL = ElasticTransformer.U_func(tf.expand_dims(source_points, 2), tf.expand_dims(source_points, 1)) 461 | 462 | # Initialize L 463 | L_top = tf.concat([tf.zeros([2,3]), source_points], 1) 464 | L_mid = tf.concat([tf.zeros([1, 2]), tf.ones([1, self.num_control_points+1])], 1) 465 | L_bot = tf.concat([tf.transpose(source_points), tf.ones([self.num_control_points, 1]), tL], 1) 466 | 467 | L = tf.concat([L_top, L_mid, L_bot], 0) 468 | L_inv = tf.matrix_inverse(L) 469 | 470 | # Construct right mat 471 | to_transform = tf.expand_dims(pixel_grid, 2) 472 | stacked_source_points = tf.expand_dims(source_points, 1) 473 | distances = ElasticTransformer.U_func(to_transform, stacked_source_points) 474 | 475 | # Add in the coefficients for the affine translation (1, x, and y, 476 | # corresponding to a_1, a_x, and a_y) 477 | ones = tf.ones(shape=[1, self.num_pixels]) 478 | pixel_distances = tf.concat([ones, distances], 0) 479 | L_inv = tf.transpose(L_inv[:,3:]) 480 | 481 | return pixel_distances, L_inv 482 | 483 | 484 | 485 | """ 486 | Common Functions 487 | 488 | """ 489 | 490 | def _meshgrid3d(out_size): 491 | """ 492 | the regular grid of coordinates to sample the values after the transformation 493 | 494 | """ 495 | #with tf.variable_scope('meshgrid3d'): 496 | 497 | # This should be equivalent to: 498 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 499 | # np.linspace(-1, 1, height)) 500 | # ones = np.ones(np.prod(x_t.shape)) 501 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 502 | 503 | #z_t, y_t, x_t = tf.meshgrid(tf.linspace(0., out_size[0]-1., out_size[0]), 504 | # tf.linspace(0., out_size[1]-1., out_size[1]), 505 | # tf.linspace(0., out_size[2]-1., out_size[2]), indexing='ij') 506 | 507 | z_t, y_t, x_t = tf.meshgrid(tf.linspace(-1., 1., out_size[0]), 508 | tf.linspace(-1., 1., out_size[1]), 509 | tf.linspace(-1., 1., out_size[2]), indexing='ij') 510 | 511 | x_t_flat = tf.reshape(x_t, (1, -1)) 512 | y_t_flat = tf.reshape(y_t, (1, -1)) 513 | z_t_flat = tf.reshape(z_t, (1, -1)) 514 | 515 | ones = tf.ones_like(x_t_flat) 516 | grid = tf.concat([x_t_flat, y_t_flat, z_t_flat, ones], 0) 517 | grid = tf.reshape(grid, [-1]) 518 | return grid 519 | 520 | def _meshgrid(out_size): 521 | """ 522 | the regular grid of coordinates to sample the values after the transformation 523 | 524 | """ 525 | #with tf.variable_scope('meshgrid'): 526 | 527 | # This should be equivalent to: 528 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 529 | # np.linspace(-1, 1, height)) 530 | # ones = np.ones(np.prod(x_t.shape)) 531 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 532 | 533 | x_t, y_t = tf.meshgrid(tf.linspace(-1.0, 1.0, out_size[1]), 534 | tf.linspace(-1.0, 1.0, out_size[0])) 535 | x_t_flat = tf.reshape(x_t, (1, -1)) 536 | y_t_flat = tf.reshape(y_t, (1, -1)) 537 | 538 | 539 | grid = tf.concat([x_t_flat, y_t_flat, tf.ones_like(x_t_flat)], 0) 540 | grid = tf.reshape(grid, [-1]) 541 | 542 | return grid 543 | 544 | 545 | def _repeat(x, n_repeats): 546 | #with tf.variable_scope('_repeat'): 547 | rep = tf.tile(tf.expand_dims(x,1), [1, n_repeats]) 548 | return tf.reshape(rep, [-1]) 549 | 550 | def _interpolate(im, x, y, out_size, method): 551 | if method=='bilinear': 552 | return bilinear_interp(im, x, y, out_size) 553 | if method=='bicubic': 554 | return bicubic_interp(im, x, y, out_size) 555 | if method=='nearest': 556 | return Nearest_interp(im, x, y, out_size) 557 | return None 558 | 559 | def _interpolate3d(vol, x, y, z, out_size, method='bilinear'): 560 | return bilinear_interp3d(vol, x, y, z, out_size) 561 | 562 | def bilinear_interp3d(vol, x, y, z, out_size, edge_size=1): 563 | #with tf.variable_scope('bilinear_interp3d'): 564 | batch_size, depth, height, width, channels = vol.get_shape().as_list() 565 | 566 | if edge_size>0: 567 | vol = tf.pad(vol, [[0,0], [edge_size,edge_size], [edge_size,edge_size], [edge_size,edge_size], [0,0]], mode='CONSTANT') 568 | 569 | x = tf.cast(x, tf.float32) 570 | y = tf.cast(y, tf.float32) 571 | z = tf.cast(z, tf.float32) 572 | 573 | depth_f = tf.cast(depth, tf.float32) 574 | height_f = tf.cast(height, tf.float32) 575 | width_f = tf.cast(width, tf.float32) 576 | 577 | out_depth = out_size[0] 578 | out_height = out_size[1] 579 | out_width = out_size[2] 580 | 581 | # scale indices to [0, width/height/depth - 1] 582 | x = (x + 1.) / 2. * (width_f -1.) 583 | y = (y + 1.) / 2. * (height_f -1.) 584 | z = (z + 1.) / 2. * (depth_f -1.) 585 | 586 | # clip to to [0, width/height/depth - 1] +- edge_size 587 | x = tf.clip_by_value(x, -edge_size, width_f -1. + edge_size) 588 | y = tf.clip_by_value(y, -edge_size, height_f -1. + edge_size) 589 | z = tf.clip_by_value(z, -edge_size, depth_f -1. + edge_size) 590 | 591 | x += edge_size 592 | y += edge_size 593 | z += edge_size 594 | 595 | # do sampling 596 | x0_f = tf.floor(x) 597 | y0_f = tf.floor(y) 598 | z0_f = tf.floor(z) 599 | x1_f = x0_f + 1 600 | y1_f = y0_f + 1 601 | z1_f = z0_f + 1 602 | 603 | x0 = tf.cast(x0_f, tf.int32) 604 | y0 = tf.cast(y0_f, tf.int32) 605 | z0 = tf.cast(z0_f, tf.int32) 606 | 607 | x1 = tf.cast(tf.minimum(x1_f, width_f - 1. + 2*edge_size), tf.int32) 608 | y1 = tf.cast(tf.minimum(y1_f, height_f - 1. + 2*edge_size), tf.int32) 609 | z1 = tf.cast(tf.minimum(z1_f, depth_f - 1. + 2*edge_size), tf.int32) 610 | 611 | dim3 = (width + 2*edge_size) 612 | dim2 = (width + 2*edge_size)*(height + 2*edge_size) 613 | dim1 = (width + 2*edge_size)*(height + 2*edge_size)*(depth + 2*edge_size) 614 | 615 | base = _repeat(tf.range(batch_size)*dim1, out_depth*out_height*out_width) 616 | base_z0 = base + z0*dim2 617 | base_z1 = base + z1*dim2 618 | 619 | base_y00 = base_z0 + y0*dim3 620 | base_y01 = base_z0 + y1*dim3 621 | base_y10 = base_z1 + y0*dim3 622 | base_y11 = base_z1 + y1*dim3 623 | 624 | idx_000 = base_y00 + x0 625 | idx_001 = base_y00 + x1 626 | idx_010 = base_y01 + x0 627 | idx_011 = base_y01 + x1 628 | idx_100 = base_y10 + x0 629 | idx_101 = base_y10 + x1 630 | idx_110 = base_y11 + x0 631 | idx_111 = base_y11 + x1 632 | 633 | # use indices to lookup pixels in the flat image and restore 634 | # channels dim 635 | vol_flat = tf.reshape(vol, [-1, channels]) 636 | I000 = tf.gather(vol_flat, idx_000) 637 | I001 = tf.gather(vol_flat, idx_001) 638 | I010 = tf.gather(vol_flat, idx_010) 639 | I011 = tf.gather(vol_flat, idx_011) 640 | I100 = tf.gather(vol_flat, idx_100) 641 | I101 = tf.gather(vol_flat, idx_101) 642 | I110 = tf.gather(vol_flat, idx_110) 643 | I111 = tf.gather(vol_flat, idx_111) 644 | 645 | # and finally calculate interpolated values 646 | w000 = tf.expand_dims((z1_f-z)*(y1_f-y)*(x1_f-x),1) 647 | w001 = tf.expand_dims((z1_f-z)*(y1_f-y)*(x-x0_f),1) 648 | w010 = tf.expand_dims((z1_f-z)*(y-y0_f)*(x1_f-x),1) 649 | w011 = tf.expand_dims((z1_f-z)*(y-y0_f)*(x-x0_f),1) 650 | w100 = tf.expand_dims((z-z0_f)*(y1_f-y)*(x1_f-x),1) 651 | w101 = tf.expand_dims((z-z0_f)*(y1_f-y)*(x-x0_f),1) 652 | w110 = tf.expand_dims((z-z0_f)*(y-y0_f)*(x1_f-x),1) 653 | w111 = tf.expand_dims((z-z0_f)*(y-y0_f)*(x-x0_f),1) 654 | 655 | output = tf.add_n([ 656 | w000*I000, 657 | w001*I001, 658 | w010*I010, 659 | w011*I011, 660 | w100*I100, 661 | w101*I101, 662 | w110*I110, 663 | w111*I111]) 664 | return output 665 | 666 | 667 | def bilinear_interp(im, x, y, out_size): 668 | #with tf.variable_scope('bilinear_interp'): 669 | batch_size, height, width, channels = im.get_shape().as_list() 670 | 671 | x = tf.cast(x, tf.float32) 672 | y = tf.cast(y, tf.float32) 673 | height_f = tf.cast(height, tf.float32) 674 | width_f = tf.cast(width, tf.float32) 675 | out_height = out_size[0] 676 | out_width = out_size[1] 677 | 678 | # scale indices from [-1, 1] to [0, width/height - 1] 679 | x = tf.clip_by_value(x, -1, 1) 680 | y = tf.clip_by_value(y, -1, 1) 681 | x = (x + 1.0) / 2.0 * (width_f-1.0) 682 | y = (y + 1.0) / 2.0 * (height_f-1.0) 683 | 684 | # do sampling 685 | x0_f = tf.floor(x) 686 | y0_f = tf.floor(y) 687 | x1_f = x0_f + 1 688 | y1_f = y0_f + 1 689 | 690 | x0 = tf.cast(x0_f, tf.int32) 691 | y0 = tf.cast(y0_f, tf.int32) 692 | x1 = tf.cast(tf.minimum(x1_f, width_f - 1), tf.int32) 693 | y1 = tf.cast(tf.minimum(y1_f, height_f - 1), tf.int32) 694 | 695 | dim2 = width 696 | dim1 = width*height 697 | 698 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width) 699 | 700 | base_y0 = base + y0*dim2 701 | base_y1 = base + y1*dim2 702 | 703 | idx_00 = base_y0 + x0 704 | idx_01 = base_y0 + x1 705 | idx_10 = base_y1 + x0 706 | idx_11 = base_y1 + x1 707 | 708 | # use indices to lookup pixels in the flat image and restore 709 | # channels dim 710 | im_flat = tf.reshape(im, [-1, channels]) 711 | 712 | I00 = tf.gather(im_flat, idx_00) 713 | I01 = tf.gather(im_flat, idx_01) 714 | I10 = tf.gather(im_flat, idx_10) 715 | I11 = tf.gather(im_flat, idx_11) 716 | 717 | # and finally calculate interpolated values 718 | w00 = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 719 | w01 = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 720 | w10 = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 721 | w11 = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 722 | 723 | output = tf.add_n([w00*I00, w01*I01, w10*I10, w11*I11]) 724 | return output 725 | 726 | 727 | def bicubic_interp(im, x, y, out_size): 728 | alpha = -0.75 # same as in tf.image.resize_images, see: 729 | # tensorflow/tensorflow/core/kernels/resize_bicubic_op.cc 730 | bicubic_coeffs = ( 731 | (1, 0, -(alpha+3), (alpha+2)), 732 | (0, alpha, -2*alpha, alpha ), 733 | (0, -alpha, 2*alpha+3, -alpha-2 ), 734 | (0, 0, alpha, -alpha ) 735 | ) 736 | 737 | #with tf.variable_scope('bicubic_interp'): 738 | batch_size, height, width, channels = im.get_shape().as_list() 739 | 740 | x = tf.cast(x, tf.float32) 741 | y = tf.cast(y, tf.float32) 742 | height_f = tf.cast(height, tf.float32) 743 | width_f = tf.cast(width, tf.float32) 744 | out_height = out_size[0] 745 | out_width = out_size[1] 746 | 747 | # scale indices from [-1, 1] to [0, width/height - 1] 748 | x = tf.clip_by_value(x, -1, 1) 749 | y = tf.clip_by_value(y, -1, 1) 750 | x = (x + 1.0) / 2.0 * (width_f-1.0) 751 | y = (y + 1.0) / 2.0 * (height_f-1.0) 752 | 753 | # do sampling 754 | # integer coordinates of 4x4 neighbourhood around (x0_f, y0_f) 755 | x0_f = tf.floor(x) 756 | y0_f = tf.floor(y) 757 | xm1_f = x0_f - 1 758 | ym1_f = y0_f - 1 759 | xp1_f = x0_f + 1 760 | yp1_f = y0_f + 1 761 | xp2_f = x0_f + 2 762 | yp2_f = y0_f + 2 763 | 764 | # clipped integer coordinates 765 | xs = [0, 0, 0, 0] 766 | ys = [0, 0, 0, 0] 767 | xs[0] = tf.cast(x0_f, tf.int32) 768 | ys[0] = tf.cast(y0_f, tf.int32) 769 | xs[1] = tf.cast(tf.maximum(xm1_f, 0), tf.int32) 770 | ys[1] = tf.cast(tf.maximum(ym1_f, 0), tf.int32) 771 | xs[2] = tf.cast(tf.minimum(xp1_f, width_f - 1), tf.int32) 772 | ys[2] = tf.cast(tf.minimum(yp1_f, height_f - 1), tf.int32) 773 | xs[3] = tf.cast(tf.minimum(xp2_f, width_f - 1), tf.int32) 774 | ys[3] = tf.cast(tf.minimum(yp2_f, height_f - 1), tf.int32) 775 | 776 | # indices of neighbours for the batch 777 | dim2 = width 778 | dim1 = width*height 779 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width) 780 | 781 | idx = [] 782 | for i in range(4): 783 | idx.append([]) 784 | for j in range(4): 785 | cur_idx = base + ys[i]*dim2 + xs[j] 786 | idx[i].append(cur_idx) 787 | 788 | # use indices to lookup pixels in the flat image and restore 789 | # channels dim 790 | im_flat = tf.reshape(im, [-1, channels]) 791 | 792 | Is = [] 793 | for i in range(4): 794 | Is.append([]) 795 | for j in range(4): 796 | Is[i].append(tf.gather(im_flat, idx[i][j])) 797 | 798 | def get_weights(x, x0_f): 799 | tx = (x-x0_f) 800 | tx2 = tx * tx 801 | tx3 = tx2 * tx 802 | t = [1, tx, tx2, tx3] 803 | weights = [] 804 | for i in range(4): 805 | result = 0 806 | for j in range(4): 807 | result = result + bicubic_coeffs[i][j]*t[j] 808 | result = tf.reshape(result, [-1, 1]) 809 | weights.append(result) 810 | return weights 811 | 812 | 813 | # to calculate interpolated values first, 814 | # interpolate in x dim 4 times for y=[0, -1, 1, 2] 815 | weights = get_weights(x, x0_f) 816 | x_interp = [] 817 | for i in range(4): 818 | result = [] 819 | for j in range(4): 820 | result = result + [weights[j]*Is[i][j]] 821 | x_interp.append(tf.add_n(result)) 822 | 823 | # finally, interpolate in y dim using interpolations in x dim 824 | weights = get_weights(y, y0_f) 825 | y_interp = [] 826 | for i in range(4): 827 | y_interp = y_interp + [weights[i]*x_interp[i]] 828 | 829 | output = tf.add_n(y_interp) 830 | return output 831 | 832 | 833 | def Nearest_interp(im, x, y, out_size): 834 | #with tf.variable_scope('bilinear_interp'): 835 | batch_size, height, width, channels = im.get_shape().as_list() 836 | 837 | x = tf.cast(x, tf.float32) 838 | y = tf.cast(y, tf.float32) 839 | height_f = tf.cast(height, tf.float32) 840 | width_f = tf.cast(width, tf.float32) 841 | out_height = out_size[0] 842 | out_width = out_size[1] 843 | 844 | # scale indices from [-1, 1] to [0, width/height - 1] 845 | x = tf.clip_by_value(x, -1, 1) 846 | y = tf.clip_by_value(y, -1, 1) 847 | x = (x + 1.0) / 2.0 * (width_f-1.0) 848 | y = (y + 1.0) / 2.0 * (height_f-1.0) 849 | 850 | # do sampling 851 | x0_f = tf.floor(x) 852 | y0_f = tf.floor(y) 853 | 854 | 855 | x0 = tf.cast(x0_f, tf.int32) 856 | y0 = tf.cast(y0_f, tf.int32) 857 | 858 | dim2 = width 859 | dim1 = width*height 860 | 861 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width) 862 | 863 | base_y0 = base + y0*dim2 864 | 865 | idx = base_y0 + x0 866 | 867 | # use indices to lookup pixels in the flat image and restore 868 | # channels dim 869 | im_flat = tf.reshape(im, [-1, channels]) 870 | 871 | I00 = tf.gather(im_flat, idx) 872 | 873 | output = I00 874 | return output 875 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from xml.etree import ElementTree 3 | import cv2 4 | import tensorflow as tf 5 | import tensorflow_addons as tfa 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from tensorflow.keras.callbacks import ModelCheckpoint 9 | from sklearn.model_selection import train_test_split 10 | from spatial_transformer import ProjectiveTransformer, AffineTransformer 11 | #from tensorflow.keras.applications.xception import preprocess_input, Xception 12 | from tensorflow.keras.applications.convnext import ConvNeXtTiny, preprocess_input 13 | from tensorflow.keras.utils import to_categorical 14 | from tensorflow.keras.models import Model, load_model, Sequential 15 | from tensorflow.keras.layers import * 16 | from scipy.interpolate import interp1d 17 | import random 18 | import scipy.io 19 | import math 20 | from scipy import ndimage 21 | from sklearn.utils import shuffle 22 | 23 | 24 | val_path = 'E:/LSUN2016_surface_relabel/surface_relabel/val/' 25 | 26 | img = cv2.imread(val_path+'sun_atssgbmizunolhzn.jpg') 27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 28 | img = cv2.resize(img, (400 , 400)) 29 | img1 = np.array(img,copy=True) 30 | img = img[tf.newaxis,...] 31 | img = preprocess_input(img) 32 | 33 | seg = cv2.imread(val_path+'sun_atssgbmizunolhzn.png',0) 34 | seg = cv2.resize(seg, (400 ,400), interpolation= cv2.INTER_NEAREST) 35 | seg = seg/51.0 36 | 37 | ref_img = tf.io.read_file('ref_img2.png') 38 | ref_img = tf.io.decode_png(ref_img) 39 | ref_img = tf.cast(ref_img, tf.float32) / 51.0 40 | ref_img = ref_img[tf.newaxis,...] 41 | #ref_img = tf.tile(ref_img, [1,1,1,1]) 42 | print(ref_img.shape) 43 | 44 | 45 | base_model = ConvNeXtTiny(include_top=False, weights="imagenet", input_shape= (400,400,3), pooling = 'avg') 46 | theta = Dense(8)(base_model.output) 47 | stl= ProjectiveTransformer((400,400)).transform(ref_img, theta) 48 | model = Model(base_model.input, stl) 49 | 50 | 51 | model.summary() 52 | model.load_weights('') 53 | 54 | 55 | out= model.predict(img) 56 | 57 | out = np.rint(out[0,:,:,0]) 58 | 59 | 60 | plt.figure('seg') 61 | plt.imshow(out, vmin = 1, vmax= 5) 62 | plt.figure('gt') 63 | plt.imshow(seg , vmin = 1, vmax= 5) 64 | plt.figure('img') 65 | plt.imshow(img1) 66 | plt.show() 67 | 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from xml.etree import ElementTree 3 | import cv2 4 | import tensorflow as tf 5 | import tensorflow_addons as tfa 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from tensorflow.keras.callbacks import ModelCheckpoint 9 | from sklearn.model_selection import train_test_split 10 | from spatial_transformer import ProjectiveTransformer, AffineTransformer 11 | #from tensorflow.keras.applications.xception import preprocess_input, Xception 12 | from tensorflow.keras.applications.convnext import ConvNeXtTiny, preprocess_input 13 | from tensorflow.keras.utils import to_categorical 14 | from tensorflow.keras.models import Model, load_model, Sequential 15 | from tensorflow.keras.layers import * 16 | import random 17 | import scipy.io 18 | from sklearn.utils import shuffle 19 | 20 | train_path = 'E:/LSUN2016_surface_relabel/surface_relabel/train/' 21 | val_path = 'E:/LSUN2016_surface_relabel/surface_relabel/val/' 22 | 23 | train_data = scipy.io.loadmat('training.mat') 24 | train_data = train_data.get('training')[0] 25 | val_data = scipy.io.loadmat('validation.mat') 26 | val_data = val_data.get('validation')[0] 27 | 28 | print(train_data.shape) 29 | 30 | train_data = shuffle(train_data) 31 | batch_size = 48 32 | 33 | 34 | def data_generator(data, path = train_path, batch_size=32, number_of_batches=None): 35 | counter = 0 36 | n_classes = 11 37 | 38 | #training parameters 39 | train_w , train_h = 400, 400 40 | while True: 41 | idx_start = batch_size * counter 42 | idx_end = batch_size * (counter + 1) 43 | x_batch = [] 44 | y_seg_batch = [] 45 | y_batch = [] 46 | for file in data[idx_start:idx_end]: 47 | img_name = list(file)[0][0] 48 | img = cv2.imread(path+img_name+'.jpg') 49 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 50 | img = cv2.resize(img, (train_w , train_h)) 51 | seg = cv2.imread(path+img_name+'.png',0) 52 | seg = cv2.resize(seg, (train_w , train_h), interpolation= cv2.INTER_NEAREST) 53 | seg = seg / 51.0 54 | label = list(file)[2][0][0] 55 | #print(label) 56 | if label in [3,4,5,10]: 57 | seg[seg==2] = 1 58 | rand_hflip = random.randint(0, 1) 59 | rand_brightness = random.randint(0, 1) 60 | 61 | if (rand_hflip == 1 and path == train_path): 62 | img = cv2.flip(img, 1) 63 | seg = cv2.flip(seg, 1) 64 | seg_temp = np.array(seg, copy= True) 65 | #flip the right and left walls 66 | seg[seg_temp==2] = 3 67 | seg[seg_temp==3] = 2 68 | if (rand_brightness == 1 and path == train_path): 69 | val = random.uniform(-1, 1) * 30 70 | img = img + val 71 | img = np.clip(img, 0, 255) 72 | img = np.uint8(img) 73 | 74 | img = preprocess_input(img) 75 | x_batch.append(img) 76 | y_seg_batch.append(seg) 77 | counter += 1 78 | x_train = np.array(x_batch) 79 | y_seg_train = np.array(y_seg_batch) 80 | yield x_train, y_seg_train 81 | if (counter == number_of_batches): 82 | counter = 0 83 | 84 | ref_img = tf.io.read_file('ref_img2.png') 85 | ref_img = tf.io.decode_png(ref_img) 86 | ref_img = tf.cast(ref_img, tf.float32) / 51.0 87 | ref_img = ref_img[tf.newaxis,...] 88 | ref_img = tf.tile(ref_img, [batch_size,1,1,1]) 89 | print(ref_img.shape) 90 | 91 | w= np.zeros((768, 8), dtype='float32') 92 | 93 | b = np.zeros(8, dtype='float32') 94 | b[0] = 1 95 | b[4] = 1 96 | 97 | base_model = ConvNext(include_top=False, weights="imagenet", input_shape= (400,400,3), pooling = 'avg') 98 | theta = Dense(8, weights=[w, b])(base_model.output) 99 | stl = ProjectiveTransformer((400,400)).transform(ref_img, theta) 100 | model = Model(base_model.input, stl) 101 | 102 | model.summary() 103 | model.compile(optimizer = tfa.optimizers.AdamW(learning_rate=0.001, weight_decay = 0.0001), loss = ['huber_loss'], metrics = ['accuracy']) 104 | 105 | model.summary() 106 | 107 | filepath="E:/LSUN2016_surface_relabel/surface_relabel/weights_stn/weights-improvement-{epoch:02d}-{loss:.4f}-{val_loss:.4f}.h5" 108 | checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=True, mode='auto', period=1) 109 | callbacks_list = [checkpoint] 110 | n_train = 4000 111 | n_valid= 394 112 | model.fit(data_generator(train_data, train_path, batch_size, number_of_batches= n_train // batch_size), 113 | steps_per_epoch=max(1, n_train//batch_size), initial_epoch = 0, 114 | validation_data= data_generator(val_data, val_path, batch_size, number_of_batches= n_valid // batch_size), 115 | validation_steps=max(1, n_valid//batch_size), 116 | epochs=150, 117 | callbacks=callbacks_list) 118 | -------------------------------------------------------------------------------- /weight file.txt: -------------------------------------------------------------------------------- 1 | download from the following link: https://drive.google.com/file/d/1KWHofYVsjb0bRi6sVuZtPpanZWCGyWSr/view?usp=sharing 2 | --------------------------------------------------------------------------------