├── README.md
├── examples
└── foj_demo.ipynb
├── field_of_junctions.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
4 |
7 |
8 | # Field of Junctions
9 |
10 | ### [Project Page](http://vision.seas.harvard.edu/foj/) | [Paper](https://arxiv.org/abs/2011.13866) | [Video](https://youtu.be/M0VwBw_aVQA)
11 |
12 | This repository contains code for:
13 |
14 | **[Field of Junctions: Extracting Boundary Structure at Low SNR](http://vision.seas.harvard.edu/foj/)**
15 |
16 | [Dor Verbin](https://scholar.harvard.edu/dorverbin) and [Todd Zickler](http://www.eecs.harvard.edu/~zickler/)
17 |
18 | International Conference on Computer Vision (ICCV), 2021.
19 |
20 |
21 | Please contact us by email for questions about our paper or code.
22 |
23 |
24 |
25 | ## Requirements
26 |
27 | Our code is implemented in pytorch. It has been tested using pytorch 1.6 but it should work for other pytorch 1.x versions. The following packages are required:
28 |
29 | - python 3.x
30 | - pytorch 1.x
31 | - numpy >= 1.14.0
32 |
33 |
34 | ## Usage
35 |
36 | To analyze an `HxWxC` image into its field of junctions, you can simply run the following code snippet:
37 | ```
38 | from field_of_junctions import FieldOfJunctions
39 | foj = FieldOfJunctions(img, opts)
40 | foj.optimize()
41 | ```
42 |
43 | In addition to the input image, the `FieldOfJunctions` class requires an object `opts` with the following fields:
44 | ```
45 | R Patch size
46 | stride Stride of field of junctions (e.g. opts.stride == 1 is dense)
47 | eta Width of Heaviside functions
48 | delta Width of boundary maps
49 | lr_angles Learning rate of angles
50 | lr_x0y0 Learning rate of vertex positions
51 | lambda_boundary_final Final value of spatial boundary consistency weight lambda_B
52 | lambda_color_final Final value of spatial color consistency weight lambda_C
53 | nvals Number of values to query in Algorithm 2 from the paper
54 | num_initialization_iters Number of initialization iterations
55 | num_refinement_iters Number of refinement iterations
56 | greedy_step_every_iters Frequency of "greedy" iteration (applying Algorithm 2 with consistency)
57 | parallel_mode Whether or not to run Algorithm 2 in parallel over all `nvals` values.
58 | ```
59 |
60 | Note that setting `parallel_mode` to `True` typically results in faster optimization, but requires more memory during
61 | initialization. For large images on a GPU with limited memory, you might need to set `parallel_mode` to `False`.
62 |
63 |
64 | Instead of using `foj.optimize()` which executes the entire optimization scheme, it is possible to access the field of junctions
65 | during optimization by using the following equivalent code snippet:
66 | ```
67 | foj = FieldOfJunctions(img, opts)
68 | for i in range(foj.num_iters):
69 | foj.step(i)
70 | ```
71 |
72 | See Python notebook in the `examples/` folder for a full usage example.
73 |
74 | ### Boundary maps
75 |
76 | In order to compute the (global) boundary maps for a given field of junctions object `foj`:
77 |
78 | ```
79 | params = torch.cat([foj.angles, foj.x0y0], dim=1)
80 | dists, _, patches = foj.get_dists_and_patches(params)
81 | local_boundaries = foj.dists2boundaries(dists)
82 | global_boundaries = foj.local2global(local_boundaries)[0, 0, :, :].detach().cpu().numpy()
83 | ```
84 |
85 | ### Boundary-aware smoothing
86 |
87 | In order to compute the boundary-aware smoothing of the input image given `foj`, use:
88 | ```
89 | params = torch.cat([foj.angles, foj.x0y0], dim=1)
90 | dists, _, patches = foj.get_dists_and_patches(params)
91 | smoothed_img = foj.local2global(patches)[0, :, :, :].permute(1, 2, 0).detach().cpu().numpy()
92 | ```
93 |
94 |
95 | ## Data
96 |
97 | A zip file containing all of our synthetic data is available [here](https://vision.seas.harvard.edu/foj/dataset/foj_data.zip). It contains the 300 images we used for quantitatively evaluating our algorithm, as well as ground truth locations of edges and corners/junctions.
98 |
99 |
100 | ## Citation
101 |
102 | For citing our paper, please use:
103 | ```
104 | @InProceedings{verbin2021foj,
105 | author = {Verbin, Dor and Zickler, Todd},
106 | title = {Field of Junctions: Extracting Boundary Structure at Low {SNR}},
107 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
108 | month = {October},
109 | year = {2021},
110 | pages = {6869-6878}
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/field_of_junctions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | if torch.cuda.is_available():
7 | dev = torch.device('cuda')
8 | else:
9 | dev = torch.device('cpu')
10 |
11 | class FieldOfJunctions:
12 | def __init__(self, img, opts):
13 | """
14 | Inputs
15 | ------
16 | img Input image: a numpy array of shape [H, W, C]
17 | opts Object with the following attributes:
18 | R Patch size
19 | stride Stride for junctions (e.g. opts.stride == 1 is a dense field of junctions)
20 | eta Width parameter for Heaviside functions
21 | delta Width parameter for boundary maps
22 | lr_angles Angle learning rate
23 | lr_x0y0 Vertex position learning rate
24 | lambda_boundary_final Final value of spatial boundary consistency term
25 | lambda_color_final Final value of spatial color consistency term
26 | nvals Number of values to query in Algorithm 2 from the paper
27 | num_initialization_iters Number of initialization iterations
28 | num_refinement_iters Number of refinement iterations
29 | greedy_step_every_iters Frequency of "greedy" iteration (applying Algorithm 2 with consistency)
30 | parallel_mode Whether or not to run Algorithm 2 in parallel over all `nvals` values.
31 | """
32 |
33 | # Get image dimensions
34 | self.H, self.W, self.C = img.shape
35 |
36 | # Make sure number of patches in both dimensions is an integer
37 | assert (self.H - opts.R) % opts.stride == 0 and (self.W - opts.R) % opts.stride == 0, \
38 | "Number of patches must be an integer."
39 |
40 | # Number of patches (throughout the documentation H_patches and W_patches are denoted by H' and W' resp.)
41 | self.H_patches = (self.H - opts.R) // opts.stride + 1
42 | self.W_patches = (self.W - opts.R) // opts.stride + 1
43 |
44 | # Store total number of iterations (initialization + refinement)
45 | self.num_iters = opts.num_initialization_iters + opts.num_refinement_iters
46 |
47 | # Split image into overlapping patches, creating a tensor of shape [N, C, R, R, H', W']
48 | t_img = torch.tensor(img, device=dev).permute(2, 0, 1).unsqueeze(0) # input image, shape [1, C, H, W]
49 | self.img_patches = nn.Unfold(opts.R, stride=opts.stride)(t_img).view(1, self.C, opts.R, opts.R,
50 | self.H_patches, self.W_patches)
51 |
52 | # Create pytorch variables for angles and vertex position for each patch
53 | self.angles = torch.zeros(1, 3, self.H_patches, self.W_patches, dtype=torch.float32, device=dev)
54 | self.x0y0 = torch.zeros(1, 2, self.H_patches, self.W_patches, dtype=torch.float32, device=dev)
55 |
56 | # Compute gradients for angles and vertex positions
57 | self.angles.requires_grad = True
58 | self.x0y0.requires_grad = True
59 |
60 | # Compute number of patches containing each pixel: has shape [H, W]
61 | self.num_patches = torch.nn.Fold(output_size=[self.H, self.W],
62 | kernel_size=opts.R,
63 | stride=opts.stride)(torch.ones(1, opts.R**2,
64 | self.H_patches * self.W_patches,
65 | device=dev)).view(self.H, self.W)
66 |
67 | # Create local grid within each patch
68 | y, x = torch.meshgrid([torch.linspace(-1.0, 1.0, opts.R, device=dev),
69 | torch.linspace(-1.0, 1.0, opts.R, device=dev)])
70 | self.x = x.view(1, opts.R, opts.R, 1, 1)
71 | self.y = y.view(1, opts.R, opts.R, 1, 1)
72 |
73 | # Optimization parameters
74 | adam_beta1 = 0.5
75 | adam_beta2 = 0.99
76 | adam_eps = 1e-08
77 |
78 | # Create optimizers for angles and vertices
79 | optimizer_angles = optim.Adam([self.angles],
80 | opts.lr_angles, [adam_beta1, adam_beta2], eps=adam_eps)
81 | optimizer_x0y0 = optim.Adam([self.x0y0],
82 | opts.lr_x0y0, [adam_beta1, adam_beta2], eps=adam_eps)
83 | self.optimizers = [optimizer_angles, optimizer_x0y0]
84 |
85 | # Values to search over in Algorithm 2: [0, 2pi) for angles, [-3, 3] for vertex position.
86 | self.angle_range = torch.linspace(0.0, 2*np.pi, opts.nvals+1, device=dev)[:opts.nvals]
87 | self.x0y0_range = torch.linspace(-3.0, 3.0, opts.nvals, device=dev)
88 |
89 | # Save current global image and boundary map (initially None)
90 | self.global_image = None
91 | self.global_boundaries = None
92 |
93 | # Save opts
94 | self.opts = opts
95 |
96 | def optimize(self):
97 | """
98 | Optimize field of junctions.
99 | """
100 | for iteration in range(self.num_iters):
101 | self.step(iteration)
102 |
103 | def step(self, iteration):
104 | """
105 | Perform one step (either initialization's coordinate descent, or refinement gradient descent)
106 | Inputs
107 | ------
108 | iteration Iteration number (integer)
109 | """
110 |
111 | # Linearly increase lambda from 0 to lambda_boundary_final and lambda_color_final
112 | if self.opts.num_refinement_iters <= 1:
113 | factor = 0.0
114 | else:
115 | factor = max([0, (iteration - self.opts.num_initialization_iters) / (self.opts.num_refinement_iters - 1)])
116 | lmbda_boundary = factor * self.opts.lambda_boundary_final
117 | lmbda_color = factor * self.opts.lambda_color_final
118 |
119 | if iteration < self.opts.num_initialization_iters or \
120 | (iteration - self.opts.num_initialization_iters + 1) % self.opts.greedy_step_every_iters == 0:
121 | self.initialization_step(lmbda_boundary, lmbda_color)
122 | else:
123 | self.refinement_step(lmbda_boundary, lmbda_color)
124 |
125 | def initialization_step(self, lmbda_boundary, lmbda_color):
126 | """
127 | Perform a single coordinate descent step (using Algorithm 2 from the paper).
128 | Implements a heuristic for searching along the three junction angles after updating each of
129 | the five parameters. The original value is included in the search, so the extra step is
130 | guaranteed to obtain a better (or equally-good) set of parameters.
131 |
132 | Inputs
133 | ------
134 | lmbda_boundary Spatial consistency boundary loss weight
135 | lmbda_color Spatial consistency color loss weight
136 | """
137 | params = torch.cat([self.angles, self.x0y0], dim=1).detach()
138 |
139 | # Run one step of Algorithm 2, sequentially improving each coordinate
140 | for i in range(5):
141 | # Repeat the set of parameters `nvals` times along 0th dimension
142 | params_query = params.repeat(self.opts.nvals, 1, 1, 1)
143 | param_range = self.angle_range if i < 3 else self.x0y0_range
144 | params_query[:, i, :, :] = params_query[:, i, :, :] + param_range.view(-1, 1, 1)
145 | best_ind = self.get_best_inds(params_query, lmbda_boundary, lmbda_color)
146 |
147 | # Update parameters
148 | params[0, i, :, :] = params_query[best_ind.view(1, self.H_patches, self.W_patches),
149 | i,
150 | torch.arange(self.H_patches).view(1, -1, 1),
151 | torch.arange(self.W_patches).view(1, 1, -1)]
152 |
153 | # Heuristic for accelerating convergence (not necessary but sometimes helps):
154 | # Update x0 and y0 along the three optimal angles (search over a line passing through current x0, y0)
155 | for i in range(3):
156 | params_query = params.repeat(self.opts.nvals, 1, 1, 1)
157 | params_query[:, 3, :, :] = params[:, 3, :, :] + torch.cos(params[:, i, :, :]) * self.x0y0_range.view(-1, 1, 1)
158 | params_query[:, 4, :, :] = params[:, 4, :, :] + torch.sin(params[:, i, :, :]) * self.x0y0_range.view(-1, 1, 1)
159 | best_ind = self.get_best_inds(params_query, lmbda_boundary, lmbda_color)
160 |
161 | # Update vertex positions of parameters
162 | for j in range(3, 5):
163 | params[:, j, :, :] = params_query[best_ind.view(1, self.H_patches, self.W_patches),
164 | j,
165 | torch.arange(self.H_patches).view(1, -1, 1),
166 | torch.arange(self.W_patches).view(1, 1, -1)]
167 |
168 | # Update angles and vertex position using the best values found
169 | self.angles.data = params[:, :3, :, :].data
170 | self.x0y0.data = params[:, 3:, :, :].data
171 |
172 | # Update global boundaries and image
173 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color)
174 | self.global_image = self.local2global(patches)
175 | self.global_boundaries = self.local2global(self.dists2boundaries(dists))
176 |
177 |
178 | def refinement_step(self, lmbda_boundary, lmbda_color):
179 | """
180 | Perform a single refinement step
181 |
182 | Inputs
183 | ------
184 | lmbda_boundary Spatial consistency boundary loss weight
185 | lmbda_color Spatial consistency color loss weight
186 | """
187 | params = torch.cat([self.angles, self.x0y0], dim=1)
188 |
189 | # Compute distance functions, colors, and junction patches
190 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color)
191 |
192 | # Compute loss
193 | loss = self.get_loss(dists, colors, patches, lmbda_boundary, lmbda_color).mean()
194 |
195 | # Take gradient step over angles and vertex positions
196 | for optimizer in self.optimizers:
197 | optimizer.zero_grad()
198 | loss.backward()
199 | for optimizer in self.optimizers:
200 | optimizer.step()
201 |
202 | # Update global boundaries and image
203 | dists, colors, patches = self.get_dists_and_patches(params, lmbda_color)
204 | self.global_image = self.local2global(patches)
205 | self.global_boundaries = self.local2global(self.dists2boundaries(dists))
206 |
207 |
208 | def get_loss(self, dists, colors, patches, lmbda_boundary, lmbda_color):
209 | """
210 | Compute the objective of our model (see Equation 8 of the paper).
211 |
212 | Inputs
213 | ------
214 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
215 | colors Tensor of shape [N, C, 3, H', W'] storing the C colors at each patch
216 | patches Tensor of shape [N, C, R, R, H', W'] with each patch having color c_i^{(j)} at the jth wedge, for each i
217 | lmbda_boundary Spatial consistency boundary loss weight
218 | lmbda_color Spatial consistency color loss weight
219 |
220 | Outputs
221 | -------
222 | Tensor of shape [N, H', W'] with the loss at each patch
223 | """
224 | # Compute negative log-likelihood for each patch (shape [N, H', W'])
225 | loss_per_patch = ((self.img_patches - patches) ** 2).mean(-3).mean(-3).sum(1)
226 |
227 | # Add spatial consistency loss for each patch, if lambda > 0
228 | if lmbda_boundary > 0.0:
229 | loss_per_patch = loss_per_patch + lmbda_boundary * self.get_boundary_consistency_term(dists)
230 |
231 | if lmbda_color > 0.0:
232 | loss_per_patch = loss_per_patch + lmbda_color * self.get_color_consistency_term(dists, colors)
233 |
234 | return loss_per_patch
235 |
236 | def get_boundary_consistency_term(self, dists):
237 | """
238 | Compute the spatial consistency term.
239 |
240 | Inputs
241 | ------
242 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
243 |
244 | Outputs
245 | -------
246 | Tensor of shape [N, H', W'] with the consistency loss at each patch
247 | """
248 | # Split global boundaries into patches
249 | curr_global_boundaries_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)(
250 | self.global_boundaries.detach()).view(1, 1, self.opts.R,self.opts.R, self.H_patches, self.W_patches)
251 |
252 | # Get local boundaries defined using the queried parameters (defined by `dists`)
253 | local_boundaries = self.dists2boundaries(dists)
254 |
255 | # Compute consistency term
256 | consistency = ((local_boundaries - curr_global_boundaries_patches) ** 2).mean(2).mean(2)
257 |
258 | return consistency[:, 0, :, :]
259 |
260 | def get_color_consistency_term(self, dists, colors):
261 | """
262 | Compute the spatial consistency term.
263 |
264 | Inputs
265 | ------
266 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
267 |
268 | Outputs
269 | -------
270 | Tensor of shape [N, H', W'] with the consistency loss at each patch
271 | """
272 | # Split into patches
273 | curr_global_image_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)(
274 | self.global_image.detach()).view(1, self.C, self.opts.R,self.opts.R, self.H_patches, self.W_patches)
275 |
276 | wedges = self.dists2indicators(dists) # shape [N, 3, R, R, H', W']
277 |
278 | # Compute consistency term
279 | consistency = (wedges.unsqueeze(1) * (
280 | colors.unsqueeze(-3).unsqueeze(-3) - curr_global_image_patches.unsqueeze(2)) ** 2).mean(-3).mean(-3).sum(1).sum(1)
281 |
282 | return consistency
283 |
284 |
285 |
286 | def get_dists_and_patches(self, params, lmbda_color=0.0):
287 | """
288 | Compute distance functions and piecewise-constant patches given junction parameters.
289 |
290 | Inputs
291 | ------
292 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each
293 | 5-vector has format (angle1, angle2, angle3, x0, y0).
294 |
295 | Outputs
296 | -------
297 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
298 | colors Tensor of shape [N, C, 3, H', W']
299 | patches Tensor of shape [N, C, R, R, H', W'] with the constant color function at each of the 3 wedges
300 | """
301 |
302 | # Get dists
303 | dists = self.params2dists(params) # shape [N, 2, R, R, H', W']
304 |
305 | # Get wedge indicator functions
306 | wedges = self.dists2indicators(dists) # shape [N, 3, R, R, H', W']
307 |
308 | if lmbda_color >= 0 and self.global_image is not None:
309 | curr_global_image_patches = nn.Unfold(self.opts.R, stride=self.opts.stride)(
310 | self.global_image.detach()).view(1, self.C, self.opts.R,self.opts.R, self.H_patches, self.W_patches)
311 |
312 | numerator = ((self.img_patches + lmbda_color *
313 | curr_global_image_patches).unsqueeze(2) * wedges.unsqueeze(1)).sum(-3).sum(-3)
314 | denominator = (1.0 + lmbda_color) * wedges.sum(-3).sum(-3).unsqueeze(1)
315 |
316 | colors = numerator / (denominator + 1e-10)
317 | else:
318 | # Get best color for each wedge and each patch
319 | colors = (self.img_patches.unsqueeze(2) * wedges.unsqueeze(1)).sum(-3).sum(-3) / \
320 | (wedges.sum(-3).sum(-3).unsqueeze(1) + 1e-10)
321 |
322 | # Fill wedges with optimal colors
323 | patches = (wedges.unsqueeze(1) * colors.unsqueeze(-3).unsqueeze(-3)).sum(dim=2)
324 |
325 | return dists, colors, patches
326 |
327 | def dists2boundaries(self, dists):
328 | """
329 | Compute boundary map for each patch, given distance functions. The width of the boundary is determined
330 | by opts.delta.
331 |
332 | Inputs
333 | ------
334 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
335 |
336 | Outputs
337 | -------
338 | Tensor of shape [N, 1, R, R, H', W'] with values of boundary map for every patch
339 | """
340 | # Find places where either distance transform is small, except where d1 > 0 and d2 < 0
341 | d1 = dists[:, 0:1, :, :, :, :]
342 | d2 = dists[:, 1:2, :, :, :, :]
343 | minabsdist = torch.where(d1 < 0.0, -d1, torch.where(d2 < 0.0, torch.min(d1, -d2), torch.min(d1, d2)))
344 |
345 | return 1.0 / (1.0 + (minabsdist / self.opts.delta) ** 2)
346 |
347 | def local2global(self, patches):
348 | """
349 | Compute average value for each pixel over all patches containing it.
350 | For example, this can be used to compute the global boundary maps, or the boundary-aware smoothed image.
351 |
352 | Inputs
353 | ------
354 | patches Tensor of shape [N, C, R, R, H', W']. patches[n, :, :, :, i, j] is an RxR C-channel patch
355 | at the (i, j)th spatial position of the nth entry.
356 |
357 |
358 | Outputs
359 | -------
360 | Tensor of shape [N, C, H, W] of averages over all patches containing each pixel.
361 | """
362 | N = patches.shape[0]
363 | C = patches.shape[1]
364 | return torch.nn.Fold(output_size=[self.H, self.W], kernel_size=self.opts.R, stride=self.opts.stride)(
365 | patches.view(N, C*self.opts.R**2, -1)).view(N, C, self.H, self.W) / \
366 | self.num_patches.unsqueeze(0).unsqueeze(0)
367 |
368 | def get_best_inds(self, params, lmbda_boundary, lmbda_color):
369 | """
370 | Compute the best index along the 0th dimension of `params` for each pixel position.
371 | Has two possible modes determined by self.opts.parallel_mode:
372 | 1) When True, all N values are computed in parallel (generally faster, requires more memory)
373 | 2) When False, the values are computed sequentially (generally slower, requires less memory)
374 |
375 | Inputs
376 | ------
377 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each
378 | 5-vector has format (angle1, angle2, angle3, x0, y0).
379 | lmbda_boundary Spatial consistency boundary loss weight
380 | lmbda_color Spatial consistency color loss weight
381 |
382 | Outputs
383 | -------
384 | Tensor of shape [H', W'] with each value in {0, ..., N-1} holding the
385 | index of the best junction parameters at that position.
386 | """
387 | if self.opts.parallel_mode:
388 | dists, colors, smooth_patches = self.get_dists_and_patches(params, lmbda_color)
389 | loss_per_patch = self.get_loss(dists, colors, smooth_patches, lmbda_boundary, lmbda_color)
390 | best_ind = loss_per_patch.argmin(dim=0)
391 |
392 | else:
393 | # First initialize tensors
394 | best_ind = torch.zeros(self.H_patches, self.W_patches, device=dev, dtype=torch.int64)
395 | best_loss_per_patch = torch.zeros(self.H_patches, self.W_patches, device=dev) + 1e10
396 |
397 | # Now fill tensors by iterating over the junction dimension and choosing the best junction parameters
398 | for n in range(params.shape[0]):
399 | dists, colors, smooth_patches = self.get_dists_and_patches(params[n:n+1, :, :, :], lmbda_color)
400 |
401 | loss_per_patch = self.get_loss(dists, colors, smooth_patches, lmbda_boundary, lmbda_color)
402 |
403 | improved_inds = loss_per_patch[0] < best_loss_per_patch
404 | best_ind = torch.where(improved_inds, torch.tensor(n, device=dev, dtype=torch.int64), best_ind)
405 | best_loss_per_patch = torch.where(improved_inds, loss_per_patch, best_loss_per_patch)
406 |
407 | return best_ind
408 |
409 | def params2dists(self, params, tau=1e-1):
410 | """
411 | Compute distance functions from field of junctions.
412 |
413 | Inputs
414 | ------
415 | params Tensor of shape [N, 5, H', W'] holding N field of junctions parameters. Each
416 | 5-vector has format (angle1, angle2, angle3, x0, y0).
417 | tau Constant used for lifting the level set function to be either entirely positive
418 | or entirely negative when an angle approaches 0 or 2pi.
419 |
420 |
421 | Outputs
422 | -------
423 | Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
424 | """
425 | x0 = params[:, 3, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W']
426 | y0 = params[:, 4, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W']
427 |
428 | # Sort so angle1 <= angle2 <= angle3 (mod 2pi)
429 | angles = torch.remainder(params[:, :3, :, :], 2 * np.pi)
430 | angles = torch.sort(angles, dim=1)[0]
431 |
432 | angle1 = angles[:, 0, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W']
433 | angle2 = angles[:, 1, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W']
434 | angle3 = angles[:, 2, :, :].unsqueeze(1).unsqueeze(1) # shape [N, 1, 1, H', W']
435 |
436 | # Define another angle halfway between angle3 and angle1, clockwise from angle3
437 | # This isn't critical but it seems a bit more stable for computing gradients
438 | angle4 = 0.5 * (angle1 + angle3) + \
439 | torch.where(torch.remainder(0.5 * (angle1 - angle3), 2 * np.pi) >= np.pi,
440 | torch.ones_like(angle1) * np.pi, torch.zeros_like(angle1))
441 |
442 | def g(dtheta):
443 | # Map from [0, 2pi] to [-1, 1]
444 | return (dtheta / np.pi - 1.0) ** 35
445 |
446 | # Compute the two distance functions
447 | sgn42 = torch.where(torch.remainder(angle2 - angle4, 2 * np.pi) < np.pi,
448 | torch.ones_like(angle2), -torch.ones_like(angle2))
449 | tau42 = g(torch.remainder(angle2 - angle4, 2*np.pi)) * tau
450 |
451 | dist42 = sgn42 * torch.min( sgn42 * (-torch.sin(angle4) * (self.x - x0) + torch.cos(angle4) * (self.y - y0)),
452 | -sgn42 * (-torch.sin(angle2) * (self.x - x0) + torch.cos(angle2) * (self.y - y0))) + tau42
453 |
454 | sgn13 = torch.where(torch.remainder(angle3 - angle1, 2 * np.pi) < np.pi,
455 | torch.ones_like(angle3), -torch.ones_like(angle3))
456 | tau13 = g(torch.remainder(angle3 - angle1, 2*np.pi)) * tau
457 | dist13 = sgn13 * torch.min( sgn13 * (-torch.sin(angle1) * (self.x - x0) + torch.cos(angle1) * (self.y - y0)),
458 | -sgn13 * (-torch.sin(angle3) * (self.x - x0) + torch.cos(angle3) * (self.y - y0))) + tau13
459 |
460 | return torch.stack([dist13, dist42], dim=1)
461 |
462 | def dists2indicators(self, dists):
463 | """
464 | Computes the indicator functions u_1, u_2, u_3 from the distance functions d_{13}, d_{12}
465 |
466 | Inputs
467 | ------
468 | dists Tensor of shape [N, 2, R, R, H', W'] with samples of the two distance functions for every patch
469 |
470 | Outputs
471 | -------
472 | Tensor of shape [N, 3, R, R, H', W'] with samples of the three indicator functions for every patch
473 | """
474 | # Apply smooth Heaviside function to distance functions
475 | hdists = 0.5 * (1.0 + (2.0 / np.pi) * torch.atan(dists / self.opts.eta))
476 |
477 | # Convert Heaviside functions into wedge indicator functions
478 | return torch.stack([1.0 - hdists[:, 0, :, :, :, :],
479 | hdists[:, 0, :, :, :, :] * (1.0 - hdists[:, 1, :, :, :, :]),
480 | hdists[:, 0, :, :, :, :] * hdists[:, 1, :, :, :, :]], dim=1)
481 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def patchstack(patches, border=2, padvalue=1.0):
7 | """
8 | Stack field of patches into one large image.
9 |
10 | Inputs
11 | ------
12 | patches Tensor of shape [..., R, R, H', W']
13 | border Space (in pixels) between neighboring patches (integer)
14 | padvalue Value to fill space with
15 |
16 | Outputs
17 | -------
18 | Tensor of shape [..., (R+border)*H'-border, (R+border)*W'-border], containing stacked patches
19 |
20 | """
21 | assert border % 2 == 0, f"border must be even (but got {border})"
22 |
23 | # Pad 3rd and 4th to last dimensions with border//2 pixels valued `padvalue`.
24 | padamt = (0, 0, 0, 0, border//2, border//2, border//2, border//2)
25 | padded = torch.nn.functional.pad(patches, padamt, value=padvalue).detach().cpu()
26 |
27 | permutation = list(range(len(patches.shape)))
28 | permutation[-4] = -2
29 | permutation[-3] = -4
30 | permutation[-2] = -1
31 | permutation[-1] = -3
32 |
33 | new_shape = list(padded.shape[:-2])
34 | new_shape[-2] *= padded.shape[-2]
35 | new_shape[-1] *= padded.shape[-1]
36 |
37 | output = padded.permute(permutation).contiguous().view(new_shape)
38 |
39 | return output
40 |
41 |
42 | def tile(a, dim, n_tile):
43 | """
44 | Tile tensor a along dimension `dim` with `n_tile` repeats.
45 |
46 | Written by Edouard360:
47 | https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/4
48 | """
49 | init_dim = a.size(dim)
50 | repeat_idx = [1] * a.dim()
51 | repeat_idx[dim] = n_tile
52 | a = a.repeat(*(repeat_idx))
53 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
54 | return torch.index_select(a, dim, order_index)
55 |
56 |
--------------------------------------------------------------------------------