├── README.md
├── ref_img2.png
├── spatial_transformer.py
├── test.py
├── train.py
└── weight file.txt
/README.md:
--------------------------------------------------------------------------------
1 | # ST-RoomNet: Learning Room Layout Estimation From Single Image Through Unsupervised Spatial Transformations
2 |
3 | This is the official implementation of ST-RoomNet: https://openaccess.thecvf.com/content/CVPR2023W/VOCVALC/html/Ibrahem_ST-RoomNet_Learning_Room_Layout_Estimation_From_Single_Image_Through_Unsupervised_CVPRW_2023_paper.html
4 |
5 | The spatial transformer module is based on this repo: https://github.com/dantkz/spatial-transformer-tensorflow
6 |
7 | We modified the spatial transformer module to work in Tensorflow2.x and added more features such as nearest neighbor interpolation in addition to the original bilinear and bicubic interpolations.
8 |
9 | requirements:
10 |
11 | opencv 4.4.1
12 |
13 | tensorflow 2.9.1
14 |
15 | If you use this code, please cite the paper as follows:
16 |
17 | @InProceedings{Ibrahem_2023_CVPR,
18 | author = {Ibrahem, Hatem and Salem, Ahmed and Kang, Hyun-Soo},
19 | title = {ST-RoomNet: Learning Room Layout Estimation From Single Image Through Unsupervised Spatial Transformations},
20 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
21 | month = {June},
22 | year = {2023},
23 | pages = {3375-3383}
24 | }
25 |
--------------------------------------------------------------------------------
/ref_img2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lab231/ST-RoomNet/14f9492bad6479a0d88aeb688de70bd1d0ac8d21/ref_img2.png
--------------------------------------------------------------------------------
/spatial_transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Spatial Transformer Networks
3 |
4 | References
5 | ----------
6 | [1] Spatial Transformer Networks
7 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
8 | Submitted on 5 Jun 2015
9 |
10 | [2] https://github.com/tensorflow/models/tree/master/transformer/transformerlayer.py
11 |
12 | [3] https://github.com/daviddao/spatial-transformer-tensorflow
13 |
14 | [4] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
15 |
16 | [5] https://github.com/Lasagne/Lasagne/blob/master/lasagne/layers/special.py
17 |
18 | [6] Fred L. Bookstein (1989):
19 | Principal warps: thin-plate splines and the decomposition of deformations.
20 | IEEE Transactions on Pattern Analysis and Machine Intelligence.
21 | http://doi.org/10.1109/34.24792
22 |
23 | """
24 |
25 | import tensorflow as tf
26 | import math
27 |
28 |
29 | """
30 | Legacy Function
31 |
32 | """
33 | def transformer(inp, theta, out_size, name='SpatialTransformer', **kwargs):
34 | #with tf.variable_scope(name):
35 | stl = AffineTransformer(out_size)
36 | output = stl.transform(inp, theta, out_size)
37 | return output
38 |
39 | class AffineVolumeTransformer(object):
40 | """Spatial Affine Volume Transformer Layer
41 | Implements a spatial transformer layer for volumetric 3D input.
42 | Implemented by Daniyar Turmukhambetov.
43 | """
44 |
45 | def __init__(self, out_size, name='SpatialAffineVolumeTransformer', interp_method='bilinear', **kwargs):
46 | """
47 | Parameters
48 | ----------
49 | out_size : tuple of three ints
50 | The size of the output of the spatial network (depth, height, width), i.e. z, y, x
51 | name : string
52 | The scope name of the variables in this network.
53 |
54 | """
55 | self.name = name
56 | self.out_size = out_size
57 | self.param_dim = 3*4
58 | self.interp_method=interp_method
59 |
60 | #with tf.variable_scope(self.name):
61 | self.voxel_grid = _meshgrid3d(self.out_size)
62 |
63 |
64 | def transform(self, inp, theta):
65 | """
66 | Affine Transformation of input tensor inp with parameters theta
67 |
68 | Parameters
69 | ----------
70 | inp : float
71 | The input tensor should have the shape
72 | [batch_size, depth, height, width, in_channels].
73 | theta: float
74 | The output of the localisation network
75 | should have the shape
76 | [batch_size, 12].
77 | Notes
78 | -----
79 | To initialize the network to the identity transform initialize ``theta`` to :
80 | identity = np.array([[1., 0., 0., 0.],
81 | [0., 1., 0., 0.],
82 | [0., 0., 1., 0.]])
83 | identity = identity.flatten()
84 | theta = tf.Variable(initial_value=identity)
85 |
86 | """
87 |
88 | x_s, y_s, z_s = self._transform(inp, theta)
89 |
90 | output = _interpolate3d(
91 | inp, x_s, y_s, z_s,
92 | self.out_size,
93 | method=self.interp_method
94 | )
95 |
96 | batch_size, _, _, _, num_channels = inp.get_shape().as_list()
97 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], self.out_size[2], num_channels])
98 |
99 | return output
100 |
101 |
102 |
103 | def _transform(self, inp, theta):
104 |
105 | batch_size, _, _, _, num_channels = inp.get_shape().as_list()
106 |
107 | theta = tf.reshape(theta, (-1, 3, 4))
108 | voxel_grid = tf.tile(self.voxel_grid, [batch_size])
109 | voxel_grid = tf.reshape(voxel_grid, [batch_size, 4, -1])
110 |
111 | # Transform A x (x_t, y_t, z_t, 1)^T -> (x_s, y_s, z_s)
112 | T_g = tf.matmul(theta, voxel_grid)
113 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
114 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
115 | z_s = tf.slice(T_g, [0, 2, 0], [-1, 1, -1])
116 |
117 | x_s_flat = tf.reshape(x_s, [-1])
118 | y_s_flat = tf.reshape(y_s, [-1])
119 | z_s_flat = tf.reshape(z_s, [-1])
120 | return x_s_flat, y_s_flat, z_s_flat
121 |
122 |
123 | class AffineTransformer(object):
124 | """Spatial Affine Transformer Layer
125 |
126 | Implements a spatial transformer layer as described in [1]_.
127 | Based on [2]_ and [3]_. Edited by Daniyar Turmukhambetov.
128 |
129 | """
130 |
131 | def __init__(self, out_size, name='SpatialAffineTransformer', interp_method='bilinear', **kwargs):
132 | """
133 | Parameters
134 | ----------
135 | out_size : tuple of two ints
136 | The size of the output of the spatial network (height, width).
137 | name : string
138 | The scope name of the variables in this network.
139 |
140 | """
141 | self.name = name
142 | self.out_size = out_size
143 | self.param_dim = 6
144 | self.interp_method=interp_method
145 |
146 | #with tf.variable_scope(self.name):
147 | self.pixel_grid = _meshgrid(self.out_size)
148 |
149 |
150 | def transform(self, inp, theta):
151 | """
152 | Affine Transformation of input tensor inp with parameters theta
153 |
154 | Parameters
155 | ----------
156 | inp : float
157 | The input tensor should have the shape
158 | [batch_size, height, width, num_channels].
159 | theta: float
160 | The output of the localisation network
161 | should have the shape
162 | [batch_size, 6].
163 | Notes
164 | -----
165 | To initialize the network to the identity transform initialize ``theta`` to :
166 | identity = np.array([[1., 0., 0.],
167 | [0., 1., 0.]])
168 | identity = identity.flatten()
169 | theta = tf.Variable(initial_value=identity)
170 |
171 | """
172 | #with tf.variable_scope(self.name):
173 | x_s, y_s = self._transform(inp, theta)
174 |
175 | output = _interpolate(
176 | inp, x_s, y_s,
177 | self.out_size,
178 | method=self.interp_method
179 | )
180 |
181 | batch_size, _, _, num_channels = inp.get_shape().as_list()
182 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels])
183 |
184 | return output
185 |
186 |
187 |
188 | def _transform(self, inp, theta):
189 | #with tf.variable_scope(self.name + '_affine_transform'):
190 | batch_size, _, _, num_channels = inp.get_shape().as_list()
191 |
192 | theta = tf.reshape(theta, (-1, 2, 3))
193 | pixel_grid = tf.tile(self.pixel_grid, [batch_size])
194 | pixel_grid = tf.reshape(pixel_grid, [batch_size, 3, -1])
195 |
196 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
197 | T_g = tf.matmul(theta, pixel_grid)
198 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
199 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
200 | x_s_flat = tf.reshape(x_s, [-1])
201 | y_s_flat = tf.reshape(y_s, [-1])
202 | return x_s_flat, y_s_flat
203 |
204 | class ProjectiveTransformer(object):
205 | """Spatial Projective Transformer Layer
206 |
207 | Implements a spatial transformer layer as described in [1]_.
208 | Based on [2]_ and [3]_. Edited by Daniyar Turmukhambetov.
209 |
210 | """
211 |
212 | def __init__(self, out_size, name='SpatialProjectiveTransformer', interp_method='bilinear', **kwargs): #interp_method='nearest' can be only used in inference mode
213 | """
214 | Parameters
215 | ----------
216 | out_size : tuple of two ints
217 | The size of the output of the spatial network (height, width).
218 | name : string
219 | The scope name of the variables in this network.
220 |
221 | """
222 | self.name = name
223 | self.out_size = out_size
224 | self.param_dim = 8
225 | self.interp_method=interp_method
226 |
227 | #with tf.variable_scope(self.name):
228 | self.pixel_grid = _meshgrid(self.out_size)
229 |
230 |
231 | def transform(self, inp, theta):
232 | """
233 | Projective Transformation of input tensor inp with parameters theta
234 |
235 | Parameters
236 | ----------
237 | inp : float
238 | The input tensor should have the shape
239 | [batch_size, height, width, num_channels].
240 | theta: float
241 | The output of the localisation network
242 | should have the shape
243 | [batch_size, 8].
244 | Notes
245 | -----
246 | To initialize the network to the identity transform initialize ``theta`` to :
247 | identity = np.array([1., 0., 0.,
248 | [0., 1., 0.,
249 | [0., 0.])
250 | theta = tf.Variable(initial_value=identity)
251 |
252 | """
253 | #with tf.variable_scope(self.name):
254 | x_s, y_s = self._transform(inp, theta)
255 |
256 | output = _interpolate(
257 | inp, x_s, y_s,
258 | self.out_size,
259 | method=self.interp_method
260 | )
261 |
262 | batch_size, _, _, num_channels = inp.get_shape().as_list()
263 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels])
264 |
265 | return output
266 |
267 |
268 |
269 | def _transform(self, inp, theta):
270 | #with tf.variable_scope(self.name + '_projective_transform'):
271 | batch_size, _, _, num_channels = inp.get_shape().as_list()
272 |
273 | theta = tf.reshape(theta, (batch_size, 8))
274 | theta = tf.concat([theta, tf.ones([batch_size, 1])], 1)
275 | theta = tf.reshape(theta, (batch_size, 3, 3))
276 |
277 | pixel_grid = tf.tile(self.pixel_grid, [batch_size])
278 | pixel_grid = tf.reshape(pixel_grid, [batch_size, 3, -1])
279 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
280 | T_g = tf.matmul(theta, pixel_grid)
281 |
282 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
283 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
284 | z_s = tf.slice(T_g, [0, 2, 0], [-1, 1, -1])
285 |
286 |
287 | z_s += 0.0000001
288 |
289 | x_s = x_s/z_s
290 | y_s = y_s/z_s
291 |
292 | x_s_flat = tf.reshape(x_s, [-1])
293 | y_s_flat = tf.reshape(y_s, [-1])
294 |
295 | return x_s_flat, y_s_flat
296 |
297 | class ElasticTransformer(object):
298 | """Spatial Elastic Transformer Layer with Thin Plate Spline deformations
299 |
300 | Implements a spatial transformer layer as described in [1]_.
301 | Based on [4]_ and [5]_. Edited by Daniyar Turmukhambetov.
302 |
303 | """
304 |
305 | def __init__(self, out_size, param_dim=2*16, name='SpatialElasticTransformer', interp_method='bilinear', **kwargs):
306 | """
307 | Parameters
308 | ----------
309 | out_size : tuple of two ints
310 | The size of the output of the spatial network (height, width).
311 | param_dim: int
312 | The 2 x number of control points that define
313 | Thin Plate Splines deformation field.
314 | number of control points *MUST* be a square of an integer.
315 | 2 x 16 by default.
316 | name : string
317 | The scope name of the variables in this network.
318 |
319 | """
320 | num_control_points = int(param_dim/2)
321 | assert param_dim == 2*num_control_points, 'param_dim must be 2 times a square of an integer.'
322 |
323 | self.name = name
324 | self.param_dim = param_dim
325 | self.interp_method=interp_method
326 | self.num_control_points = num_control_points
327 | self.out_size = out_size
328 |
329 | self.grid_size = math.floor(math.sqrt(self.num_control_points))
330 | assert self.grid_size*self.grid_size == self.num_control_points, 'num_control_points must be a square of an int'
331 |
332 | #with tf.variable_scope(self.name):
333 | # Create source grid
334 | self.source_points = ElasticTransformer.get_meshgrid(self.grid_size, self.grid_size)
335 | # Construct pixel grid
336 | self.pixel_grid = ElasticTransformer.get_meshgrid(self.out_size[1], self.out_size[0])
337 | self.num_pixels = self.out_size[0]*self.out_size[1]
338 | self.pixel_distances, self.L_inv = self._initialize_tps(self.source_points, self.pixel_grid)
339 |
340 |
341 | def transform(self, inp, theta, forward=True, **kwargs):
342 | """
343 | Parameters
344 | ----------
345 | inp : float
346 | The input tensor should have the shape
347 | [batch_size, height, width, num_channels].
348 | theta: float
349 | Should have the shape of [batch_size, self.num_control_points x 2]
350 | Theta is the output of the localisation network, so it is
351 | the x and y offsets of the destination coordinates
352 | of each of the control points.
353 | Notes
354 | -----
355 | To initialize the network to the identity transform initialize ``theta`` to zeros:
356 | identity = np.zeros(16*2)
357 | identity = identity.flatten()
358 | theta = tf.Variable(initial_value=identity)
359 |
360 | """
361 | #with tf.variable_scope(self.name):
362 | # reshape destination offsets to be (batch_size, 2, num_control_points)
363 | # and add to source_points
364 | source_points = tf.expand_dims(self.source_points, 0)
365 | theta = source_points + tf.reshape(theta, [-1, 2, self.num_control_points])
366 |
367 | x_s, y_s = self._transform(
368 | inp, theta, self.num_control_points,
369 | self.pixel_grid, self.num_pixels,
370 | self.pixel_distances, self.L_inv,
371 | self.name + '_elastic_transform', forward)
372 | if forward:
373 | output = _interpolate(
374 | inp, x_s, y_s,
375 | self.out_size,
376 | method=self.interp_method
377 | )
378 | else:
379 | rx_s, ry_s = self._transform(
380 | inp, theta, self.num_control_points,
381 | self.pixel_grid, self.num_pixels,
382 | self.pixel_distances, self.L_inv,
383 | self.name + '_elastic_transform', forward)
384 | output = _interpolate(
385 | inp, rx_s, ry_s,
386 | self.out_size,
387 | method=self.interp_method
388 | )
389 | pass
390 |
391 |
392 | batch_size, _, _, num_channels = inp.get_shape().as_list()
393 | output = tf.reshape(output, [batch_size, self.out_size[0], self.out_size[1], num_channels])
394 | return output
395 |
396 | def _transform(self, inp, theta, num_control_points, pixel_grid, num_pixels, pixel_distances, L_inv, name, forward=True):
397 | #with tf.variable_scope(name):
398 | batch_size = inp.get_shape().as_list()[0]
399 |
400 | # Solve as in ref [2]
401 | theta = tf.reshape(theta, [-1, num_control_points])
402 | coefficients = tf.matmul(theta, L_inv)
403 | coefficients = tf.reshape(coefficients, [-1, 2, num_control_points+3])
404 |
405 | # Transform each point on the target grid (out_size)
406 | right_mat = tf.concat([pixel_grid, pixel_distances], 0)
407 | right_mat = tf.tile(tf.expand_dims(right_mat, 0), (batch_size, 1, 1))
408 | transformed_points = tf.matmul(coefficients, right_mat)
409 | transformed_points = tf.reshape(transformed_points, [-1, 2, num_pixels])
410 |
411 | x_s_flat = tf.reshape(transformed_points[:,0,:], [-1])
412 | y_s_flat = tf.reshape(transformed_points[:,1,:], [-1])
413 |
414 | return x_s_flat, y_s_flat
415 |
416 | # U function for the new point and each source point
417 | @staticmethod
418 | def U_func(points1, points2):
419 | # The U function is simply U(r) = r^2 * log(r^2), as in ref [5]_,
420 | # where r is the euclidean distance
421 | r_sq = tf.transpose(tf.reduce_sum(tf.square(points1 - points2), axis=0))
422 | log_r = tf.log(r_sq)
423 | log_r = tf.where(tf.is_inf(log_r), tf.zeros_like(log_r), log_r)
424 | phi = r_sq * log_r
425 |
426 | # The U function is simply U(r) = r, where r is the euclidean distance
427 | #phi = tf.sqrt(r_sq)
428 | return phi
429 |
430 |
431 | @staticmethod
432 | def get_meshgrid(grid_size_x, grid_size_y):
433 | # Create 2 x num_points array of source points
434 | x_points, y_points = tf.meshgrid(
435 | tf.linspace(-1.0, 1.0, int(grid_size_x)),
436 | tf.linspace(-1.0, 1.0, int(grid_size_y)))
437 | x_flat = tf.reshape(x_points, (1,-1))
438 | y_flat = tf.reshape(y_points, (1,-1))
439 | points = tf.concat([x_flat, y_flat], 0)
440 | return points
441 |
442 |
443 | def _initialize_tps(self, source_points, pixel_grid):
444 | """
445 | Initializes the thin plate spline calculation by creating the source
446 | point array and the inverted L matrix used for calculating the
447 | transformations as in ref [5]_
448 |
449 | Returns
450 | ----------
451 | right_mat : float
452 | Tensor of shape [num_control_points + 3, out_height*out_width].
453 | L_inv : float
454 | Tensor of shape [num_control_points + 3, num_control_points].
455 | source_points : float
456 | Tensor of shape (2, num_control_points).
457 |
458 | """
459 |
460 | tL = ElasticTransformer.U_func(tf.expand_dims(source_points, 2), tf.expand_dims(source_points, 1))
461 |
462 | # Initialize L
463 | L_top = tf.concat([tf.zeros([2,3]), source_points], 1)
464 | L_mid = tf.concat([tf.zeros([1, 2]), tf.ones([1, self.num_control_points+1])], 1)
465 | L_bot = tf.concat([tf.transpose(source_points), tf.ones([self.num_control_points, 1]), tL], 1)
466 |
467 | L = tf.concat([L_top, L_mid, L_bot], 0)
468 | L_inv = tf.matrix_inverse(L)
469 |
470 | # Construct right mat
471 | to_transform = tf.expand_dims(pixel_grid, 2)
472 | stacked_source_points = tf.expand_dims(source_points, 1)
473 | distances = ElasticTransformer.U_func(to_transform, stacked_source_points)
474 |
475 | # Add in the coefficients for the affine translation (1, x, and y,
476 | # corresponding to a_1, a_x, and a_y)
477 | ones = tf.ones(shape=[1, self.num_pixels])
478 | pixel_distances = tf.concat([ones, distances], 0)
479 | L_inv = tf.transpose(L_inv[:,3:])
480 |
481 | return pixel_distances, L_inv
482 |
483 |
484 |
485 | """
486 | Common Functions
487 |
488 | """
489 |
490 | def _meshgrid3d(out_size):
491 | """
492 | the regular grid of coordinates to sample the values after the transformation
493 |
494 | """
495 | #with tf.variable_scope('meshgrid3d'):
496 |
497 | # This should be equivalent to:
498 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
499 | # np.linspace(-1, 1, height))
500 | # ones = np.ones(np.prod(x_t.shape))
501 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
502 |
503 | #z_t, y_t, x_t = tf.meshgrid(tf.linspace(0., out_size[0]-1., out_size[0]),
504 | # tf.linspace(0., out_size[1]-1., out_size[1]),
505 | # tf.linspace(0., out_size[2]-1., out_size[2]), indexing='ij')
506 |
507 | z_t, y_t, x_t = tf.meshgrid(tf.linspace(-1., 1., out_size[0]),
508 | tf.linspace(-1., 1., out_size[1]),
509 | tf.linspace(-1., 1., out_size[2]), indexing='ij')
510 |
511 | x_t_flat = tf.reshape(x_t, (1, -1))
512 | y_t_flat = tf.reshape(y_t, (1, -1))
513 | z_t_flat = tf.reshape(z_t, (1, -1))
514 |
515 | ones = tf.ones_like(x_t_flat)
516 | grid = tf.concat([x_t_flat, y_t_flat, z_t_flat, ones], 0)
517 | grid = tf.reshape(grid, [-1])
518 | return grid
519 |
520 | def _meshgrid(out_size):
521 | """
522 | the regular grid of coordinates to sample the values after the transformation
523 |
524 | """
525 | #with tf.variable_scope('meshgrid'):
526 |
527 | # This should be equivalent to:
528 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
529 | # np.linspace(-1, 1, height))
530 | # ones = np.ones(np.prod(x_t.shape))
531 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
532 |
533 | x_t, y_t = tf.meshgrid(tf.linspace(-1.0, 1.0, out_size[1]),
534 | tf.linspace(-1.0, 1.0, out_size[0]))
535 | x_t_flat = tf.reshape(x_t, (1, -1))
536 | y_t_flat = tf.reshape(y_t, (1, -1))
537 |
538 |
539 | grid = tf.concat([x_t_flat, y_t_flat, tf.ones_like(x_t_flat)], 0)
540 | grid = tf.reshape(grid, [-1])
541 |
542 | return grid
543 |
544 |
545 | def _repeat(x, n_repeats):
546 | #with tf.variable_scope('_repeat'):
547 | rep = tf.tile(tf.expand_dims(x,1), [1, n_repeats])
548 | return tf.reshape(rep, [-1])
549 |
550 | def _interpolate(im, x, y, out_size, method):
551 | if method=='bilinear':
552 | return bilinear_interp(im, x, y, out_size)
553 | if method=='bicubic':
554 | return bicubic_interp(im, x, y, out_size)
555 | if method=='nearest':
556 | return Nearest_interp(im, x, y, out_size)
557 | return None
558 |
559 | def _interpolate3d(vol, x, y, z, out_size, method='bilinear'):
560 | return bilinear_interp3d(vol, x, y, z, out_size)
561 |
562 | def bilinear_interp3d(vol, x, y, z, out_size, edge_size=1):
563 | #with tf.variable_scope('bilinear_interp3d'):
564 | batch_size, depth, height, width, channels = vol.get_shape().as_list()
565 |
566 | if edge_size>0:
567 | vol = tf.pad(vol, [[0,0], [edge_size,edge_size], [edge_size,edge_size], [edge_size,edge_size], [0,0]], mode='CONSTANT')
568 |
569 | x = tf.cast(x, tf.float32)
570 | y = tf.cast(y, tf.float32)
571 | z = tf.cast(z, tf.float32)
572 |
573 | depth_f = tf.cast(depth, tf.float32)
574 | height_f = tf.cast(height, tf.float32)
575 | width_f = tf.cast(width, tf.float32)
576 |
577 | out_depth = out_size[0]
578 | out_height = out_size[1]
579 | out_width = out_size[2]
580 |
581 | # scale indices to [0, width/height/depth - 1]
582 | x = (x + 1.) / 2. * (width_f -1.)
583 | y = (y + 1.) / 2. * (height_f -1.)
584 | z = (z + 1.) / 2. * (depth_f -1.)
585 |
586 | # clip to to [0, width/height/depth - 1] +- edge_size
587 | x = tf.clip_by_value(x, -edge_size, width_f -1. + edge_size)
588 | y = tf.clip_by_value(y, -edge_size, height_f -1. + edge_size)
589 | z = tf.clip_by_value(z, -edge_size, depth_f -1. + edge_size)
590 |
591 | x += edge_size
592 | y += edge_size
593 | z += edge_size
594 |
595 | # do sampling
596 | x0_f = tf.floor(x)
597 | y0_f = tf.floor(y)
598 | z0_f = tf.floor(z)
599 | x1_f = x0_f + 1
600 | y1_f = y0_f + 1
601 | z1_f = z0_f + 1
602 |
603 | x0 = tf.cast(x0_f, tf.int32)
604 | y0 = tf.cast(y0_f, tf.int32)
605 | z0 = tf.cast(z0_f, tf.int32)
606 |
607 | x1 = tf.cast(tf.minimum(x1_f, width_f - 1. + 2*edge_size), tf.int32)
608 | y1 = tf.cast(tf.minimum(y1_f, height_f - 1. + 2*edge_size), tf.int32)
609 | z1 = tf.cast(tf.minimum(z1_f, depth_f - 1. + 2*edge_size), tf.int32)
610 |
611 | dim3 = (width + 2*edge_size)
612 | dim2 = (width + 2*edge_size)*(height + 2*edge_size)
613 | dim1 = (width + 2*edge_size)*(height + 2*edge_size)*(depth + 2*edge_size)
614 |
615 | base = _repeat(tf.range(batch_size)*dim1, out_depth*out_height*out_width)
616 | base_z0 = base + z0*dim2
617 | base_z1 = base + z1*dim2
618 |
619 | base_y00 = base_z0 + y0*dim3
620 | base_y01 = base_z0 + y1*dim3
621 | base_y10 = base_z1 + y0*dim3
622 | base_y11 = base_z1 + y1*dim3
623 |
624 | idx_000 = base_y00 + x0
625 | idx_001 = base_y00 + x1
626 | idx_010 = base_y01 + x0
627 | idx_011 = base_y01 + x1
628 | idx_100 = base_y10 + x0
629 | idx_101 = base_y10 + x1
630 | idx_110 = base_y11 + x0
631 | idx_111 = base_y11 + x1
632 |
633 | # use indices to lookup pixels in the flat image and restore
634 | # channels dim
635 | vol_flat = tf.reshape(vol, [-1, channels])
636 | I000 = tf.gather(vol_flat, idx_000)
637 | I001 = tf.gather(vol_flat, idx_001)
638 | I010 = tf.gather(vol_flat, idx_010)
639 | I011 = tf.gather(vol_flat, idx_011)
640 | I100 = tf.gather(vol_flat, idx_100)
641 | I101 = tf.gather(vol_flat, idx_101)
642 | I110 = tf.gather(vol_flat, idx_110)
643 | I111 = tf.gather(vol_flat, idx_111)
644 |
645 | # and finally calculate interpolated values
646 | w000 = tf.expand_dims((z1_f-z)*(y1_f-y)*(x1_f-x),1)
647 | w001 = tf.expand_dims((z1_f-z)*(y1_f-y)*(x-x0_f),1)
648 | w010 = tf.expand_dims((z1_f-z)*(y-y0_f)*(x1_f-x),1)
649 | w011 = tf.expand_dims((z1_f-z)*(y-y0_f)*(x-x0_f),1)
650 | w100 = tf.expand_dims((z-z0_f)*(y1_f-y)*(x1_f-x),1)
651 | w101 = tf.expand_dims((z-z0_f)*(y1_f-y)*(x-x0_f),1)
652 | w110 = tf.expand_dims((z-z0_f)*(y-y0_f)*(x1_f-x),1)
653 | w111 = tf.expand_dims((z-z0_f)*(y-y0_f)*(x-x0_f),1)
654 |
655 | output = tf.add_n([
656 | w000*I000,
657 | w001*I001,
658 | w010*I010,
659 | w011*I011,
660 | w100*I100,
661 | w101*I101,
662 | w110*I110,
663 | w111*I111])
664 | return output
665 |
666 |
667 | def bilinear_interp(im, x, y, out_size):
668 | #with tf.variable_scope('bilinear_interp'):
669 | batch_size, height, width, channels = im.get_shape().as_list()
670 |
671 | x = tf.cast(x, tf.float32)
672 | y = tf.cast(y, tf.float32)
673 | height_f = tf.cast(height, tf.float32)
674 | width_f = tf.cast(width, tf.float32)
675 | out_height = out_size[0]
676 | out_width = out_size[1]
677 |
678 | # scale indices from [-1, 1] to [0, width/height - 1]
679 | x = tf.clip_by_value(x, -1, 1)
680 | y = tf.clip_by_value(y, -1, 1)
681 | x = (x + 1.0) / 2.0 * (width_f-1.0)
682 | y = (y + 1.0) / 2.0 * (height_f-1.0)
683 |
684 | # do sampling
685 | x0_f = tf.floor(x)
686 | y0_f = tf.floor(y)
687 | x1_f = x0_f + 1
688 | y1_f = y0_f + 1
689 |
690 | x0 = tf.cast(x0_f, tf.int32)
691 | y0 = tf.cast(y0_f, tf.int32)
692 | x1 = tf.cast(tf.minimum(x1_f, width_f - 1), tf.int32)
693 | y1 = tf.cast(tf.minimum(y1_f, height_f - 1), tf.int32)
694 |
695 | dim2 = width
696 | dim1 = width*height
697 |
698 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width)
699 |
700 | base_y0 = base + y0*dim2
701 | base_y1 = base + y1*dim2
702 |
703 | idx_00 = base_y0 + x0
704 | idx_01 = base_y0 + x1
705 | idx_10 = base_y1 + x0
706 | idx_11 = base_y1 + x1
707 |
708 | # use indices to lookup pixels in the flat image and restore
709 | # channels dim
710 | im_flat = tf.reshape(im, [-1, channels])
711 |
712 | I00 = tf.gather(im_flat, idx_00)
713 | I01 = tf.gather(im_flat, idx_01)
714 | I10 = tf.gather(im_flat, idx_10)
715 | I11 = tf.gather(im_flat, idx_11)
716 |
717 | # and finally calculate interpolated values
718 | w00 = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
719 | w01 = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
720 | w10 = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
721 | w11 = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
722 |
723 | output = tf.add_n([w00*I00, w01*I01, w10*I10, w11*I11])
724 | return output
725 |
726 |
727 | def bicubic_interp(im, x, y, out_size):
728 | alpha = -0.75 # same as in tf.image.resize_images, see:
729 | # tensorflow/tensorflow/core/kernels/resize_bicubic_op.cc
730 | bicubic_coeffs = (
731 | (1, 0, -(alpha+3), (alpha+2)),
732 | (0, alpha, -2*alpha, alpha ),
733 | (0, -alpha, 2*alpha+3, -alpha-2 ),
734 | (0, 0, alpha, -alpha )
735 | )
736 |
737 | #with tf.variable_scope('bicubic_interp'):
738 | batch_size, height, width, channels = im.get_shape().as_list()
739 |
740 | x = tf.cast(x, tf.float32)
741 | y = tf.cast(y, tf.float32)
742 | height_f = tf.cast(height, tf.float32)
743 | width_f = tf.cast(width, tf.float32)
744 | out_height = out_size[0]
745 | out_width = out_size[1]
746 |
747 | # scale indices from [-1, 1] to [0, width/height - 1]
748 | x = tf.clip_by_value(x, -1, 1)
749 | y = tf.clip_by_value(y, -1, 1)
750 | x = (x + 1.0) / 2.0 * (width_f-1.0)
751 | y = (y + 1.0) / 2.0 * (height_f-1.0)
752 |
753 | # do sampling
754 | # integer coordinates of 4x4 neighbourhood around (x0_f, y0_f)
755 | x0_f = tf.floor(x)
756 | y0_f = tf.floor(y)
757 | xm1_f = x0_f - 1
758 | ym1_f = y0_f - 1
759 | xp1_f = x0_f + 1
760 | yp1_f = y0_f + 1
761 | xp2_f = x0_f + 2
762 | yp2_f = y0_f + 2
763 |
764 | # clipped integer coordinates
765 | xs = [0, 0, 0, 0]
766 | ys = [0, 0, 0, 0]
767 | xs[0] = tf.cast(x0_f, tf.int32)
768 | ys[0] = tf.cast(y0_f, tf.int32)
769 | xs[1] = tf.cast(tf.maximum(xm1_f, 0), tf.int32)
770 | ys[1] = tf.cast(tf.maximum(ym1_f, 0), tf.int32)
771 | xs[2] = tf.cast(tf.minimum(xp1_f, width_f - 1), tf.int32)
772 | ys[2] = tf.cast(tf.minimum(yp1_f, height_f - 1), tf.int32)
773 | xs[3] = tf.cast(tf.minimum(xp2_f, width_f - 1), tf.int32)
774 | ys[3] = tf.cast(tf.minimum(yp2_f, height_f - 1), tf.int32)
775 |
776 | # indices of neighbours for the batch
777 | dim2 = width
778 | dim1 = width*height
779 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width)
780 |
781 | idx = []
782 | for i in range(4):
783 | idx.append([])
784 | for j in range(4):
785 | cur_idx = base + ys[i]*dim2 + xs[j]
786 | idx[i].append(cur_idx)
787 |
788 | # use indices to lookup pixels in the flat image and restore
789 | # channels dim
790 | im_flat = tf.reshape(im, [-1, channels])
791 |
792 | Is = []
793 | for i in range(4):
794 | Is.append([])
795 | for j in range(4):
796 | Is[i].append(tf.gather(im_flat, idx[i][j]))
797 |
798 | def get_weights(x, x0_f):
799 | tx = (x-x0_f)
800 | tx2 = tx * tx
801 | tx3 = tx2 * tx
802 | t = [1, tx, tx2, tx3]
803 | weights = []
804 | for i in range(4):
805 | result = 0
806 | for j in range(4):
807 | result = result + bicubic_coeffs[i][j]*t[j]
808 | result = tf.reshape(result, [-1, 1])
809 | weights.append(result)
810 | return weights
811 |
812 |
813 | # to calculate interpolated values first,
814 | # interpolate in x dim 4 times for y=[0, -1, 1, 2]
815 | weights = get_weights(x, x0_f)
816 | x_interp = []
817 | for i in range(4):
818 | result = []
819 | for j in range(4):
820 | result = result + [weights[j]*Is[i][j]]
821 | x_interp.append(tf.add_n(result))
822 |
823 | # finally, interpolate in y dim using interpolations in x dim
824 | weights = get_weights(y, y0_f)
825 | y_interp = []
826 | for i in range(4):
827 | y_interp = y_interp + [weights[i]*x_interp[i]]
828 |
829 | output = tf.add_n(y_interp)
830 | return output
831 |
832 |
833 | def Nearest_interp(im, x, y, out_size):
834 | #with tf.variable_scope('bilinear_interp'):
835 | batch_size, height, width, channels = im.get_shape().as_list()
836 |
837 | x = tf.cast(x, tf.float32)
838 | y = tf.cast(y, tf.float32)
839 | height_f = tf.cast(height, tf.float32)
840 | width_f = tf.cast(width, tf.float32)
841 | out_height = out_size[0]
842 | out_width = out_size[1]
843 |
844 | # scale indices from [-1, 1] to [0, width/height - 1]
845 | x = tf.clip_by_value(x, -1, 1)
846 | y = tf.clip_by_value(y, -1, 1)
847 | x = (x + 1.0) / 2.0 * (width_f-1.0)
848 | y = (y + 1.0) / 2.0 * (height_f-1.0)
849 |
850 | # do sampling
851 | x0_f = tf.floor(x)
852 | y0_f = tf.floor(y)
853 |
854 |
855 | x0 = tf.cast(x0_f, tf.int32)
856 | y0 = tf.cast(y0_f, tf.int32)
857 |
858 | dim2 = width
859 | dim1 = width*height
860 |
861 | base = _repeat(tf.range(batch_size)*dim1, out_height*out_width)
862 |
863 | base_y0 = base + y0*dim2
864 |
865 | idx = base_y0 + x0
866 |
867 | # use indices to lookup pixels in the flat image and restore
868 | # channels dim
869 | im_flat = tf.reshape(im, [-1, channels])
870 |
871 | I00 = tf.gather(im_flat, idx)
872 |
873 | output = I00
874 | return output
875 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from xml.etree import ElementTree
3 | import cv2
4 | import tensorflow as tf
5 | import tensorflow_addons as tfa
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from tensorflow.keras.callbacks import ModelCheckpoint
9 | from sklearn.model_selection import train_test_split
10 | from spatial_transformer import ProjectiveTransformer, AffineTransformer
11 | #from tensorflow.keras.applications.xception import preprocess_input, Xception
12 | from tensorflow.keras.applications.convnext import ConvNeXtTiny, preprocess_input
13 | from tensorflow.keras.utils import to_categorical
14 | from tensorflow.keras.models import Model, load_model, Sequential
15 | from tensorflow.keras.layers import *
16 | from scipy.interpolate import interp1d
17 | import random
18 | import scipy.io
19 | import math
20 | from scipy import ndimage
21 | from sklearn.utils import shuffle
22 |
23 |
24 | val_path = 'E:/LSUN2016_surface_relabel/surface_relabel/val/'
25 |
26 | img = cv2.imread(val_path+'sun_atssgbmizunolhzn.jpg')
27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28 | img = cv2.resize(img, (400 , 400))
29 | img1 = np.array(img,copy=True)
30 | img = img[tf.newaxis,...]
31 | img = preprocess_input(img)
32 |
33 | seg = cv2.imread(val_path+'sun_atssgbmizunolhzn.png',0)
34 | seg = cv2.resize(seg, (400 ,400), interpolation= cv2.INTER_NEAREST)
35 | seg = seg/51.0
36 |
37 | ref_img = tf.io.read_file('ref_img2.png')
38 | ref_img = tf.io.decode_png(ref_img)
39 | ref_img = tf.cast(ref_img, tf.float32) / 51.0
40 | ref_img = ref_img[tf.newaxis,...]
41 | #ref_img = tf.tile(ref_img, [1,1,1,1])
42 | print(ref_img.shape)
43 |
44 |
45 | base_model = ConvNeXtTiny(include_top=False, weights="imagenet", input_shape= (400,400,3), pooling = 'avg')
46 | theta = Dense(8)(base_model.output)
47 | stl= ProjectiveTransformer((400,400)).transform(ref_img, theta)
48 | model = Model(base_model.input, stl)
49 |
50 |
51 | model.summary()
52 | model.load_weights('')
53 |
54 |
55 | out= model.predict(img)
56 |
57 | out = np.rint(out[0,:,:,0])
58 |
59 |
60 | plt.figure('seg')
61 | plt.imshow(out, vmin = 1, vmax= 5)
62 | plt.figure('gt')
63 | plt.imshow(seg , vmin = 1, vmax= 5)
64 | plt.figure('img')
65 | plt.imshow(img1)
66 | plt.show()
67 |
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from xml.etree import ElementTree
3 | import cv2
4 | import tensorflow as tf
5 | import tensorflow_addons as tfa
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from tensorflow.keras.callbacks import ModelCheckpoint
9 | from sklearn.model_selection import train_test_split
10 | from spatial_transformer import ProjectiveTransformer, AffineTransformer
11 | #from tensorflow.keras.applications.xception import preprocess_input, Xception
12 | from tensorflow.keras.applications.convnext import ConvNeXtTiny, preprocess_input
13 | from tensorflow.keras.utils import to_categorical
14 | from tensorflow.keras.models import Model, load_model, Sequential
15 | from tensorflow.keras.layers import *
16 | import random
17 | import scipy.io
18 | from sklearn.utils import shuffle
19 |
20 | train_path = 'E:/LSUN2016_surface_relabel/surface_relabel/train/'
21 | val_path = 'E:/LSUN2016_surface_relabel/surface_relabel/val/'
22 |
23 | train_data = scipy.io.loadmat('training.mat')
24 | train_data = train_data.get('training')[0]
25 | val_data = scipy.io.loadmat('validation.mat')
26 | val_data = val_data.get('validation')[0]
27 |
28 | print(train_data.shape)
29 |
30 | train_data = shuffle(train_data)
31 | batch_size = 48
32 |
33 |
34 | def data_generator(data, path = train_path, batch_size=32, number_of_batches=None):
35 | counter = 0
36 | n_classes = 11
37 |
38 | #training parameters
39 | train_w , train_h = 400, 400
40 | while True:
41 | idx_start = batch_size * counter
42 | idx_end = batch_size * (counter + 1)
43 | x_batch = []
44 | y_seg_batch = []
45 | y_batch = []
46 | for file in data[idx_start:idx_end]:
47 | img_name = list(file)[0][0]
48 | img = cv2.imread(path+img_name+'.jpg')
49 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
50 | img = cv2.resize(img, (train_w , train_h))
51 | seg = cv2.imread(path+img_name+'.png',0)
52 | seg = cv2.resize(seg, (train_w , train_h), interpolation= cv2.INTER_NEAREST)
53 | seg = seg / 51.0
54 | label = list(file)[2][0][0]
55 | #print(label)
56 | if label in [3,4,5,10]:
57 | seg[seg==2] = 1
58 | rand_hflip = random.randint(0, 1)
59 | rand_brightness = random.randint(0, 1)
60 |
61 | if (rand_hflip == 1 and path == train_path):
62 | img = cv2.flip(img, 1)
63 | seg = cv2.flip(seg, 1)
64 | seg_temp = np.array(seg, copy= True)
65 | #flip the right and left walls
66 | seg[seg_temp==2] = 3
67 | seg[seg_temp==3] = 2
68 | if (rand_brightness == 1 and path == train_path):
69 | val = random.uniform(-1, 1) * 30
70 | img = img + val
71 | img = np.clip(img, 0, 255)
72 | img = np.uint8(img)
73 |
74 | img = preprocess_input(img)
75 | x_batch.append(img)
76 | y_seg_batch.append(seg)
77 | counter += 1
78 | x_train = np.array(x_batch)
79 | y_seg_train = np.array(y_seg_batch)
80 | yield x_train, y_seg_train
81 | if (counter == number_of_batches):
82 | counter = 0
83 |
84 | ref_img = tf.io.read_file('ref_img2.png')
85 | ref_img = tf.io.decode_png(ref_img)
86 | ref_img = tf.cast(ref_img, tf.float32) / 51.0
87 | ref_img = ref_img[tf.newaxis,...]
88 | ref_img = tf.tile(ref_img, [batch_size,1,1,1])
89 | print(ref_img.shape)
90 |
91 | w= np.zeros((768, 8), dtype='float32')
92 |
93 | b = np.zeros(8, dtype='float32')
94 | b[0] = 1
95 | b[4] = 1
96 |
97 | base_model = ConvNext(include_top=False, weights="imagenet", input_shape= (400,400,3), pooling = 'avg')
98 | theta = Dense(8, weights=[w, b])(base_model.output)
99 | stl = ProjectiveTransformer((400,400)).transform(ref_img, theta)
100 | model = Model(base_model.input, stl)
101 |
102 | model.summary()
103 | model.compile(optimizer = tfa.optimizers.AdamW(learning_rate=0.001, weight_decay = 0.0001), loss = ['huber_loss'], metrics = ['accuracy'])
104 |
105 | model.summary()
106 |
107 | filepath="E:/LSUN2016_surface_relabel/surface_relabel/weights_stn/weights-improvement-{epoch:02d}-{loss:.4f}-{val_loss:.4f}.h5"
108 | checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=True, mode='auto', period=1)
109 | callbacks_list = [checkpoint]
110 | n_train = 4000
111 | n_valid= 394
112 | model.fit(data_generator(train_data, train_path, batch_size, number_of_batches= n_train // batch_size),
113 | steps_per_epoch=max(1, n_train//batch_size), initial_epoch = 0,
114 | validation_data= data_generator(val_data, val_path, batch_size, number_of_batches= n_valid // batch_size),
115 | validation_steps=max(1, n_valid//batch_size),
116 | epochs=150,
117 | callbacks=callbacks_list)
118 |
--------------------------------------------------------------------------------
/weight file.txt:
--------------------------------------------------------------------------------
1 | download from the following link: https://drive.google.com/file/d/1KWHofYVsjb0bRi6sVuZtPpanZWCGyWSr/view?usp=sharing
2 |
--------------------------------------------------------------------------------