├── LICENSE ├── README.md ├── assets ├── city_depth.png ├── depth_out.png └── edges_out.png ├── heatmethod └── __init__.py └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 jakericedesigns 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Heat-Method 2 | A (crude) differentiable implementation of [The Heat Method for Distance Computation, by Crane et al.](https://www.cs.cmu.edu/~kmcrane/Projects/HeatMethod/) 3 | 4 | I didn't wrap it up into a proper pytorch method, because I'm lazy and also don't fully understand all of that junk. But this works, and seems to generate valid gradients during backprop. 5 | 6 | 7 | ## Dependencies 8 | - Pytorch 9 | - Kornia (optional, used in the demo for edge extraction) 10 | 11 | 12 | ## Author's Notes 13 | 14 | The `heat_method` function performs best when fed a single channel image tensor, where the edges to find the distance to, are a constant value, and everything else is set to 0. 15 | 16 | It uses the jacobi method for solving the linear systems, there are plenty of better/faster ways of solving them (FFTs, torch.linalg.solve), but once again, I'm lazy. 17 | 18 | --- 19 | 20 | ### Example Input: 21 | 22 | ![Input Edges](https://github.com/jakericedesigns/Pytorch-Heat-Method/blob/main/assets/edges_out.png) 23 | 24 | ### Example Output: 25 | 26 | ![Output Distance](https://github.com/jakericedesigns/Pytorch-Heat-Method/blob/main/assets/depth_out.png) 27 | -------------------------------------------------------------------------------- /assets/city_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakericedesigns/Pytorch-Heat-Method/19a9c22cac7ba3e032c0c3470d97cede2acb7090/assets/city_depth.png -------------------------------------------------------------------------------- /assets/depth_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakericedesigns/Pytorch-Heat-Method/19a9c22cac7ba3e032c0c3470d97cede2acb7090/assets/depth_out.png -------------------------------------------------------------------------------- /assets/edges_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakericedesigns/Pytorch-Heat-Method/19a9c22cac7ba3e032c0c3470d97cede2acb7090/assets/edges_out.png -------------------------------------------------------------------------------- /heatmethod/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | 4 | # A crude pytorch implementation of The Heat Method for Distance Computation (Geodesics in Heat), by Crane Et Al. 5 | # https://www.cs.cmu.edu/~kmcrane/Projects/HeatMethod/ 6 | 7 | # I find the easiest reference for finite differences related stuff is this: 8 | # https://developer.download.nvidia.com/books/HTML/gpugems/gpugems_ch38.html 9 | # Its easy to find since u just have to google gpu gems fluids :) 10 | 11 | def solve_poisson(img, iters=1000): 12 | # Solve Δx=b, a regular poisson eq. 13 | # A=Δ 14 | # A = D + L + U (D is the diagonal of the laplace matrix, L and U are the lower and upper parts) 15 | 16 | # solving for x (jacobi method): x_1 = D^-1 * (b - (L + U) * x_0) 17 | 18 | # which ultimately takes the form of this, when plugging our stuff in: 19 | #(solution - convolve(x, LU)) / -4 20 | 21 | LU = torch.tensor([[[[0, 1., 0], 22 | [1, 0, 1], 23 | [0, 1, 0]]]], dtype=img.dtype).to(img.device) 24 | 25 | 26 | BC = nn.ReplicationPad2d(1) #boundary condition 27 | solution = img 28 | for i in range(iters): 29 | solution = (img - nn.functional.conv2d(BC(solution), LU)) / -4.0 30 | return solution 31 | 32 | def screened_poisson(img, timestep=.1, mass=.01, iters=1000): 33 | # Solve (M - t*Δ)x=b, a screened poisson eq. 34 | # Set A = (M - t*Δ) 35 | # A = D + L + U (D is the diagonal of the laplace matrix, L and U are the lower and upper parts) 36 | 37 | # solving for x (jacobi method): x_1 = D^-1 * (b - (L + U) * x_0) 38 | 39 | # M is a mass matrix, which has only diagonal entries. 40 | 41 | # Since M - t*A is our right hand side 42 | # M - D*t = M - (D * t), 43 | # M - (L + U)*t = 0 - (L + U) * t 44 | 45 | 46 | # which ultimately takes the form of this, when plugging our stuff in: 47 | # (solution - convolve(x, -LU * t)) / (M - (-4 * t)) 48 | 49 | LU = torch.tensor([[[[0, -1., 0], 50 | [-1, 0, -1], 51 | [0, -1, 0]]]], dtype=img.dtype).to(img.device) 52 | 53 | 54 | BC = nn.ReplicationPad2d(1) #boundary condition 55 | solution = img 56 | for i in range(iters): 57 | solution = (img - nn.functional.conv2d(BC(solution), LU * timestep)) / (mass + 4.0 * timestep) 58 | return solution 59 | 60 | 61 | def finite_diff_grad(img): 62 | #expects a 1d input: B,1,H,W 63 | kernel_x = torch.tensor([[[[0, 0., 0], 64 | [1, 0, -1], 65 | [0, 0, 0]]]], dtype=img.dtype).to(img.device) 66 | 67 | kernel_y= torch.tensor([[[[0, 1, 0], 68 | [0, 0, 0], 69 | [0, -1, 0]]]], dtype=img.dtype).to(img.device) 70 | div_x = nn.functional.conv2d(img, kernel_x) 71 | div_y = nn.functional.conv2d(img, kernel_y) 72 | return torch.cat((div_x, div_y), 1) / 2.0 73 | 74 | def finite_diff_div(grad): 75 | #expects a 2d input B,2,H,W 76 | kernel_x = torch.tensor([[[[0, 0., 0], 77 | [1, 0, -1], 78 | [0, 0, 0]]]], dtype=grad.dtype).to(grad.device) 79 | kernel_y= torch.tensor([[[[0, 1, 0], 80 | [0, 0, 0], 81 | [0, -1, 0]]]], dtype=grad.dtype).to(grad.device) 82 | 83 | #there's for sure a better way to do this indexing stuff, but i'm bad at numpy style indexing 84 | div_x = nn.functional.conv2d(grad[:,0:1,...], kernel_x) 85 | div_y = nn.functional.conv2d(grad[:,1:,...], kernel_y) 86 | return (div_x + div_y) / 4.0 87 | 88 | def heat_method(image, timestep=1.0, mass=.01, iters_diffusion=500, iters_poisson=1000): 89 | heat = screened_poisson(image, timestep, mass, iters_diffusion) #diffusion 90 | grad = finite_diff_grad(heat) * -1 #inverted gradient 91 | grad = grad / grad.norm(dim=1) #normalize gradient 92 | div = finite_diff_div(grad) 93 | distance = solve_poisson(div, iters_poisson) 94 | return distance 95 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as TF 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import kornia 7 | import heatmethod as heat 8 | 9 | 10 | def fitrange(x): 11 | c_max = torch.max(x) 12 | c_min = torch.min(x) 13 | 14 | return (x - c_min) / (c_max - c_min) 15 | 16 | if __name__ == "__main__": 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | print('Using device:', device) 19 | 20 | #change me :) 21 | init_image = "assets/city_depth.png" 22 | 23 | 24 | pil_image = Image.open(init_image).convert('RGB') 25 | image = TF.to_tensor(pil_image).to(device).unsqueeze(0) 26 | 27 | #extract edges from our input images 28 | print("Extracting edges from input picture") 29 | edges = kornia.filters.canny(image)[1] 30 | 31 | 32 | #Genereate the depth map to our input edges 33 | print("Generating distance transform to the extracted edges") 34 | depth_map = heat.heat_method(edges, timestep=0.1, mass=.01, iters_diffusion=500, iters_poisson=1000) 35 | 36 | 37 | #Save shit out 38 | print("Saving images") 39 | TF.to_pil_image(edges[0].cpu()).save('assets/edges_out.png') 40 | TF.to_pil_image(fitrange(depth_map)[0].cpu()).save('assets/depth_out.png') 41 | --------------------------------------------------------------------------------