├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── bilateral.py ├── crfrnn ├── __init__.py └── crf.py ├── images ├── crf_layer_example.png ├── filtered.png └── wimr_small.png ├── make_gfilt_dispatch.py ├── permutohedral ├── __init__.py ├── gfilt.py └── hash.py ├── pyproject.toml ├── sanity_check.py ├── setup.py └── src ├── build_hash.cu ├── build_hash_cuda.h ├── build_hash_kernel.h ├── build_hash_wrapper.cu ├── common.h ├── gfilt_cuda.cu ├── gfilt_cuda.h ├── gfilt_kernel.h ├── gfilt_wrapper.cu ├── hash_fns.cuh └── permutohedral.cpp /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,W,B9 3 | max-line-length = 100 4 | ### DEFAULT IGNORES FOR 4-space INDENTED PROJECTS ### 5 | # Main Explanation Docs: https://lintlyci.github.io/Flake8Rules/ 6 | # 7 | # E127, E128 are hard to silence in certain nested formatting situations. 8 | # E203 doesn't work for slicing 9 | # E265, E266 talk about comment formatting which is too opinionated. 10 | # E402 warns on imports coming after statements. There are important use cases 11 | # like demandimport (https://fburl.com/demandimport) that require statements 12 | # before imports. 13 | # E501 is not flexible enough, we're using B950 instead. 14 | # E722 is a duplicate of B001. 15 | # F811 looks for duplicate imports + noise for overload typing 16 | # P207 is a duplicate of B003. 17 | # P208 is a duplicate of C403. 18 | # W503 talks about operator formatting which is too opinionated. 19 | ignore = E127, E128, E203, E265, E266, E402, E501, E722, F811, P207, P208, W503 20 | ### DEFAULT IGNORES FOR 2-space INDENTED PROJECTS (uncomment) ### 21 | # ignore = E111, E114, E121, E127, E128, E265, E266, E402, E501, P207, P208, W503 22 | exclude = 23 | .git, 24 | .hg, 25 | __pycache__, 26 | _bin/*, 27 | _build/*, 28 | _ig_fbcode_wheel/*, 29 | buck-out/*, 30 | third-party-buck/*, 31 | third-party2/* 32 | 33 | # Calculate max-complexity by changing the value below to 1, then surveying fbcode 34 | # to see the distribution of complexity: 35 | # find ./[a-z0-9]* -name 'buck-*' -prune -o -name 'third*party*' -prune -o \ 36 | # -name '*.py' -print |\ 37 | # parallel flake8 --config ./.flake8 |\ 38 | # perl -ne 'if (/C901/) { s/.*\((\d+)\)/$1/; print; }' | stats 39 | # NOTE: This will take a while to run (near an hour IME) so you probably want a 40 | # second working dir to run it in. 41 | # Pick a reasonable point from there (e.g. p95 or "95%") 42 | # As of 2016-05-18 the rough distribution is: 43 | # 44 | # count: 134807 45 | # min: 2 46 | # max: 206 47 | # avg: 4.361 48 | # median: 3 49 | # sum: 587882 50 | # stddev: 4.317 51 | # variance: 18.635 52 | # 53 | # percentiles: 54 | # 75%: 5 55 | # 90%: 8 56 | # 95%: 11 57 | # 99%: 20 58 | # 99.9%: 48 59 | # 99.99%: 107 60 | # 99.999%: 160 61 | max-complexity = 12 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated code 2 | crfrnn/src/gfilt_dispatch_table.h 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # editor stuff 12 | .vscode/* 13 | !.vscode/settings.json 14 | 15 | # Misc 16 | build 17 | *.swp 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Gabriel Schwartz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRF-as-RNN Layer for [Pytorch](https://github.com/pytorch/pytorch) 2 | 3 | This repository contains an implementation of the CRF-as-RNN method 4 | [described here](http://www.robots.ox.ac.uk/~szheng/CRFasRNN.html). Please cite 5 | their work if you use this in your own code. I am not affiliated with their 6 | group, this is just a side-project. 7 | 8 | The pytorch module relies on two Functions: one to build the hashtable 9 | representing a [permutohedral 10 | lattice](http://graphics.stanford.edu/papers/permutohedral/permutohedral.pdf) 11 | and another to perform the high-dimensional Gaussian filtering required by 12 | approximate CRF inference. 13 | 14 | ## Setup 15 | 16 | For inplace use / testing: 17 | 18 | ``` 19 | python setup.py build_ext --inplace 20 | ``` 21 | 22 | Or, to install the packages (permutohedral, crfrnn): 23 | 24 | ``` 25 | python setup.py install 26 | ``` 27 | 28 | ## Pytorch Module 29 | [![example](images/crf_layer_example.png)](images/crf_layer_example.png) 30 | 31 | The [Pytorch](https://github.com/pytorch/pytorch) module takes two inputs for 32 | the forward pass: a probability map (typically the output of a softmax layer), 33 | and a reference image (typically the image being segmented/densely-classified). 34 | Optional additional parameters may be provided to the module on construction: 35 | 36 | * `sxy_bf`: spatial standard deviation for the bilateral filter. 37 | * `sc_bf`: color standard deviation for the bilateral filter. 38 | * `compat_bf`: label compatibility weight for the bilateral filter. 39 | * `sxy_spatial`: spatial standard deviation for the 2D Gaussian filter. 40 | * `compat_spatial`: label compatibility weight for the 2D Gaussian filter. 41 | 42 | **Note**: the default color standard deviation assumes the input is a color 43 | image in the range [0, 255]. If you use whitened or otherwise-normalized images, 44 | you should change this value. 45 | 46 | Here is a simple example: 47 | 48 | ```python 49 | import torch as th 50 | 51 | from crfrnn import CRF 52 | 53 | n_categories = 32 54 | 55 | class MyCNN(th.nn.Module): 56 | def __init__(self): 57 | super(MyCNN, self).__init__() 58 | self.relu = th.nn.ReLU() 59 | self.conv1 = th.nn.Conv2d(3, 64, 3, 1, 1) 60 | self.conv2 = th.nn.Conv2d(64, 64, 3, 1, 1) 61 | self.final = th.nn.Conv2d(64, n_categories, 3, 1, 1) 62 | self.crf = CRF() 63 | 64 | def forward(self, x): 65 | input = x 66 | x = self.relu(self.conv1(x)) 67 | x = self.relu(self.conv2(x)) 68 | x = th.softmax(self.final(x), dim=1) 69 | x = self.crf(x, input) 70 | return x 71 | 72 | img = th.zeros(1, 3, 384, 512, device="cuda:0") 73 | model = MyCNN() 74 | model.to(device="cuda:0") 75 | model(img) 76 | ``` 77 | 78 | ## Sub-Functions 79 | 80 | The functions used for CRF inference can also be used on their own for things 81 | like bilateral filtering. [bilateral.py](bilateral.py) contains a sample 82 | implementation. 83 | 84 | `python bilateral.py input.png output.png 20 0.25` 85 | 86 | 87 | -------------------------------------------------------------------------------- /bilateral.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import torch as th 5 | import cv2 6 | 7 | from permutohedral.gfilt import gfilt 8 | 9 | def gaussian_filter(ref, val, kstd): 10 | return gfilt(ref / kstd[None, :, None, None], val) 11 | 12 | def usage(): 13 | print("Usage: python bilateral.py input output sxy srgb") 14 | exit(1) 15 | 16 | if len(sys.argv) != 5: 17 | usage() 18 | 19 | try: 20 | sxy = float(sys.argv[3]) 21 | srgb = float(sys.argv[4]) 22 | except: 23 | usage() 24 | 25 | img = cv2.imread(sys.argv[1]).astype(np.float32)[..., :3] / 255. 26 | img = img.transpose(2, 0, 1) 27 | yx = np.mgrid[:img.shape[1], :img.shape[2]].astype(np.float32) 28 | stacked = np.vstack([yx, img]) 29 | 30 | img = th.from_numpy(img).cuda() 31 | stacked = th.from_numpy(stacked).cuda() 32 | kstd = th.FloatTensor([sxy, sxy, srgb, srgb, srgb]).cuda() 33 | 34 | filtered = gaussian_filter(stacked[None], img[None], kstd)[0] 35 | filtered = (255 * filtered).permute(1, 2, 0).byte().data.cpu().numpy() 36 | cv2.imwrite(sys.argv[2], filtered) 37 | -------------------------------------------------------------------------------- /crfrnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/crfrnn/__init__.py -------------------------------------------------------------------------------- /crfrnn/crf.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as thf 4 | 5 | from permutohedral.gfilt import gfilt 6 | 7 | 8 | def gaussian_filter(ref, val, kstd): 9 | return gfilt(ref / kstd[:, :, None, None], val) 10 | 11 | 12 | def mgrid(h, w, dev): 13 | y = th.arange(0, h, device=dev).repeat(w, 1).t() 14 | x = th.arange(0, w, device=dev).repeat(h, 1) 15 | return th.stack([y, x], 0) 16 | 17 | 18 | def gkern(std, chans, dev): 19 | sig_sq = std ** 2 20 | r = sig_sq if (sig_sq % 2) else sig_sq - 1 21 | s = 2 * r + 1 22 | k = th.exp(-((mgrid(s, s, dev) - r) ** 2).sum(0) / (2 * sig_sq)) 23 | W = th.zeros(chans, chans, s, s, device=dev) 24 | for i in range(chans): 25 | W[i, i] = k / k.sum() 26 | return W 27 | 28 | 29 | class CRF(nn.Module): 30 | def __init__( 31 | self, 32 | n_ref: int, 33 | n_out: int, 34 | sxy_bf: float = 70, 35 | sc_bf: float = 12, 36 | compat_bf: float = 4, 37 | sxy_spatial: float = 6, 38 | compat_spatial: float = 2, 39 | num_iter: int = 5, 40 | normalize_final_iter: bool = True, 41 | trainable_kstd: bool = False, 42 | ): 43 | """Implements fast approximate mean-field inference for a 44 | fully-connected CRF with Gaussian edge potentials within a neural 45 | network layer using fast bilateral filtering. 46 | 47 | Args: 48 | n_ref: Number of channels in the reference images. 49 | 50 | n_out: Number of labels. 51 | 52 | sxy_bf: Spatial standard deviation of the bilateral filter. 53 | 54 | sc_bf: Color standard deviation of the bilateral filter. 55 | 56 | compat_bf: Label compatibility weight for the bilateral filter. 57 | Assumes a Potts model w/one parameter. 58 | 59 | sxy_spatial: Spatial standard deviation of the 2D Gaussian 60 | convolution kernel. 61 | 62 | compat_spatial: Label compatibility weight of the 2D Gaussian 63 | convolution kernel. 64 | 65 | num_iter: Number of steps to run in the inference loop. 66 | 67 | normalize_final_iter: If pre-softmax outputs are desired rather 68 | than label probabilities, set this to False. 69 | 70 | trainable_kstd: Allow the parameters of the bilateral filter to be 71 | learned as well. This option may make training less stable. 72 | """ 73 | assert n_ref in {1, 3}, "Reference image must be either RGB or greyscale (3 or 1 channels)." 74 | 75 | super().__init__() 76 | 77 | self.n_ref = n_ref 78 | self.n_out = n_out 79 | self.sxy_bf = sxy_bf 80 | self.sc_bf = sc_bf 81 | self.compat_bf = compat_bf 82 | self.sxy_spatial = sxy_spatial 83 | self.compat_spatial = compat_spatial 84 | self.num_iter = num_iter 85 | self.normalize_final_iter = normalize_final_iter 86 | self.trainable_kstd = trainable_kstd 87 | 88 | kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf]) 89 | if n_ref == 1: 90 | kstd = kstd[:3] 91 | 92 | if trainable_kstd: 93 | self.kstd = nn.Parameter(kstd) 94 | else: 95 | self.register_buffer("kstd", kstd) 96 | 97 | self.register_buffer("gk", gkern(sxy_spatial, n_out)) 98 | 99 | def forward(self, unary, ref): 100 | def _bilateral(V, R): 101 | return gaussian_filter(R, V, self.kstd[None]) 102 | 103 | def _step(prev_q, U, ref, normalize=True): 104 | qbf = _bilateral(prev_q, ref) 105 | qsf = thf.conv2d(prev_q, self.gk, padding=self.gk.shape[-1] // 2) 106 | q_hat = -self.compat_bf * qbf - self.compat_spatial * qsf 107 | q_hat = U - q_hat 108 | return th.softmax(q_hat, dim=1) if normalize else q_hat 109 | 110 | def _inference(unary, ref): 111 | U = th.log(th.clamp(unary, 1e-5, 1)) 112 | prev_q = th.softmax(U, dim=1) 113 | 114 | for i in range(self.num_iter): 115 | normalize = self.normalize_final_iter or i < self.num_iter - 1 116 | prev_q = _step(prev_q, U, ref, normalize=normalize) 117 | return prev_q 118 | 119 | N, _, H, W = unary.shape 120 | yx = mgrid(H, W, unary.device) 121 | grid = yx[None].repeat(N, 1, 1, 1) 122 | stacked = th.cat([grid, ref], dim=1) 123 | 124 | return _inference(unary, stacked) 125 | -------------------------------------------------------------------------------- /images/crf_layer_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/crf_layer_example.png -------------------------------------------------------------------------------- /images/filtered.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/filtered.png -------------------------------------------------------------------------------- /images/wimr_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/wimr_small.png -------------------------------------------------------------------------------- /make_gfilt_dispatch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | base_call = "_call_gfilt_kernels<%d, %d>(values, output, tmp_vals_1, tmp_vals_2, hash_entries, hash_keys, neib_ents, barycentric, valid_entries, n_valid, hash_cap, N, reverse, stream);" 4 | 5 | if __name__ == "__main__": 6 | args = sys.argv[1:] 7 | if len(args) == 2: 8 | max_ref_dim, max_val_dim = args 9 | ref_dims = range(1, max_ref_dim+1) 10 | val_dims = range(1, max_val_dim+1) 11 | else: 12 | assert(len(args) == 0) 13 | ref_dims = [2, 6] 14 | val_dims = range(1, 16) 15 | 16 | print("switch(1000 * ref_dim + val_dim) {") 17 | for rdim in ref_dims: 18 | for vdim in val_dims: 19 | print("\tcase %d:" % (1000 * rdim + vdim)) 20 | print("\t\t" + (base_call % (rdim, vdim))) 21 | print("\t\tbreak;") 22 | print("\tdefault:") 23 | print("\t\tprintf(\"Unsupported ref_dim/val_dim combination (%zd, %zd), generate a new dispatch table using 'make_gfilt_dispatch.py'.\\n\", ref_dim, val_dim);") 24 | print("\t\texit(-1);") 25 | print("}") 26 | -------------------------------------------------------------------------------- /permutohedral/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/permutohedral/__init__.py -------------------------------------------------------------------------------- /permutohedral/gfilt.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | from .hash import make_hashtable 4 | import permutohedral_ext 5 | 6 | th.ops.load_library(permutohedral_ext.__file__) 7 | gfilt_cuda = th.ops.permutohedral_ext.gfilt_cuda 8 | 9 | 10 | def make_gfilt_buffers(b, val_dim, h, w, cap, dev): 11 | return [ 12 | th.zeros(b, val_dim, h, w, device=dev), # output 13 | th.empty(cap, val_dim + 1, device=dev), # tmp_vals_1 14 | th.empty(cap, val_dim + 1, device=dev), # tmp_vals_2 15 | ] 16 | 17 | 18 | class GaussianFilter(th.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, ref, val, _hash_buffers=None, _gfilt_buffers=None): 21 | val = val.contiguous() 22 | b, ref_dim, h, w = ref.shape 23 | vb, val_dim, vh, vw = val.shape 24 | assert vb == b and vh == h and vw == w 25 | 26 | if _hash_buffers is None: 27 | hash_buffers = make_hashtable(ref) 28 | else: 29 | hash_buffers = list(_hash_buffers) 30 | hash_buffers[-1] = hash_buffers[-1].cpu() 31 | 32 | assert hash_buffers[0].shape[0] == b 33 | cap = hash_buffers[0].shape[1] 34 | 35 | if _gfilt_buffers is None: 36 | gfilt_buffers = make_gfilt_buffers(b, val_dim, h, w, cap, val.device) 37 | else: 38 | gfilt_buffers = list(_gfilt_buffers) 39 | 40 | if val.is_cuda: 41 | gfilt_cuda(val, *gfilt_buffers, *hash_buffers, ref_dim, False) 42 | else: 43 | raise NotImplementedError("Gfilt currently requires CUDA support.") 44 | # gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False) 45 | 46 | out = gfilt_buffers[0] 47 | 48 | if ref.requires_grad: 49 | ctx.save_for_backward(ref, val, out, *hash_buffers) 50 | elif val.requires_grad: 51 | ctx.save_for_backward(ref, val, *hash_buffers) 52 | 53 | return out 54 | 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | grads = [None, None, None, None] 58 | 59 | ref = ctx.saved_tensors[0] 60 | val = ctx.saved_tensors[1] 61 | hash_buffers = list(ctx.saved_tensors[-6:]) 62 | 63 | b, ref_dim = ref.shape[:2] 64 | val_dim, h, w = val.shape[-3:] 65 | assert hash_buffers[0].shape[0] == b 66 | cap = hash_buffers[0].shape[1] 67 | 68 | def filt(v): 69 | if not v.is_contiguous(): 70 | v = v.contiguous() 71 | gfilt_buffers = make_gfilt_buffers(b, val_dim, h, w, cap, v.device) 72 | gfilt_cuda(v, *gfilt_buffers, *hash_buffers, ref_dim, True) 73 | return gfilt_buffers[0] 74 | 75 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 76 | filt_og = filt(grad_output) 77 | 78 | if ctx.needs_input_grad[0]: 79 | out = ctx.saved_tensors[2] 80 | 81 | grads[0] = th.stack( 82 | [ 83 | (grad_output * (filt(val * r_i) - r_i * out)) 84 | + (val * (filt(grad_output * r_i) - r_i * filt_og)) 85 | for r_i in ref.split(1, dim=1) 86 | ], 87 | dim=1, 88 | ).sum(dim=2) 89 | 90 | if ctx.needs_input_grad[1]: 91 | grads[1] = filt_og 92 | 93 | return grads[0], grads[1], grads[2], grads[3] 94 | 95 | 96 | gfilt = GaussianFilter.apply 97 | -------------------------------------------------------------------------------- /permutohedral/hash.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | import permutohedral_ext 4 | 5 | th.ops.load_library(permutohedral_ext.__file__) 6 | build_hash_cuda = th.ops.permutohedral_ext.build_hash_cuda 7 | 8 | 9 | def make_hash_buffers(b, dim, h, w, cap, dev): 10 | return [ 11 | -th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries 12 | th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys 13 | th.zeros(b, dim + 1, h, w, dtype=th.int32, device=dev), # neib_ents 14 | th.zeros(b, dim + 1, h, w, device=dev), # barycentric 15 | th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries 16 | th.zeros(b, 1).int().to(device=dev), # n_valid_entries 17 | ] 18 | 19 | 20 | def get_hash_cap(N, dim): 21 | return N * (dim + 1) 22 | 23 | 24 | def make_hashtable(points): 25 | b, dim, h, w = points.shape 26 | N = h * w 27 | cap = get_hash_cap(N, dim) 28 | 29 | buffers = make_hash_buffers(b, dim, h, w, cap, points.device) 30 | if points.is_cuda: 31 | build_hash_cuda(points.contiguous(), *buffers) 32 | else: 33 | raise NotImplementedError("Hash table currently requires CUDA support.") 34 | # build_hash_cpu(points.contiguous(), *buffers, cap) 35 | 36 | return buffers 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line_length = 100 3 | skip-string-normalization = true 4 | target_version = ['py36', 'py37', 'py38'] 5 | include = '\.pyi?$' 6 | exclude = ''' 7 | /( 8 | .*\.json 9 | | \.eggs 10 | | \.git 11 | | \.hg 12 | | \.mypy_cache 13 | | \.nox 14 | | \.tox 15 | | \.venv 16 | | _build 17 | | buck-out 18 | | build 19 | | dist 20 | )/ 21 | ''' 22 | -------------------------------------------------------------------------------- /sanity_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch as th 6 | from tqdm import tqdm 7 | 8 | from permutohedral.gfilt import gfilt 9 | 10 | ############################################################################### 11 | # Perform a simple sanity check of the filtering / hashing components. Creates 12 | # a circle filled with a checkerboard pattern, blurs it using the circle's mask 13 | # + XY coords as a reference image, and then optimizes an initial random image 14 | # + initial kernel std. values through the filter to match the target image. 15 | ############################################################################### 16 | 17 | 18 | def gaussian_filter(ref, val, kstd): 19 | return gfilt(ref / kstd[:, :, None, None], val) 20 | 21 | 22 | sxy = 3 23 | srgb = 0.25 24 | h, w = 512, 512 25 | 26 | if __name__ == "__main__": 27 | np.random.seed(0) 28 | th.manual_seed(0) 29 | 30 | yx = np.mgrid[:h, :w].astype(np.float32) 31 | yx = th.from_numpy(yx).cuda() 32 | kstd = th.FloatTensor([sxy, sxy, srgb]).cuda() 33 | 34 | tgt = np.zeros((512, 512, 3), np.uint8) 35 | mask = np.zeros((512, 512), np.uint8) 36 | cv2.circle(mask, (256, 256), 128, 255, -1) 37 | mask = mask > 0 38 | 39 | # Make a simple checkerboard texture. 40 | color_1 = [255, 128, 32] 41 | color_2 = [128, 255, 32] 42 | for i in range(8): 43 | for j in range(8): 44 | tgt[8 + i :: 16, 8 + j :: 16] = color_1 45 | tgt[i::16, j::16] = color_1 46 | tgt[8 + i :: 16, j::16] = color_2 47 | tgt[i::16, 8 + j :: 16] = color_2 48 | tgt[~mask] = 0 49 | tgt = th.from_numpy(tgt.transpose(2, 0, 1).copy()).cuda().float() / 255 50 | prefilt = tgt.clone() 51 | prefilt_np = (255 * prefilt).byte().data.cpu().numpy().transpose(1, 2, 0) 52 | 53 | # Create the filtered target image. 54 | stacked = th.cat([yx, th.from_numpy(mask[None]).cuda().float()], dim=0) 55 | tgt = gaussian_filter(stacked[None], tgt[None], kstd[None])[0] 56 | tgt_np = (255 * tgt).byte().data.cpu().numpy().transpose(1, 2, 0) 57 | 58 | # Create a random initial image that will be optimized. 59 | img_est = (0.5 * th.rand(3, h, w).cuda() + 0.25).requires_grad_(True) 60 | kstd_orig = kstd.clone() 61 | kstd[:] = 8 62 | kstd.requires_grad_(True) 63 | optim = th.optim.Adam([img_est, kstd], lr=1e-3) 64 | 65 | if not os.path.exists("sanity_imgs"): 66 | os.mkdir("sanity_imgs") 67 | 68 | for it in tqdm(range(8000)): 69 | filt = gaussian_filter(stacked[None], img_est[None], kstd[None])[0] 70 | diff = (filt - tgt) ** 2 71 | loss = diff.mean() 72 | 73 | optim.zero_grad() 74 | loss.backward() 75 | optim.step() 76 | 77 | # If left unconstrained, the image can take on negative values. 78 | with th.no_grad(): 79 | img_est.clamp_(0, 1) 80 | 81 | if it == 4000: 82 | for pg in optim.param_groups: 83 | pg["lr"] /= 2 84 | 85 | if it % 25 == 0: 86 | with th.no_grad(): 87 | filt = gaussian_filter(stacked[None], img_est[None], kstd[None])[0] 88 | diff = (filt - tgt) ** 2 89 | img = th.cat( 90 | [ 91 | img_est.clamp(0, 1), 92 | filt.clamp(0, 1), 93 | (abs(diff) / diff.max().clamp(min=1e-2)).clamp(0, 1), 94 | ], 95 | dim=2, 96 | ) 97 | img = (255 * img).byte().data.cpu().numpy().transpose(1, 2, 0) 98 | img = np.vstack([img, np.hstack([prefilt_np, tgt_np, np.zeros_like(tgt_np)])]).copy() 99 | 100 | color = (255, 255, 255) 101 | font = cv2.FONT_HERSHEY_SIMPLEX 102 | cv2.putText(img, "Input (estimated)", (32, 64), font, 1, color, 1) 103 | cv2.putText(img, "Filtered (estimated)", (32 + 512, 64), font, 1, color, 1) 104 | cv2.putText(img, "Diff (filt_est - filt_gt)x100", (32 + 2 * 512, 64), font, 1, color, 1) 105 | cv2.putText(img, "Input (GT)", (32, 64 + 512), font, 1, color, 1) 106 | cv2.putText(img, "Filtered (GT)", (32 + 512, 64 + 512), font, 1, color, 1) 107 | cv2.imwrite(f"sanity_imgs/{it:06d}.png", img) 108 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | GFILT_CALL = "_call_gfilt_kernels<{ref_dim}, {val_dim}>(values, output, tmp_vals_1, tmp_vals_2, hash_entries, hash_keys, neib_ents, barycentric, valid_entries, n_valid, hash_cap, N, reverse, stream);" 7 | def make_gfilt_dispatch_table(fname, ref_dims=range(2, 6), val_dims=range(1, 16)): 8 | with open(fname, "w") as f: 9 | f.write("switch(1000 * ref_dim + val_dim) {\n") 10 | for rdim in ref_dims: 11 | for vdim in val_dims: 12 | f.write(f"\tcase {1000 * rdim + vdim}:\n") 13 | f.write("\t\t" + GFILT_CALL.format(ref_dim=rdim, val_dim=vdim) + "\n") 14 | f.write("\t\tbreak;\n") 15 | f.write("\tdefault:\n") 16 | f.write("\t\tprintf(\"Unsupported ref_dim/val_dim combination (%zd, %zd), generate a new dispatch table using 'make_gfilt_dispatch.py'.\\n\", ref_dim, val_dim);\n") 17 | f.write("\t\texit(-1);\n") 18 | f.write("}\n") 19 | 20 | if __name__ == "__main__": 21 | table_fn = "src/gfilt_dispatch_table.h" 22 | if not os.path.exists(table_fn) or (os.path.getmtime(table_fn) < os.path.getmtime(__file__)): 23 | make_gfilt_dispatch_table(table_fn) 24 | 25 | cxx_args = ["-O3", "-fopenmp", "-std=c++14"] 26 | nvcc_args = ["-O3"] 27 | if "CC" in os.environ: 28 | nvcc_args.append("-ccbin=" + os.path.dirname(os.environ.get("CC"))) 29 | 30 | setup( 31 | name="permutohedral", 32 | version="0.4", 33 | description="", 34 | url="", 35 | author="Gabriel Schwartz", 36 | author_email="gbschwartz@gmail.com", 37 | ext_modules=[ 38 | CUDAExtension( 39 | "permutohedral_ext", 40 | [ 41 | "src/permutohedral.cpp", 42 | "src/build_hash_wrapper.cu", 43 | "src/gfilt_wrapper.cu", 44 | "src/gfilt_cuda.cu", 45 | "src/build_hash.cu" 46 | ], 47 | extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args}, 48 | ) 49 | ], 50 | cmdclass={"build_ext": BuildExtension}, 51 | packages=["permutohedral"], 52 | ) 53 | 54 | setup(name="crfrnn", packages=["crfrnn"]) 55 | -------------------------------------------------------------------------------- /src/build_hash.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "hash_fns.cuh" 4 | #include "build_hash_kernel.h" 5 | 6 | template 7 | __device__ int hash_insert(int* entries, short* keys, size_t capacity, const short* key) { 8 | unsigned int h = hash(key) % capacity; 9 | const unsigned int init_h = h; 10 | const size_t max_it = min(MIN_QUAD_PROBES, capacity); 11 | int* entry = &entries[h]; 12 | 13 | // First try quadratic probing for a fixed number of iterations, then fall 14 | // back to linear probing. 15 | for(int iter=0; iter= 0 && key_cmp(&keys[ret * dim], key)){ 25 | // If another thread already inserted the same key, return the 26 | // entry for that key. 27 | return ret; 28 | } 29 | 30 | h = (init_h + iter * iter) % capacity; 31 | } 32 | 33 | for(int iter=0; iter= 0 && key_cmp(&keys[ret * dim], key)){ 43 | // If another thread already inserted the same key, return the 44 | // entry for that key. 45 | return ret; 46 | } 47 | 48 | h = (h + 1) % capacity; 49 | } 50 | 51 | // We wrapped around without finding a free slot, the table is full. 52 | return -1; 53 | } 54 | 55 | constexpr float root_two_thirds = 0.81649658092f; 56 | 57 | template 58 | __device__ __inline__ void embed(const float* f, float* e, size_t N) { 59 | constexpr float sf = root_two_thirds * ((float)dim + 1.f); 60 | 61 | e[dim] = -sqrt((float)dim / ((float)dim + 1.f)) * f[N * (dim - 1)] * sf; 62 | for(int i=dim - 1; i > 0; --i) { 63 | e[i] = sf * f[N * i] / sqrt((i + 1.f) / (i + 2.f)) + e[i + 1] - sqrt(i / (i + 1.f)) * sf * f[N * (i - 1)]; 64 | } 65 | e[0] = sf * f[0] / sqrt(0.5f) + e[1]; 66 | } 67 | 68 | template 69 | __device__ __inline__ short round2mult(float f) { 70 | const float s = f / (dim + 1.f); 71 | const float lo = floor(s) * ((float)dim + 1.f); 72 | const float hi = ceil(s) * ((float)dim + 1.f); 73 | return ((hi - f) > (f - lo)) ? lo : hi; 74 | } 75 | 76 | template 77 | __device__ __inline__ void ranksort_diff(const float* v1, const short* v2, short* rank) { 78 | for(int i=0; i<=dim; ++i) { 79 | rank[i] = 0; 80 | const float di = v1[i] - v2[i]; 81 | for(int j=0; j<=dim; ++j) { 82 | const float dj = v1[j] - v2[j]; 83 | if (di < dj || (di==dj && i>j)) { ++rank[i]; } 84 | } 85 | } 86 | } 87 | 88 | template 89 | __device__ __inline__ short canonical_coord(int k, int i) { 90 | return (i < (dim + 1 - k)) ? k : (k - (dim + 1)); 91 | } 92 | 93 | template 94 | __global__ void build_hash(const float* points, 95 | int* hash_entries, 96 | short* hash_keys, 97 | int* neib_ents, 98 | float* barycentric, 99 | size_t hash_cap, 100 | size_t N) { 101 | 102 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 103 | if (idx < N) { 104 | short rounded[dim + 1]; 105 | short rank[dim + 1]; 106 | float embedded[dim + 1]; 107 | 108 | const float* p = &points[idx]; 109 | float* b = &barycentric[idx]; 110 | 111 | embed(p, embedded, N); 112 | 113 | short sum = 0; 114 | for(int i=0; i<=dim; ++i) { 115 | const short r = round2mult(embedded[i]); 116 | rounded[i] = r; 117 | sum += r; 118 | } 119 | sum /= (short)dim + 1; 120 | 121 | // Compute rank(embedded - rounded), decreasing order 122 | ranksort_diff(embedded, rounded, rank); 123 | 124 | // Walk the point back onto H_d (Lemma 2.9 in permutohedral_techreport.pdf) 125 | for (int i = 0; i <= dim; i++) { 126 | rank[i] += sum; 127 | if (rank[i] < 0) { 128 | rank[i] += (short)dim + 1; 129 | rounded[i] += (short)dim + 1; 130 | } else if (rank[i] > dim) { 131 | rank[i] -= (short)dim + 1; 132 | rounded[i] -= (short)dim + 1; 133 | } 134 | } 135 | 136 | // The temporary key has 1 extra dimension. Normally we ignore the last 137 | // entry in the key because they sum to 0, but we can use the key as swap 138 | // space to invert the sorting permutation. 139 | short key[dim]; 140 | for(int k=0; k<=dim; ++k) { 141 | for(int i=0; i(k, rank[i]) + rounded[i]; 143 | } 144 | 145 | const int ind = hash_insert(hash_entries, hash_keys, hash_cap, key); 146 | assert(ind >= 0); 147 | neib_ents[idx + N * k] = ind; 148 | } 149 | 150 | float bar_tmp[dim + 2]{0}; 151 | // Compute the barycentric coordinates (p.10 in [Adams etal 2010]) 152 | for (int i = 0; i <= dim; ++i) { 153 | const float delta = (embedded[i] - rounded[i]) * (1.f / ((float)dim + 1)); 154 | bar_tmp[dim - rank[i]] += delta; 155 | bar_tmp[dim + 1 - rank[i]] -= delta; 156 | } 157 | // Wrap around 158 | bar_tmp[0] += 1.0 + bar_tmp[dim + 1]; 159 | 160 | for (int i = 0; i <= dim; ++i) { 161 | b[N * i] = bar_tmp[i]; 162 | } 163 | } 164 | } 165 | 166 | template 167 | __global__ void dedup(int* hash_entries, short* hash_keys, size_t hash_cap) { 168 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 169 | if (idx < hash_cap) { 170 | int& e = hash_entries[idx]; 171 | if(e >= 0) { 172 | const short* key = &hash_keys[idx * dim]; 173 | e = hash_lookup(hash_entries, hash_keys, hash_cap, key); 174 | } 175 | } 176 | } 177 | 178 | __global__ void find_valid(const int* hash_entries, 179 | int* valid_entries, 180 | int* n_valid, 181 | size_t hash_cap) { 182 | 183 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 184 | if (idx < hash_cap) { 185 | const int& e = hash_entries[idx]; 186 | if(e >= 0) { 187 | const int my_ind = atomicAdd(n_valid, 1); 188 | valid_entries[my_ind] = e; 189 | } 190 | } 191 | } 192 | 193 | template 194 | void _call_hash_kernels(const float* points, 195 | int* hash_entries, 196 | short* hash_keys, 197 | int* neib_ents, 198 | float* barycentric, 199 | int* valid_entries, 200 | int* n_valid, 201 | size_t hash_cap, 202 | size_t N, 203 | cudaStream_t stream) { 204 | 205 | build_hash<<>>(points, 206 | hash_entries, hash_keys, neib_ents, barycentric, hash_cap, N); 207 | dedup<<>>(hash_entries, hash_keys, hash_cap); 208 | find_valid<<>>(hash_entries, valid_entries, n_valid, hash_cap); 209 | } 210 | 211 | void call_build_hash_kernels(const float* points, 212 | int* hash_entries, 213 | short* hash_keys, 214 | int* neib_ents, 215 | float* barycentric, 216 | int* valid_entries, 217 | int* n_valid, 218 | size_t hash_cap, 219 | size_t N, size_t dim, 220 | cudaStream_t stream) { 221 | 222 | switch(dim) { 223 | case 1: 224 | _call_hash_kernels<1>(points, hash_entries, hash_keys, neib_ents, 225 | barycentric, valid_entries, n_valid, hash_cap, N, stream); 226 | break; 227 | case 2: 228 | _call_hash_kernels<2>(points, hash_entries, hash_keys, neib_ents, 229 | barycentric, valid_entries, n_valid, hash_cap, N, stream); 230 | break; 231 | case 3: 232 | _call_hash_kernels<3>(points, hash_entries, hash_keys, neib_ents, 233 | barycentric, valid_entries, n_valid, hash_cap, N, stream); 234 | break; 235 | case 4: 236 | _call_hash_kernels<4>(points, hash_entries, hash_keys, neib_ents, 237 | barycentric, valid_entries, n_valid, hash_cap, N, stream); 238 | break; 239 | case 5: 240 | _call_hash_kernels<5>(points, hash_entries, hash_keys, neib_ents, 241 | barycentric, valid_entries, n_valid, hash_cap, N, stream); 242 | break; 243 | default: 244 | fprintf(stderr, 245 | "Can't build hash tables for more than 5 dimensional points (but you can fix this by copy/pasting the above lines a few more times if you need to).\n"); 246 | exit(-1); 247 | } 248 | } 249 | 250 | -------------------------------------------------------------------------------- /src/build_hash_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void build_hash_cuda(const torch::Tensor& th_points, 4 | torch::Tensor th_hash_entries, 5 | torch::Tensor th_hash_keys, 6 | torch::Tensor th_neib_ents, 7 | torch::Tensor th_barycentric, 8 | torch::Tensor th_valid_entries, 9 | torch::Tensor th_n_valid); 10 | -------------------------------------------------------------------------------- /src/build_hash_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _BHK_H 2 | #define _BHK_H 3 | 4 | #define BLOCK 256 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | void call_build_hash_kernels(const float* points, 11 | int* hash_entries, 12 | short* hash_keys, 13 | int* neib_ents, 14 | float* barycentric, 15 | int* valid_entries, 16 | int* n_valid, 17 | size_t hash_cap, 18 | size_t N, size_t dim, 19 | cudaStream_t stream); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /src/build_hash_wrapper.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "build_hash_kernel.h" 5 | #include "common.h" 6 | 7 | void build_hash_cuda(const torch::Tensor& th_points, 8 | torch::Tensor th_hash_entries, 9 | torch::Tensor th_hash_keys, 10 | torch::Tensor th_neib_ents, 11 | torch::Tensor th_barycentric, 12 | torch::Tensor th_valid_entries, 13 | torch::Tensor th_n_valid) { 14 | 15 | CHECK_INPUT(th_points) 16 | CHECK_INPUT(th_hash_entries) 17 | CHECK_INPUT(th_hash_keys) 18 | CHECK_INPUT(th_neib_ents) 19 | CHECK_INPUT(th_barycentric) 20 | CHECK_INPUT(th_valid_entries) 21 | CHECK_CONTIGUOUS(th_n_valid) 22 | 23 | CHECK_4DIMS(th_points) 24 | CHECK_2DIMS(th_hash_entries) 25 | CHECK_3DIMS(th_hash_keys) 26 | CHECK_4DIMS(th_neib_ents) 27 | CHECK_4DIMS(th_barycentric) 28 | CHECK_2DIMS(th_valid_entries) 29 | CHECK_2DIMS(th_n_valid) 30 | 31 | const float* points = DATA_PTR(th_points, float); 32 | int* hash_entries = DATA_PTR(th_hash_entries, int); 33 | short* hash_keys = DATA_PTR(th_hash_keys, short); 34 | int* neib_ents = DATA_PTR(th_neib_ents, int); 35 | float* barycentric = DATA_PTR(th_barycentric, float); 36 | int* valid_entries = DATA_PTR(th_valid_entries, int); 37 | int* n_valid = DATA_PTR(th_n_valid, int); 38 | 39 | const int B = th_points.size(0); 40 | const int dim = th_points.size(1); 41 | const int H = th_points.size(2); 42 | const int W = th_points.size(3); 43 | const int hash_cap = th_hash_entries.size(1); 44 | 45 | cudaError_t err; 46 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 47 | 48 | for (int b=0; b < B; ++b) { 49 | call_build_hash_kernels( 50 | points + (b * dim * H * W), 51 | hash_entries + (b * hash_cap), 52 | hash_keys + (b * hash_cap * dim), 53 | neib_ents + (b * (dim + 1) * H * W), 54 | barycentric + (b * (dim + 1) * H * W), 55 | valid_entries + (b * hash_cap), 56 | n_valid + b, 57 | hash_cap, H * W, dim, stream 58 | ); 59 | 60 | err = cudaGetLastError(); 61 | if (err != cudaSuccess) { 62 | fprintf(stderr, "build_hash CUDA kernel failure: %s\n", cudaGetErrorString(err)); 63 | exit(-1); 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef TORCH_CHECK 7 | #define MYCHECK TORCH_CHECK 8 | #else 9 | #define MYCHECK AT_ASSERTM 10 | #endif 11 | 12 | #define CHECK_CUDA(x) MYCHECK(x.device().is_cuda(), #x " is not a CUDA tensor.") 13 | #define CHECK_CONTIGUOUS(x) MYCHECK(x.is_contiguous(), #x " is not contiguous.") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | #define CHECK_2DIMS(x) MYCHECK((x.sizes().size() == 2), #x " is not 2 dimensional.") 18 | #define CHECK_3DIMS(x) MYCHECK((x.sizes().size() == 3), #x " is not 3 dimensional.") 19 | #define CHECK_4DIMS(x) MYCHECK((x.sizes().size() == 4), #x " is not 4 dimensional.") 20 | 21 | #ifdef USE_OLD_PYTORCH_DATA_MEMBER 22 | #define DATA_PTR(x,t) x.data() 23 | #else 24 | #define DATA_PTR(x,t) x.data_ptr() 25 | #endif 26 | 27 | -------------------------------------------------------------------------------- /src/gfilt_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "hash_fns.cuh" 4 | #include "gfilt_kernel.h" 5 | 6 | template 7 | __global__ void splat(const float* values, 8 | const float* barycentric, 9 | const int* hash_entries, 10 | const int* neib_ents, 11 | float* hash_values, size_t N) { 12 | 13 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 14 | if (idx >= N) { return; } 15 | 16 | // We work with val_dim + 1 value channels here (one extra) because the 17 | // last channel keeps track of the normalization factor to use at the end. 18 | // During splatting, we splat 1 * bary into the last channel. 19 | float local_v[val_dim + 1]; 20 | float local_b[ref_dim+1]; 21 | local_v[val_dim] = 1.f; 22 | 23 | for(int i=0; i < val_dim; ++i) { local_v[i] = values[idx + N*i]; } 24 | for(int i=0; i < (ref_dim+1); ++i) { local_b[i] = barycentric[idx + N*i]; } 25 | 26 | // Splat this point onto each vertex of its surrounding simplex. 27 | for(int k=0; k 38 | __global__ void blur(float* out, 39 | const int* hash_entries, 40 | const int* valid_entries, 41 | const short* hash_keys, 42 | const float* hash_values, 43 | size_t hash_cap, int n_valid, size_t axis) { 44 | 45 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 46 | if (idx >= n_valid) { return; } 47 | 48 | // We work with val_dim + 1 value channels here (one extra) because the 49 | // last channel keeps track of the normalization factor to use at the end. 50 | float local_out[val_dim + 1]; 51 | 52 | // The local key storage needs the normally-ignored value at the end so 53 | // that key[axis] is always a valid memory access. 54 | //short key[ref_dim]; 55 | short key[ref_dim+1]; 56 | 57 | const int& ind_c = valid_entries[idx]; 58 | for(int i=0; i < val_dim+1; ++i) { local_out[i] = 0; } 59 | 60 | const short* key_c = &hash_keys[ind_c*ref_dim]; 61 | 62 | for(int i=0; i(hash_entries, hash_keys, hash_cap, key); 66 | 67 | for(int i=0; i(hash_entries, hash_keys, hash_cap, key); 71 | 72 | if(ind_l >= 0 && ind_r >= 0) { 73 | for(int i=0; i < val_dim + 1; ++i) { 74 | local_out[i] = (hash_values[ind_l * (val_dim + 1) + i] + 75 | 2*hash_values[ind_c * (val_dim + 1) + i] + 76 | hash_values[ind_r * (val_dim + 1) + i]) * 0.25f; 77 | } 78 | } else if(ind_l >= 0) { 79 | for(int i=0; i < val_dim + 1; ++i) { 80 | local_out[i] = (hash_values[ind_l * (val_dim + 1) + i] + 81 | 2*hash_values[ind_c * (val_dim + 1) + i]) * 0.25f; 82 | } 83 | } else if(ind_r >= 0) { 84 | for(int i=0; i < val_dim + 1; ++i) { 85 | local_out[i] = (hash_values[ind_r * (val_dim + 1) + i] + 86 | 2*hash_values[ind_c * (val_dim + 1) + i]) * 0.25f; 87 | } 88 | } else { 89 | for(int i=0; i < val_dim + 1; ++i) { 90 | local_out[i] = hash_values[ind_c * (val_dim + 1) + i] * 0.5f; 91 | } 92 | } 93 | 94 | for(int i=0; i < val_dim + 1; ++i) { out[ind_c * (val_dim + 1) + i] = local_out[i]; } 95 | } 96 | 97 | template 98 | __global__ void slice(float* out, 99 | const float* barycentric, 100 | const int* hash_entries, 101 | const int* neib_ents, 102 | const float* hash_values, int N) { 103 | 104 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; 105 | if (idx >= N) { return; } 106 | 107 | // We work with val_dim + 1 value channels here (one extra) because the 108 | // last channel keeps track of the normalization factor by which we will 109 | // scale the output. 110 | float local_out[val_dim + 1]; 111 | for(int i=0; i < val_dim + 1; ++i) { local_out[i] = 0; } 112 | 113 | // Gather values from each of the surrounding simplex vertices. 114 | for(int k=0; k 129 | void _call_gfilt_kernels(const float* values, float* output, 130 | float* tmp_vals_1, float* tmp_vals_2, 131 | const int* hash_entries, const short* hash_keys, 132 | const int* neib_ents, const float* barycentric, 133 | const int* valid_entries, int n_valid, 134 | size_t hash_cap, size_t N, bool reverse, cudaStream_t stream) { 135 | 136 | splat<<>>(values, 137 | barycentric, hash_entries, neib_ents, tmp_vals_1, N); 138 | 139 | float* tmp_swap; 140 | 141 | for(int ax=reverse ? ref_dim : 0; 142 | ax <= ref_dim && ax >= 0; 143 | reverse ? --ax : ++ax 144 | ) { 145 | blur<<>>(tmp_vals_2, hash_entries, valid_entries, hash_keys, 147 | tmp_vals_1, hash_cap, n_valid, ax); 148 | 149 | tmp_swap = tmp_vals_1; 150 | tmp_vals_1 = tmp_vals_2; 151 | tmp_vals_2 = tmp_swap; 152 | } 153 | 154 | slice<<>>(output, barycentric, hash_entries, neib_ents, tmp_vals_2, N); 155 | } 156 | 157 | void call_gfilt_kernels(const float* values, float* output, 158 | float* tmp_vals_1, float* tmp_vals_2, 159 | const int* hash_entries, const short* hash_keys, 160 | const int* neib_ents, const float* barycentric, 161 | const int* valid_entries, int n_valid, 162 | size_t hash_cap, size_t N, size_t ref_dim, 163 | size_t val_dim, bool reverse, cudaStream_t stream) { 164 | 165 | #include "gfilt_dispatch_table.h" 166 | } 167 | -------------------------------------------------------------------------------- /src/gfilt_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | void gfilt_cuda(const torch::Tensor& th_values, torch::Tensor th_output, 4 | torch::Tensor th_tmp_vals_1, torch::Tensor th_tmp_vals_2, 5 | const torch::Tensor& th_hash_entries, const torch::Tensor& th_hash_keys, 6 | const torch::Tensor& th_neib_ents, const torch::Tensor& th_barycentric, 7 | const torch::Tensor& th_valid_entries, const torch::Tensor& n_valid, 8 | int64_t ref_dim, bool reverse); 9 | -------------------------------------------------------------------------------- /src/gfilt_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _GFK_H 2 | #define _GFK_H 3 | 4 | #define BLOCK 256 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | void call_gfilt_kernels(const float* values, float* output, 11 | float* tmp_vals_1, float* tmp_vals_2, 12 | const int* hash_entries, const short* hash_keys, 13 | const int* neib_ents, const float* barycentric, 14 | const int* valid_entries, int n_valid, 15 | size_t hash_cap, size_t N, size_t ref_dim, 16 | size_t val_dim, bool reverse, cudaStream_t stream); 17 | 18 | #ifdef __cplusplus 19 | } 20 | #endif 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /src/gfilt_wrapper.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "gfilt_kernel.h" 5 | #include "common.h" 6 | 7 | void gfilt_cuda(const torch::Tensor& th_values, torch::Tensor th_output, 8 | torch::Tensor th_tmp_vals_1, torch::Tensor th_tmp_vals_2, 9 | const torch::Tensor& th_hash_entries, const torch::Tensor& th_hash_keys, 10 | const torch::Tensor& th_neib_ents, const torch::Tensor& th_barycentric, 11 | const torch::Tensor& th_valid_entries, const torch::Tensor& th_n_valid, 12 | int64_t ref_dim, bool reverse) { 13 | 14 | CHECK_4DIMS(th_values) 15 | CHECK_4DIMS(th_output) 16 | CHECK_2DIMS(th_tmp_vals_1) 17 | CHECK_2DIMS(th_tmp_vals_2) 18 | 19 | CHECK_2DIMS(th_hash_entries) 20 | CHECK_3DIMS(th_hash_keys) 21 | CHECK_4DIMS(th_neib_ents) 22 | CHECK_4DIMS(th_barycentric) 23 | CHECK_2DIMS(th_valid_entries) 24 | CHECK_CONTIGUOUS(th_n_valid) 25 | MYCHECK(th_n_valid.device().is_cpu(), "n_valid is not a CPU tensor.") 26 | 27 | const float* values = DATA_PTR(th_values, float); 28 | float* output = DATA_PTR(th_output, float); 29 | float* tmp_vals_1 = DATA_PTR(th_tmp_vals_1, float); 30 | float* tmp_vals_2 = DATA_PTR(th_tmp_vals_2, float); 31 | const int* hash_entries = DATA_PTR(th_hash_entries, int); 32 | const short* hash_keys = DATA_PTR(th_hash_keys, short); 33 | const int* neib_ents = DATA_PTR(th_neib_ents, int); 34 | const float* barycentric = DATA_PTR(th_barycentric, float); 35 | const int* valid_entries = DATA_PTR(th_valid_entries, int); 36 | int* n_valid = DATA_PTR(th_n_valid, int); 37 | 38 | const int B = th_values.size(0); 39 | const int H = th_values.size(2); 40 | const int W = th_values.size(3); 41 | const int val_dim = th_values.size(1); 42 | const int hash_cap = th_hash_entries.size(1); 43 | 44 | cudaError_t err; 45 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 46 | 47 | for (int b=0; b < B; ++b) { 48 | th_tmp_vals_1.fill_(0.f); 49 | th_tmp_vals_2.fill_(0.f); 50 | 51 | call_gfilt_kernels( 52 | values + (b * val_dim * H * W), 53 | output + (b * val_dim * H * W), 54 | tmp_vals_1, tmp_vals_2, 55 | hash_entries + (b * hash_cap), 56 | hash_keys + (b * hash_cap * ref_dim), 57 | neib_ents + (b * (ref_dim + 1) * H * W), 58 | barycentric + (b * (ref_dim + 1) * H * W), 59 | valid_entries + (b * hash_cap), 60 | n_valid[b], 61 | hash_cap, H * W, ref_dim, val_dim, 62 | reverse, stream 63 | ); 64 | 65 | err = cudaGetLastError(); 66 | if (err != cudaSuccess) { 67 | fprintf(stderr, "gfilt CUDA kernel failure: %s\n", cudaGetErrorString(err)); 68 | exit(-1); 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/hash_fns.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "build_hash_kernel.h" 3 | 4 | const size_t MIN_QUAD_PROBES = 10000; 5 | 6 | inline dim3 cuda_gridsize(int n) { 7 | int k = (n - 1) / BLOCK + 1; 8 | int x = k; 9 | int y = 1; 10 | 11 | if(x > 65535) { 12 | x = ceil(sqrt(k)); 13 | y = (n - 1) / (x * BLOCK) + 1; 14 | } 15 | 16 | return dim3(x, y, 1); 17 | } 18 | 19 | template 20 | __device__ inline unsigned int hash(const short* key) { 21 | unsigned int h = 0; 22 | /* 23 | for (int i=0; i < key_dim; ++i) { 24 | h ^= ((unsigned int)key[i]) << ((31/key_dim)*i); 25 | } 26 | */ 27 | for (int i=0; i < key_dim; i++) { 28 | h += key[i]; 29 | h = h * 2531011; 30 | } 31 | return h; 32 | } 33 | 34 | template 35 | __device__ inline bool key_cmp(const short* key1, const short* key2) { 36 | for(int i=0; i 43 | __device__ inline int hash_lookup(const int* entries, const short* keys, size_t capacity, const short* key) { 44 | unsigned int h = hash(key) % capacity; 45 | const unsigned int init_h = h; 46 | const size_t max_it = min((size_t)MIN_QUAD_PROBES, capacity); 47 | 48 | // The probing sequence here needs to match the one used for insertion, 49 | // otherwise bad things will happen. 50 | for(int iter=0; iter(&keys[entry*key_dim], key)) { return entry; } 53 | h = (init_h + iter*iter) % capacity; 54 | } 55 | 56 | for(int iter=0; iter(&keys[entry*key_dim], key)) { return entry; } 59 | h = (h+1) % capacity; 60 | } 61 | 62 | // We wrapped around without finding a matching key, the key is not in the 63 | // table. 64 | return -1; 65 | } 66 | -------------------------------------------------------------------------------- /src/permutohedral.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "build_hash_cuda.h" 5 | #include "gfilt_cuda.h" 6 | 7 | 8 | PYBIND11_MODULE(permutohedral_ext, m) { 9 | m.def("gfilt_cuda", &gfilt_cuda, "High-dimensional Gaussian filter (CUDA)"); 10 | m.def("build_hash_cuda", &build_hash_cuda, "High-dimensional Gaussian filter (CUDA)"); 11 | } 12 | 13 | TORCH_LIBRARY(permutohedral_ext, m) { 14 | m.def("gfilt_cuda", &gfilt_cuda); 15 | m.def("build_hash_cuda", &build_hash_cuda); 16 | } 17 | --------------------------------------------------------------------------------