├── README.md ├── examples └── foj_demo.ipynb ├── field_of_junctions.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | Input image 4 | Optimization 7 | 8 | # Field of Junctions 9 | 10 | ### [Project Page](http://vision.seas.harvard.edu/foj/) | [Paper](https://arxiv.org/abs/2011.13866) | [Video](https://youtu.be/M0VwBw_aVQA) 11 | 12 | This repository contains code for: 13 | 14 | **[Field of Junctions: Extracting Boundary Structure at Low SNR](http://vision.seas.harvard.edu/foj/)** 15 |
16 | [Dor Verbin](https://scholar.harvard.edu/dorverbin) and [Todd Zickler](http://www.eecs.harvard.edu/~zickler/) 17 |
18 | International Conference on Computer Vision (ICCV), 2021. 19 | 20 | 21 | Please contact us by email for questions about our paper or code. 22 | 23 | 24 | 25 | ## Requirements 26 | 27 | Our code is implemented in pytorch. It has been tested using pytorch 1.6 but it should work for other pytorch 1.x versions. The following packages are required: 28 | 29 | - python 3.x 30 | - pytorch 1.x 31 | - numpy >= 1.14.0 32 | 33 | 34 | ## Usage 35 | 36 | To analyze an `HxWxC` image into its field of junctions, you can simply run the following code snippet: 37 | ``` 38 | from field_of_junctions import FieldOfJunctions 39 | foj = FieldOfJunctions(img, opts) 40 | foj.optimize() 41 | ``` 42 | 43 | In addition to the input image, the `FieldOfJunctions` class requires an object `opts` with the following fields: 44 | ``` 45 | R Patch size 46 | stride Stride of field of junctions (e.g. opts.stride == 1 is dense) 47 | eta Width of Heaviside functions 48 | delta Width of boundary maps 49 | lr_angles Learning rate of angles 50 | lr_x0y0 Learning rate of vertex positions 51 | lambda_boundary_final Final value of spatial boundary consistency weight lambda_B 52 | lambda_color_final Final value of spatial color consistency weight lambda_C 53 | nvals Number of values to query in Algorithm 2 from the paper 54 | num_initialization_iters Number of initialization iterations 55 | num_refinement_iters Number of refinement iterations 56 | greedy_step_every_iters Frequency of "greedy" iteration (applying Algorithm 2 with consistency) 57 | parallel_mode Whether or not to run Algorithm 2 in parallel over all `nvals` values. 58 | ``` 59 | 60 | Note that setting `parallel_mode` to `True` typically results in faster optimization, but requires more memory during 61 | initialization. For large images on a GPU with limited memory, you might need to set `parallel_mode` to `False`. 62 | 63 | 64 | Instead of using `foj.optimize()` which executes the entire optimization scheme, it is possible to access the field of junctions 65 | during optimization by using the following equivalent code snippet: 66 | ``` 67 | foj = FieldOfJunctions(img, opts) 68 | for i in range(foj.num_iters): 69 | foj.step(i) 70 | ``` 71 | 72 | See Python notebook in the `examples/` folder for a full usage example. 73 | 74 | ### Boundary maps 75 | 76 | In order to compute the (global) boundary maps for a given field of junctions object `foj`: 77 | 78 | ``` 79 | params = torch.cat([foj.angles, foj.x0y0], dim=1) 80 | dists, _, patches = foj.get_dists_and_patches(params) 81 | local_boundaries = foj.dists2boundaries(dists) 82 | global_boundaries = foj.local2global(local_boundaries)[0, 0, :, :].detach().cpu().numpy() 83 | ``` 84 | 85 | ### Boundary-aware smoothing 86 | 87 | In order to compute the boundary-aware smoothing of the input image given `foj`, use: 88 | ``` 89 | params = torch.cat([foj.angles, foj.x0y0], dim=1) 90 | dists, _, patches = foj.get_dists_and_patches(params) 91 | smoothed_img = foj.local2global(patches)[0, :, :, :].permute(1, 2, 0).detach().cpu().numpy() 92 | ``` 93 | 94 | 95 | ## Data 96 | 97 | A zip file containing all of our synthetic data is available [here](https://vision.seas.harvard.edu/foj/dataset/foj_data.zip). It contains the 300 images we used for quantitatively evaluating our algorithm, as well as ground truth locations of edges and corners/junctions. 98 | 99 | 100 | ## Citation 101 | 102 | For citing our paper, please use: 103 | ``` 104 | @InProceedings{verbin2021foj, 105 | author = {Verbin, Dor and Zickler, Todd}, 106 | title = {Field of Junctions: Extracting Boundary Structure at Low {SNR}}, 107 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 108 | month = {October}, 109 | year = {2021}, 110 | pages = {6869-6878} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /field_of_junctions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | if torch.cuda.is_available(): 7 | dev = torch.device('cuda') 8 | else: 9 | dev = torch.device('cpu') 10 | 11 | class FieldOfJunctions: 12 | def __init__(self, img, opts): 13 | """ 14 | Inputs 15 | ------ 16 | img Input image: a numpy array of shape [H, W, C] 17 | opts Object with the following attributes: 18 | R Patch size 19 | stride Stride for junctions (e.g. opts.stride == 1 is a dense field of junctions) 20 | eta Width parameter for Heaviside functions 21 | delta Width parameter for boundary maps 22 | lr_angles Angle learning rate 23 | lr_x0y0 Vertex position learning rate 24 | lambda_boundary_final Final value of spatial boundary consistency term 25 | lambda_color_final Final value of spatial color consistency term 26 | nvals Number of values to query in Algorithm 2 from the paper 27 | num_initialization_iters Number of initialization iterations 28 | num_refinement_iters Number of refinement iterations 29 | greedy_step_every_iters Frequency of "greedy" iteration (applying Algorithm 2 with consistency) 30 | parallel_mode Whether or not to run Algorithm 2 in parallel over all `nvals` values. 31 | """ 32 | 33 | # Get image dimensions 34 | self.H, self.W, self.C = img.shape 35 | 36 | # Make sure number of patches in both dimensions is an integer 37 | assert (self.H - opts.R) % opts.stride == 0 and (self.W - opts.R) % opts.stride == 0, \ 38 | "Number of patches must be an integer." 39 | 40 | # Number of patches (throughout the documentation H_patches and W_patches are denoted by H' and W' resp.) 41 | self.H_patches = (self.H - opts.R) // opts.stride + 1 42 | self.W_patches = (self.W - opts.R) // opts.stride + 1 43 | 44 | # Store total number of iterations (initialization + refinement) 45 | self.num_iters = opts.num_initialization_iters + opts.num_refinement_iters 46 | 47 | # Split image into overlapping patches, creating a tensor of shape [N, C, R, R, H', W'] 48 | t_img = torch.tensor(img, device=dev).permute(2, 0, 1).unsqueeze(0) # input image, shape [1, C, H, W] 49 | self.img_patches = nn.Unfold(opts.R, stride=opts.stride)(t_img).view(1, self.C, opts.R, opts.R, 50 | self.H_patches, self.W_patches) 51 | 52 | # Create pytorch variables for angles and vertex position for each patch 53 | self.angles = torch.zeros(1, 3, self.H_patches, self.W_patches, dtype=torch.float32, device=dev) 54 | self.x0y0 = torch.zeros(1, 2, self.H_patches, self.W_patches, dtype=torch.float32, device=dev) 55 | 56 | # Compute gradients for angles and vertex positions 57 | self.angles.requires_grad = True 58 | self.x0y0.requires_grad = True 59 | 60 | # Compute number of patches containing each pixel: has shape [H, W] 61 | self.num_patches = torch.nn.Fold(output_size=[self.H, self.W], 62 | kernel_size=opts.R, 63 | stride=opts.stride)(torch.ones(1, opts.R**2, 64 | self.H_patches * self.W_patches, 65 | device=dev)).view(self.H, self.W) 66 | 67 | # Create local grid within each patch 68 | y, x = torch.meshgrid([torch.linspace(-1.0, 1.0, opts.R, device=dev), 69 | torch.linspace(-1.0, 1.0, opts.R, device=dev)]) 70 | self.x = x.view(1, opts.R, opts.R, 1, 1) 71 | self.y = y.view(1, opts.R, opts.R, 1, 1) 72 | 73 | # Optimization parameters 74 | adam_beta1 = 0.5 75 | adam_beta2 = 0.99 76 | adam_eps = 1e-08 77 | 78 | # Create optimizers for angles and vertices 79 | optimizer_angles = optim.Adam([self.angles], 80 | opts.lr_angles, [adam_beta1, adam_beta2], eps=adam_eps) 81 | optimizer_x0y0 = optim.Adam([self.x0y0], 82 | opts.lr_x0y0, [adam_beta1, adam_beta2], eps=adam_eps) 83 | self.optimizers = [optimizer_angles, optimizer_x0y0] 84 | 85 | # Values to search over in Algorithm 2: [0, 2pi) for angles, [-3, 3] for vertex position. 86 | self.angle_range = torch.linspace(0.0, 2*np.pi, opts.nvals+1, device=dev)[:opts.nvals] 87 | self.x0y0_range = torch.linspace(-3.0, 3.0, opts.nvals, device=dev) 88 | 89 | # Save current global image and boundary map (initially None) 90 | self.global_image = None 91 | self.global_boundaries = None 92 | 93 | # Save opts 94 | self.opts = opts 95 | 96 | def optimize(self): 97 | """ 98 | Optimize field of junctions. 99 | """ 100 | for iteration in range(self.num_iters): 101 | self.step(iteration) 102 | 103 | def step(self, iteration): 104 | """ 105 | Perform one step (either initialization's coordinate descent, or refinement gradient descent) 106 | Inputs 107 | ------ 108 | iteration Iteration number (integer) 109 | """ 110 | 111 | # Linearly increase lambda from 0 to lambda_boundary_final and lambda_color_final 112 | if self.opts.num_refinement_iters <= 1: 113 | factor = 0.0 114 | else: 115 | factor = max([0, (iteration - self.opts.num_initialization_iters) / (self.opts.num_refinement_iters - 1)]) 116 | lmbda_boundary = factor * self.opts.lambda_boundary_final 117 | lmbda_color = factor * self.opts.lambda_color_final 118 | 119 | if iteration < self.opts.num_initialization_iters or \ 120 | (iteration - self.opts.num_initialization_iters + 1) % self.opts.greedy_step_every_iters == 0: 121 | self.initialization_step(lmbda_boundary, lmbda_color) 122 | else: 123 | self.refinement_step(lmbda_boundary, lmbda_color) 124 | 125 | def initialization_step(self, lmbda_boundary, lmbda_color): 126 | """ 127 | Perform a single coordinate descent step (using Algorithm 2 from the paper). 128 | Implements a heuristic for searching along the three junction angles after updating each of 129 | the five parameters. The original value is included in the search, so the extra step is 130 | guaranteed to obtain a better (or equally-good) set of parameters. 131 | 132 | Inputs 133 | ------ 134 | lmbda_boundary Spatial consistency boundary loss weight 135 | lmbda_color Spatial consistency color loss weight 136 | """ 137 | params = torch.cat([self.angles, self.x0y0], dim=1).detach() 138 | 139 | # Run one step of Algorithm 2, sequentially improving each coordinate 140 | for i in range(5): 141 | # Repeat the set of parameters `nvals` times along 0th dimension 142 | params_query = params.repeat(self.opts.nvals, 1, 1, 1) 143 | param_range = self.angle_range if i < 3 else self.x0y0_range 144 | params_query[:, i, :, :] = params_query[:, i, :, :] + param_range.view(-1, 1, 1) 145 | best_ind = self.get_best_inds(params_query, lmbda_boundary, lmbda_color) 146 | 147 | # Update parameters 148 | params[0, i, :, :] = params_query[best_ind.view(1, self.H_patches, self.W_patches), 149 | i, 150 | torch.arange(self.H_patches).view(1, -1, 1), 151 | torch.arange(self.W_patches).view(1, 1, -1)] 152 | 153 | # Heuristic for accelerating convergence (not necessary but sometimes helps): 154 | # Update x0 and y0 along the three optimal angles (search over a line passing through current x0, y0) 155 | for i in range(3): 156 | params_query = params.repeat(self.opts.nvals, 1, 1, 1) 157 | params_query[:, 3, :, :] = params[:, 3, :, :] + torch.cos(params[:, i, :, :]) * self.x0y0_range.view(-1, 1, 1) 158 | params_query[:, 4, :, :] = params[:, 4, :, :] + torch.sin(params[:, i, :, :]) * self.x0y0_range.view(-1, 1, 1) 159 | best_ind = self.get_best_inds(params_query, lmbda_boundary, lmbda_color) 160 | 161 | # Update vertex positions of parameters 162 | for j in range(3, 5): 163 | params[:, j, :, :] = params_query[best_ind.view(1, self.H_patches, self.W_patches), 164 | j, 165 | torch.arange(self.H_patches).view(1, -1, 1), 166 | torch.arange(self.W_patches).view(1, 1, -1)] 167 | 168 | # Update angles and vertex position using the best values found 169 | self.angles.data = params[:, :3, :, :].data 170 | self.x0y0.data = params[:, 3:, :, :].data 171 | 172 | # Update global boundaries and image 173 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color) 174 | self.global_image = self.local2global(patches) 175 | self.global_boundaries = self.local2global(self.dists2boundaries(dists)) 176 | 177 | 178 | def refinement_step(self, lmbda_boundary, lmbda_color): 179 | """ 180 | Perform a single refinement step 181 | 182 | Inputs 183 | ------ 184 | lmbda_boundary Spatial consistency boundary loss weight 185 | lmbda_color Spatial consistency color loss weight 186 | """ 187 | params = torch.cat([self.angles, self.x0y0], dim=1) 188 | 189 | # Compute distance functions, colors, and junction patches 190 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color) 191 | 192 | # Compute loss 193 | loss = self.get_loss(dists, colors, patches, lmbda_boundary, lmbda_color).mean() 194 | 195 | # Take gradient step over angles and vertex positions 196 | for optimizer in self.optimizers: 197 | optimizer.zero_grad() 198 | loss.backward() 199 | for optimizer in self.optimizers: 200 | optimizer.step() 201 | 202 | # Update global boundaries and image 203 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color) 204 | self.global_image = self.local2global(patches) 205 | self.global_boundaries = self.local2global(self.dists2boundaries(dists)) 206 | 207 | 208 | def get_loss(self, dists, colors, patches, lmbda_boundary, lmbda_color): 209 | """ 210 | Compute the objective of our model (see Equation 8 of the paper). 211 | 212 | Inputs 213 | ------ 214 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 215 | colors Tensor of shape [N, C, 3, H', W'] storing the C colors at each patch 216 | patches Tensor of shape [N, C, R, R, H', W'] with each patch having color c_i^{(j)} at the jth wedge, for each i 217 | lmbda_boundary Spatial consistency boundary loss weight 218 | lmbda_color Spatial consistency color loss weight 219 | 220 | Outputs 221 | ------- 222 | Tensor of shape [N, H', W'] with the loss at each patch 223 | """ 224 | # Compute negative log-likelihood for each patch (shape [N, H', W']) 225 | loss_per_patch = ((self.img_patches - patches) ** 2).mean(-3).mean(-3).sum(1) 226 | 227 | # Add spatial consistency loss for each patch, if lambda > 0 228 | if lmbda_boundary > 0.0: 229 | loss_per_patch = loss_per_patch + lmbda_boundary * self.get_boundary_consistency_term(dists) 230 | 231 | if lmbda_color > 0.0: 232 | loss_per_patch = loss_per_patch + lmbda_color * self.get_color_consistency_term(dists, colors) 233 | 234 | return loss_per_patch 235 | 236 | def get_boundary_consistency_term(self, dists): 237 | """ 238 | Compute the spatial consistency term. 239 | 240 | Inputs 241 | ------ 242 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 243 | 244 | Outputs 245 | ------- 246 | Tensor of shape [N, H', W'] with the consistency loss at each patch 247 | """ 248 | # Split global boundaries into patches 249 | curr_global_boundaries_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)( 250 | self.global_boundaries.detach()).view(1, 1, self.opts.R,self.opts.R, self.H_patches, self.W_patches) 251 | 252 | # Get local boundaries defined using the queried parameters (defined by `dists`) 253 | local_boundaries = self.dists2boundaries(dists) 254 | 255 | # Compute consistency term 256 | consistency = ((local_boundaries - curr_global_boundaries_patches) ** 2).mean(2).mean(2) 257 | 258 | return consistency[:, 0, :, :] 259 | 260 | def get_color_consistency_term(self, dists, colors): 261 | """ 262 | Compute the spatial consistency term. 263 | 264 | Inputs 265 | ------ 266 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 267 | 268 | Outputs 269 | ------- 270 | Tensor of shape [N, H', W'] with the consistency loss at each patch 271 | """ 272 | # Split into patches 273 | curr_global_image_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)( 274 | self.global_image.detach()).view(1, self.C, self.opts.R,self.opts.R, self.H_patches, self.W_patches) 275 | 276 | wedges = self.dists2indicators(dists) # shape [N, 3, R, R, H', W'] 277 | 278 | # Compute consistency term 279 | consistency = (wedges.unsqueeze(1) * ( 280 | colors.unsqueeze(-3).unsqueeze(-3) - curr_global_image_patches.unsqueeze(2)) ** 2).mean(-3).mean(-3).sum(1).sum(1) 281 | 282 | return consistency 283 | 284 | 285 | 286 | def get_dists_and_patches(self, params, lmbda_color=0.0): 287 | """ 288 | Compute distance functions and piecewise-constant patches given junction parameters. 289 | 290 | Inputs 291 | ------ 292 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each 293 | 5-vector has format (angle1, angle2, angle3, x0, y0). 294 | 295 | Outputs 296 | ------- 297 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 298 | colors Tensor of shape [N, C, 3, H', W'] 299 | patches Tensor of shape [N, C, R, R, H', W'] with the constant color function at each of the 3 wedges 300 | """ 301 | 302 | # Get dists 303 | dists = self.params2dists(params) # shape [N, 2, R, R, H', W'] 304 | 305 | # Get wedge indicator functions 306 | wedges = self.dists2indicators(dists) # shape [N, 3, R, R, H', W'] 307 | 308 | if lmbda_color >= 0 and self.global_image is not None: 309 | curr_global_image_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)( 310 | self.global_image.detach()).view(1, self.C, self.opts.R,self.opts.R, self.H_patches, self.W_patches) 311 | 312 | numerator = ((self.img_patches + lmbda_color * 313 | curr_global_image_patches).unsqueeze(2) * wedges.unsqueeze(1)).sum(-3).sum(-3) 314 | denominator = (1.0 + lmbda_color) * wedges.sum(-3).sum(-3).unsqueeze(1) 315 | 316 | colors = numerator / (denominator + 1e-10) 317 | else: 318 | # Get best color for each wedge and each patch 319 | colors = (self.img_patches.unsqueeze(2) * wedges.unsqueeze(1)).sum(-3).sum(-3) / \ 320 | (wedges.sum(-3).sum(-3).unsqueeze(1) + 1e-10) 321 | 322 | # Fill wedges with optimal colors 323 | patches = (wedges.unsqueeze(1) * colors.unsqueeze(-3).unsqueeze(-3)).sum(dim=2) 324 | 325 | return dists, colors, patches 326 | 327 | def dists2boundaries(self, dists): 328 | """ 329 | Compute boundary map for each patch, given distance functions. The width of the boundary is determined 330 | by opts.delta. 331 | 332 | Inputs 333 | ------ 334 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 335 | 336 | Outputs 337 | ------- 338 | Tensor of shape [N, 1, R, R, H', W'] with values of boundary map for every patch 339 | """ 340 | # Find places where either distance transform is small, except where d1 > 0 and d2 < 0 341 | d1 = dists[:, 0:1, :, :, :, :] 342 | d2 = dists[:, 1:2, :, :, :, :] 343 | minabsdist = torch.where(d1 < 0.0, -d1, torch.where(d2 < 0.0, torch.min(d1, -d2), torch.min(d1, d2))) 344 | 345 | return 1.0 / (1.0 + (minabsdist / self.opts.delta) ** 2) 346 | 347 | def local2global(self, patches): 348 | """ 349 | Compute average value for each pixel over all patches containing it. 350 | For example, this can be used to compute the global boundary maps, or the boundary-aware smoothed image. 351 | 352 | Inputs 353 | ------ 354 | patches Tensor of shape [N, C, R, R, H', W']. patches[n, :, :, :, i, j] is an RxR C-channel patch 355 | at the (i, j)th spatial position of the nth entry. 356 | 357 | 358 | Outputs 359 | ------- 360 | Tensor of shape [N, C, H, W] of averages over all patches containing each pixel. 361 | """ 362 | N = patches.shape[0] 363 | C = patches.shape[1] 364 | return torch.nn.Fold(output_size=[self.H, self.W], kernel_size=self.opts.R, stride=self.opts.stride)( 365 | patches.view(N, C*self.opts.R**2, -1)).view(N, C, self.H, self.W) / \ 366 | self.num_patches.unsqueeze(0).unsqueeze(0) 367 | 368 | def get_best_inds(self, params, lmbda_boundary, lmbda_color): 369 | """ 370 | Compute the best index along the 0th dimension of `params` for each pixel position. 371 | Has two possible modes determined by self.opts.parallel_mode: 372 | 1) When True, all N values are computed in parallel (generally faster, requires more memory) 373 | 2) When False, the values are computed sequentially (generally slower, requires less memory) 374 | 375 | Inputs 376 | ------ 377 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each 378 | 5-vector has format (angle1, angle2, angle3, x0, y0). 379 | lmbda_boundary Spatial consistency boundary loss weight 380 | lmbda_color Spatial consistency color loss weight 381 | 382 | Outputs 383 | ------- 384 | Tensor of shape [H', W'] with each value in {0, ..., N-1} holding the 385 | index of the best junction parameters at that position. 386 | """ 387 | if self.opts.parallel_mode: 388 | dists, colors, smooth_patches = self.get_dists_and_patches(params, lmbda_color) 389 | loss_per_patch = self.get_loss(dists, colors, smooth_patches, lmbda_boundary, lmbda_color) 390 | best_ind = loss_per_patch.argmin(dim=0) 391 | 392 | else: 393 | # First initialize tensors 394 | best_ind = torch.zeros(self.H_patches, self.W_patches, device=dev, dtype=torch.int64) 395 | best_loss_per_patch = torch.zeros(self.H_patches, self.W_patches, device=dev) + 1e10 396 | 397 | # Now fill tensors by iterating over the junction dimension and choosing the best junction parameters 398 | for n in range(params.shape[0]): 399 | dists, colors, smooth_patches = self.get_dists_and_patches(params[n:n+1, :, :, :], lmbda_color) 400 | 401 | loss_per_patch = self.get_loss(dists, colors, smooth_patches, lmbda_boundary, lmbda_color) 402 | 403 | improved_inds = loss_per_patch[0] < best_loss_per_patch 404 | best_ind = torch.where(improved_inds, torch.tensor(n, device=dev, dtype=torch.int64), best_ind) 405 | best_loss_per_patch = torch.where(improved_inds, loss_per_patch, best_loss_per_patch) 406 | 407 | return best_ind 408 | 409 | def params2dists(self, params, tau=1e-1): 410 | """ 411 | Compute distance functions from field of junctions. 412 | 413 | Inputs 414 | ------ 415 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each 416 | 5-vector has format (angle1, angle2, angle3, x0, y0). 417 | tau Constant used for lifting the level set function to be either entirely positive 418 | or entirely negative when an angle approaches 0 or 2pi. 419 | 420 | 421 | Outputs 422 | ------- 423 | Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 424 | """ 425 | x0 = params[:, 3, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W'] 426 | y0 = params[:, 4, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W'] 427 | 428 | # Sort so angle1 <= angle2 <= angle3 (mod 2pi) 429 | angles = torch.remainder(params[:, :3, :, :], 2 * np.pi) 430 | angles = torch.sort(angles, dim=1)[0] 431 | 432 | angle1 = angles[:, 0, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W'] 433 | angle2 = angles[:, 1, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W'] 434 | angle3 = angles[:, 2, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W'] 435 | 436 | # Define another angle halfway between angle3 and angle1, clockwise from angle3 437 | # This isn't critical but it seems a bit more stable for computing gradients 438 | angle4 = 0.5 * (angle1 + angle3) + \ 439 | torch.where(torch.remainder(0.5 * (angle1 - angle3), 2 * np.pi) >= np.pi, 440 | torch.ones_like(angle1) * np.pi, torch.zeros_like(angle1)) 441 | 442 | def g(dtheta): 443 | # Map from [0, 2pi] to [-1, 1] 444 | return (dtheta / np.pi - 1.0) ** 35 445 | 446 | # Compute the two distance functions 447 | sgn42 = torch.where(torch.remainder(angle2 - angle4, 2 * np.pi) < np.pi, 448 | torch.ones_like(angle2), -torch.ones_like(angle2)) 449 | tau42 = g(torch.remainder(angle2 - angle4, 2*np.pi)) * tau 450 | 451 | dist42 = sgn42 * torch.min( sgn42 * (-torch.sin(angle4) * (self.x - x0) + torch.cos(angle4) * (self.y - y0)), 452 | -sgn42 * (-torch.sin(angle2) * (self.x - x0) + torch.cos(angle2) * (self.y - y0))) + tau42 453 | 454 | sgn13 = torch.where(torch.remainder(angle3 - angle1, 2 * np.pi) < np.pi, 455 | torch.ones_like(angle3), -torch.ones_like(angle3)) 456 | tau13 = g(torch.remainder(angle3 - angle1, 2*np.pi)) * tau 457 | dist13 = sgn13 * torch.min( sgn13 * (-torch.sin(angle1) * (self.x - x0) + torch.cos(angle1) * (self.y - y0)), 458 | -sgn13 * (-torch.sin(angle3) * (self.x - x0) + torch.cos(angle3) * (self.y - y0))) + tau13 459 | 460 | return torch.stack([dist13, dist42], dim=1) 461 | 462 | def dists2indicators(self, dists): 463 | """ 464 | Computes the indicator functions u_1, u_2, u_3 from the distance functions d_{13}, d_{12} 465 | 466 | Inputs 467 | ------ 468 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch 469 | 470 | Outputs 471 | ------- 472 | Tensor of shape [N, 3, R, R, H', W'] with samples of the three indicator functions for every patch 473 | """ 474 | # Apply smooth Heaviside function to distance functions 475 | hdists = 0.5 * (1.0 + (2.0 / np.pi) * torch.atan(dists / self.opts.eta)) 476 | 477 | # Convert Heaviside functions into wedge indicator functions 478 | return torch.stack([1.0 - hdists[:, 0, :, :, :, :], 479 | hdists[:, 0, :, :, :, :] * (1.0 - hdists[:, 1, :, :, :, :]), 480 | hdists[:, 0, :, :, :, :] * hdists[:, 1, :, :, :, :]], dim=1) 481 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def patchstack(patches, border=2, padvalue=1.0): 7 | """ 8 | Stack field of patches into one large image. 9 | 10 | Inputs 11 | ------ 12 | patches Tensor of shape [..., R, R, H', W'] 13 | border Space (in pixels) between neighboring patches (integer) 14 | padvalue Value to fill space with 15 | 16 | Outputs 17 | ------- 18 | Tensor of shape [..., (R+border)*H'-border, (R+border)*W'-border], containing stacked patches 19 | 20 | """ 21 | assert border % 2 == 0, f"border must be even (but got {border})" 22 | 23 | # Pad 3rd and 4th to last dimensions with border//2 pixels valued `padvalue`. 24 | padamt = (0, 0, 0, 0, border//2, border//2, border//2, border//2) 25 | padded = torch.nn.functional.pad(patches, padamt, value=padvalue).detach().cpu() 26 | 27 | permutation = list(range(len(patches.shape))) 28 | permutation[-4] = -2 29 | permutation[-3] = -4 30 | permutation[-2] = -1 31 | permutation[-1] = -3 32 | 33 | new_shape = list(padded.shape[:-2]) 34 | new_shape[-2] *= padded.shape[-2] 35 | new_shape[-1] *= padded.shape[-1] 36 | 37 | output = padded.permute(permutation).contiguous().view(new_shape) 38 | 39 | return output 40 | 41 | 42 | def tile(a, dim, n_tile): 43 | """ 44 | Tile tensor a along dimension `dim` with `n_tile` repeats. 45 | 46 | Written by Edouard360: 47 | https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/4 48 | """ 49 | init_dim = a.size(dim) 50 | repeat_idx = [1] * a.dim() 51 | repeat_idx[dim] = n_tile 52 | a = a.repeat(*(repeat_idx)) 53 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 54 | return torch.index_select(a, dim, order_index) 55 | 56 | --------------------------------------------------------------------------------