├── .flake8
├── .gitignore
├── LICENSE
├── README.md
├── bilateral.py
├── crfrnn
├── __init__.py
└── crf.py
├── images
├── crf_layer_example.png
├── filtered.png
└── wimr_small.png
├── make_gfilt_dispatch.py
├── permutohedral
├── __init__.py
├── gfilt.py
└── hash.py
├── pyproject.toml
├── sanity_check.py
├── setup.py
└── src
├── build_hash.cu
├── build_hash_cuda.h
├── build_hash_kernel.h
├── build_hash_wrapper.cu
├── common.h
├── gfilt_cuda.cu
├── gfilt_cuda.h
├── gfilt_kernel.h
├── gfilt_wrapper.cu
├── hash_fns.cuh
└── permutohedral.cpp
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = B,C,E,F,P,W,B9
3 | max-line-length = 100
4 | ### DEFAULT IGNORES FOR 4-space INDENTED PROJECTS ###
5 | # Main Explanation Docs: https://lintlyci.github.io/Flake8Rules/
6 | #
7 | # E127, E128 are hard to silence in certain nested formatting situations.
8 | # E203 doesn't work for slicing
9 | # E265, E266 talk about comment formatting which is too opinionated.
10 | # E402 warns on imports coming after statements. There are important use cases
11 | # like demandimport (https://fburl.com/demandimport) that require statements
12 | # before imports.
13 | # E501 is not flexible enough, we're using B950 instead.
14 | # E722 is a duplicate of B001.
15 | # F811 looks for duplicate imports + noise for overload typing
16 | # P207 is a duplicate of B003.
17 | # P208 is a duplicate of C403.
18 | # W503 talks about operator formatting which is too opinionated.
19 | ignore = E127, E128, E203, E265, E266, E402, E501, E722, F811, P207, P208, W503
20 | ### DEFAULT IGNORES FOR 2-space INDENTED PROJECTS (uncomment) ###
21 | # ignore = E111, E114, E121, E127, E128, E265, E266, E402, E501, P207, P208, W503
22 | exclude =
23 | .git,
24 | .hg,
25 | __pycache__,
26 | _bin/*,
27 | _build/*,
28 | _ig_fbcode_wheel/*,
29 | buck-out/*,
30 | third-party-buck/*,
31 | third-party2/*
32 |
33 | # Calculate max-complexity by changing the value below to 1, then surveying fbcode
34 | # to see the distribution of complexity:
35 | # find ./[a-z0-9]* -name 'buck-*' -prune -o -name 'third*party*' -prune -o \
36 | # -name '*.py' -print |\
37 | # parallel flake8 --config ./.flake8 |\
38 | # perl -ne 'if (/C901/) { s/.*\((\d+)\)/$1/; print; }' | stats
39 | # NOTE: This will take a while to run (near an hour IME) so you probably want a
40 | # second working dir to run it in.
41 | # Pick a reasonable point from there (e.g. p95 or "95%")
42 | # As of 2016-05-18 the rough distribution is:
43 | #
44 | # count: 134807
45 | # min: 2
46 | # max: 206
47 | # avg: 4.361
48 | # median: 3
49 | # sum: 587882
50 | # stddev: 4.317
51 | # variance: 18.635
52 | #
53 | # percentiles:
54 | # 75%: 5
55 | # 90%: 8
56 | # 95%: 11
57 | # 99%: 20
58 | # 99.9%: 48
59 | # 99.99%: 107
60 | # 99.999%: 160
61 | max-complexity = 12
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated code
2 | crfrnn/src/gfilt_dispatch_table.h
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 |
8 | # C extensions
9 | *.so
10 |
11 | # editor stuff
12 | .vscode/*
13 | !.vscode/settings.json
14 |
15 | # Misc
16 | build
17 | *.swp
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Gabriel Schwartz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CRF-as-RNN Layer for [Pytorch](https://github.com/pytorch/pytorch)
2 |
3 | This repository contains an implementation of the CRF-as-RNN method
4 | [described here](http://www.robots.ox.ac.uk/~szheng/CRFasRNN.html). Please cite
5 | their work if you use this in your own code. I am not affiliated with their
6 | group, this is just a side-project.
7 |
8 | The pytorch module relies on two Functions: one to build the hashtable
9 | representing a [permutohedral
10 | lattice](http://graphics.stanford.edu/papers/permutohedral/permutohedral.pdf)
11 | and another to perform the high-dimensional Gaussian filtering required by
12 | approximate CRF inference.
13 |
14 | ## Setup
15 |
16 | For inplace use / testing:
17 |
18 | ```
19 | python setup.py build_ext --inplace
20 | ```
21 |
22 | Or, to install the packages (permutohedral, crfrnn):
23 |
24 | ```
25 | python setup.py install
26 | ```
27 |
28 | ## Pytorch Module
29 | [](images/crf_layer_example.png)
30 |
31 | The [Pytorch](https://github.com/pytorch/pytorch) module takes two inputs for
32 | the forward pass: a probability map (typically the output of a softmax layer),
33 | and a reference image (typically the image being segmented/densely-classified).
34 | Optional additional parameters may be provided to the module on construction:
35 |
36 | * `sxy_bf`: spatial standard deviation for the bilateral filter.
37 | * `sc_bf`: color standard deviation for the bilateral filter.
38 | * `compat_bf`: label compatibility weight for the bilateral filter.
39 | * `sxy_spatial`: spatial standard deviation for the 2D Gaussian filter.
40 | * `compat_spatial`: label compatibility weight for the 2D Gaussian filter.
41 |
42 | **Note**: the default color standard deviation assumes the input is a color
43 | image in the range [0, 255]. If you use whitened or otherwise-normalized images,
44 | you should change this value.
45 |
46 | Here is a simple example:
47 |
48 | ```python
49 | import torch as th
50 |
51 | from crfrnn import CRF
52 |
53 | n_categories = 32
54 |
55 | class MyCNN(th.nn.Module):
56 | def __init__(self):
57 | super(MyCNN, self).__init__()
58 | self.relu = th.nn.ReLU()
59 | self.conv1 = th.nn.Conv2d(3, 64, 3, 1, 1)
60 | self.conv2 = th.nn.Conv2d(64, 64, 3, 1, 1)
61 | self.final = th.nn.Conv2d(64, n_categories, 3, 1, 1)
62 | self.crf = CRF()
63 |
64 | def forward(self, x):
65 | input = x
66 | x = self.relu(self.conv1(x))
67 | x = self.relu(self.conv2(x))
68 | x = th.softmax(self.final(x), dim=1)
69 | x = self.crf(x, input)
70 | return x
71 |
72 | img = th.zeros(1, 3, 384, 512, device="cuda:0")
73 | model = MyCNN()
74 | model.to(device="cuda:0")
75 | model(img)
76 | ```
77 |
78 | ## Sub-Functions
79 |
80 | The functions used for CRF inference can also be used on their own for things
81 | like bilateral filtering. [bilateral.py](bilateral.py) contains a sample
82 | implementation.
83 |
84 | `python bilateral.py input.png output.png 20 0.25`
85 |
86 |
87 |
--------------------------------------------------------------------------------
/bilateral.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 | import torch as th
5 | import cv2
6 |
7 | from permutohedral.gfilt import gfilt
8 |
9 | def gaussian_filter(ref, val, kstd):
10 | return gfilt(ref / kstd[None, :, None, None], val)
11 |
12 | def usage():
13 | print("Usage: python bilateral.py input output sxy srgb")
14 | exit(1)
15 |
16 | if len(sys.argv) != 5:
17 | usage()
18 |
19 | try:
20 | sxy = float(sys.argv[3])
21 | srgb = float(sys.argv[4])
22 | except:
23 | usage()
24 |
25 | img = cv2.imread(sys.argv[1]).astype(np.float32)[..., :3] / 255.
26 | img = img.transpose(2, 0, 1)
27 | yx = np.mgrid[:img.shape[1], :img.shape[2]].astype(np.float32)
28 | stacked = np.vstack([yx, img])
29 |
30 | img = th.from_numpy(img).cuda()
31 | stacked = th.from_numpy(stacked).cuda()
32 | kstd = th.FloatTensor([sxy, sxy, srgb, srgb, srgb]).cuda()
33 |
34 | filtered = gaussian_filter(stacked[None], img[None], kstd)[0]
35 | filtered = (255 * filtered).permute(1, 2, 0).byte().data.cpu().numpy()
36 | cv2.imwrite(sys.argv[2], filtered)
37 |
--------------------------------------------------------------------------------
/crfrnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/crfrnn/__init__.py
--------------------------------------------------------------------------------
/crfrnn/crf.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import torch.nn as nn
3 | import torch.nn.functional as thf
4 |
5 | from permutohedral.gfilt import gfilt
6 |
7 |
8 | def gaussian_filter(ref, val, kstd):
9 | return gfilt(ref / kstd[:, :, None, None], val)
10 |
11 |
12 | def mgrid(h, w, dev):
13 | y = th.arange(0, h, device=dev).repeat(w, 1).t()
14 | x = th.arange(0, w, device=dev).repeat(h, 1)
15 | return th.stack([y, x], 0)
16 |
17 |
18 | def gkern(std, chans, dev):
19 | sig_sq = std ** 2
20 | r = sig_sq if (sig_sq % 2) else sig_sq - 1
21 | s = 2 * r + 1
22 | k = th.exp(-((mgrid(s, s, dev) - r) ** 2).sum(0) / (2 * sig_sq))
23 | W = th.zeros(chans, chans, s, s, device=dev)
24 | for i in range(chans):
25 | W[i, i] = k / k.sum()
26 | return W
27 |
28 |
29 | class CRF(nn.Module):
30 | def __init__(
31 | self,
32 | n_ref: int,
33 | n_out: int,
34 | sxy_bf: float = 70,
35 | sc_bf: float = 12,
36 | compat_bf: float = 4,
37 | sxy_spatial: float = 6,
38 | compat_spatial: float = 2,
39 | num_iter: int = 5,
40 | normalize_final_iter: bool = True,
41 | trainable_kstd: bool = False,
42 | ):
43 | """Implements fast approximate mean-field inference for a
44 | fully-connected CRF with Gaussian edge potentials within a neural
45 | network layer using fast bilateral filtering.
46 |
47 | Args:
48 | n_ref: Number of channels in the reference images.
49 |
50 | n_out: Number of labels.
51 |
52 | sxy_bf: Spatial standard deviation of the bilateral filter.
53 |
54 | sc_bf: Color standard deviation of the bilateral filter.
55 |
56 | compat_bf: Label compatibility weight for the bilateral filter.
57 | Assumes a Potts model w/one parameter.
58 |
59 | sxy_spatial: Spatial standard deviation of the 2D Gaussian
60 | convolution kernel.
61 |
62 | compat_spatial: Label compatibility weight of the 2D Gaussian
63 | convolution kernel.
64 |
65 | num_iter: Number of steps to run in the inference loop.
66 |
67 | normalize_final_iter: If pre-softmax outputs are desired rather
68 | than label probabilities, set this to False.
69 |
70 | trainable_kstd: Allow the parameters of the bilateral filter to be
71 | learned as well. This option may make training less stable.
72 | """
73 | assert n_ref in {1, 3}, "Reference image must be either RGB or greyscale (3 or 1 channels)."
74 |
75 | super().__init__()
76 |
77 | self.n_ref = n_ref
78 | self.n_out = n_out
79 | self.sxy_bf = sxy_bf
80 | self.sc_bf = sc_bf
81 | self.compat_bf = compat_bf
82 | self.sxy_spatial = sxy_spatial
83 | self.compat_spatial = compat_spatial
84 | self.num_iter = num_iter
85 | self.normalize_final_iter = normalize_final_iter
86 | self.trainable_kstd = trainable_kstd
87 |
88 | kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf])
89 | if n_ref == 1:
90 | kstd = kstd[:3]
91 |
92 | if trainable_kstd:
93 | self.kstd = nn.Parameter(kstd)
94 | else:
95 | self.register_buffer("kstd", kstd)
96 |
97 | self.register_buffer("gk", gkern(sxy_spatial, n_out))
98 |
99 | def forward(self, unary, ref):
100 | def _bilateral(V, R):
101 | return gaussian_filter(R, V, self.kstd[None])
102 |
103 | def _step(prev_q, U, ref, normalize=True):
104 | qbf = _bilateral(prev_q, ref)
105 | qsf = thf.conv2d(prev_q, self.gk, padding=self.gk.shape[-1] // 2)
106 | q_hat = -self.compat_bf * qbf - self.compat_spatial * qsf
107 | q_hat = U - q_hat
108 | return th.softmax(q_hat, dim=1) if normalize else q_hat
109 |
110 | def _inference(unary, ref):
111 | U = th.log(th.clamp(unary, 1e-5, 1))
112 | prev_q = th.softmax(U, dim=1)
113 |
114 | for i in range(self.num_iter):
115 | normalize = self.normalize_final_iter or i < self.num_iter - 1
116 | prev_q = _step(prev_q, U, ref, normalize=normalize)
117 | return prev_q
118 |
119 | N, _, H, W = unary.shape
120 | yx = mgrid(H, W, unary.device)
121 | grid = yx[None].repeat(N, 1, 1, 1)
122 | stacked = th.cat([grid, ref], dim=1)
123 |
124 | return _inference(unary, stacked)
125 |
--------------------------------------------------------------------------------
/images/crf_layer_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/crf_layer_example.png
--------------------------------------------------------------------------------
/images/filtered.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/filtered.png
--------------------------------------------------------------------------------
/images/wimr_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/images/wimr_small.png
--------------------------------------------------------------------------------
/make_gfilt_dispatch.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | base_call = "_call_gfilt_kernels<%d, %d>(values, output, tmp_vals_1, tmp_vals_2, hash_entries, hash_keys, neib_ents, barycentric, valid_entries, n_valid, hash_cap, N, reverse, stream);"
4 |
5 | if __name__ == "__main__":
6 | args = sys.argv[1:]
7 | if len(args) == 2:
8 | max_ref_dim, max_val_dim = args
9 | ref_dims = range(1, max_ref_dim+1)
10 | val_dims = range(1, max_val_dim+1)
11 | else:
12 | assert(len(args) == 0)
13 | ref_dims = [2, 6]
14 | val_dims = range(1, 16)
15 |
16 | print("switch(1000 * ref_dim + val_dim) {")
17 | for rdim in ref_dims:
18 | for vdim in val_dims:
19 | print("\tcase %d:" % (1000 * rdim + vdim))
20 | print("\t\t" + (base_call % (rdim, vdim)))
21 | print("\t\tbreak;")
22 | print("\tdefault:")
23 | print("\t\tprintf(\"Unsupported ref_dim/val_dim combination (%zd, %zd), generate a new dispatch table using 'make_gfilt_dispatch.py'.\\n\", ref_dim, val_dim);")
24 | print("\t\texit(-1);")
25 | print("}")
26 |
--------------------------------------------------------------------------------
/permutohedral/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HapeMask/crfrnn_layer/d4223d80aaa0bd960d5ba8faaa7beb862cc6d255/permutohedral/__init__.py
--------------------------------------------------------------------------------
/permutohedral/gfilt.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | from .hash import make_hashtable
4 | import permutohedral_ext
5 |
6 | th.ops.load_library(permutohedral_ext.__file__)
7 | gfilt_cuda = th.ops.permutohedral_ext.gfilt_cuda
8 |
9 |
10 | def make_gfilt_buffers(b, val_dim, h, w, cap, dev):
11 | return [
12 | th.zeros(b, val_dim, h, w, device=dev), # output
13 | th.empty(cap, val_dim + 1, device=dev), # tmp_vals_1
14 | th.empty(cap, val_dim + 1, device=dev), # tmp_vals_2
15 | ]
16 |
17 |
18 | class GaussianFilter(th.autograd.Function):
19 | @staticmethod
20 | def forward(ctx, ref, val, _hash_buffers=None, _gfilt_buffers=None):
21 | val = val.contiguous()
22 | b, ref_dim, h, w = ref.shape
23 | vb, val_dim, vh, vw = val.shape
24 | assert vb == b and vh == h and vw == w
25 |
26 | if _hash_buffers is None:
27 | hash_buffers = make_hashtable(ref)
28 | else:
29 | hash_buffers = list(_hash_buffers)
30 | hash_buffers[-1] = hash_buffers[-1].cpu()
31 |
32 | assert hash_buffers[0].shape[0] == b
33 | cap = hash_buffers[0].shape[1]
34 |
35 | if _gfilt_buffers is None:
36 | gfilt_buffers = make_gfilt_buffers(b, val_dim, h, w, cap, val.device)
37 | else:
38 | gfilt_buffers = list(_gfilt_buffers)
39 |
40 | if val.is_cuda:
41 | gfilt_cuda(val, *gfilt_buffers, *hash_buffers, ref_dim, False)
42 | else:
43 | raise NotImplementedError("Gfilt currently requires CUDA support.")
44 | # gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False)
45 |
46 | out = gfilt_buffers[0]
47 |
48 | if ref.requires_grad:
49 | ctx.save_for_backward(ref, val, out, *hash_buffers)
50 | elif val.requires_grad:
51 | ctx.save_for_backward(ref, val, *hash_buffers)
52 |
53 | return out
54 |
55 | @staticmethod
56 | def backward(ctx, grad_output):
57 | grads = [None, None, None, None]
58 |
59 | ref = ctx.saved_tensors[0]
60 | val = ctx.saved_tensors[1]
61 | hash_buffers = list(ctx.saved_tensors[-6:])
62 |
63 | b, ref_dim = ref.shape[:2]
64 | val_dim, h, w = val.shape[-3:]
65 | assert hash_buffers[0].shape[0] == b
66 | cap = hash_buffers[0].shape[1]
67 |
68 | def filt(v):
69 | if not v.is_contiguous():
70 | v = v.contiguous()
71 | gfilt_buffers = make_gfilt_buffers(b, val_dim, h, w, cap, v.device)
72 | gfilt_cuda(v, *gfilt_buffers, *hash_buffers, ref_dim, True)
73 | return gfilt_buffers[0]
74 |
75 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
76 | filt_og = filt(grad_output)
77 |
78 | if ctx.needs_input_grad[0]:
79 | out = ctx.saved_tensors[2]
80 |
81 | grads[0] = th.stack(
82 | [
83 | (grad_output * (filt(val * r_i) - r_i * out))
84 | + (val * (filt(grad_output * r_i) - r_i * filt_og))
85 | for r_i in ref.split(1, dim=1)
86 | ],
87 | dim=1,
88 | ).sum(dim=2)
89 |
90 | if ctx.needs_input_grad[1]:
91 | grads[1] = filt_og
92 |
93 | return grads[0], grads[1], grads[2], grads[3]
94 |
95 |
96 | gfilt = GaussianFilter.apply
97 |
--------------------------------------------------------------------------------
/permutohedral/hash.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | import permutohedral_ext
4 |
5 | th.ops.load_library(permutohedral_ext.__file__)
6 | build_hash_cuda = th.ops.permutohedral_ext.build_hash_cuda
7 |
8 |
9 | def make_hash_buffers(b, dim, h, w, cap, dev):
10 | return [
11 | -th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries
12 | th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys
13 | th.zeros(b, dim + 1, h, w, dtype=th.int32, device=dev), # neib_ents
14 | th.zeros(b, dim + 1, h, w, device=dev), # barycentric
15 | th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries
16 | th.zeros(b, 1).int().to(device=dev), # n_valid_entries
17 | ]
18 |
19 |
20 | def get_hash_cap(N, dim):
21 | return N * (dim + 1)
22 |
23 |
24 | def make_hashtable(points):
25 | b, dim, h, w = points.shape
26 | N = h * w
27 | cap = get_hash_cap(N, dim)
28 |
29 | buffers = make_hash_buffers(b, dim, h, w, cap, points.device)
30 | if points.is_cuda:
31 | build_hash_cuda(points.contiguous(), *buffers)
32 | else:
33 | raise NotImplementedError("Hash table currently requires CUDA support.")
34 | # build_hash_cpu(points.contiguous(), *buffers, cap)
35 |
36 | return buffers
37 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line_length = 100
3 | skip-string-normalization = true
4 | target_version = ['py36', 'py37', 'py38']
5 | include = '\.pyi?$'
6 | exclude = '''
7 | /(
8 | .*\.json
9 | | \.eggs
10 | | \.git
11 | | \.hg
12 | | \.mypy_cache
13 | | \.nox
14 | | \.tox
15 | | \.venv
16 | | _build
17 | | buck-out
18 | | build
19 | | dist
20 | )/
21 | '''
22 |
--------------------------------------------------------------------------------
/sanity_check.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch as th
6 | from tqdm import tqdm
7 |
8 | from permutohedral.gfilt import gfilt
9 |
10 | ###############################################################################
11 | # Perform a simple sanity check of the filtering / hashing components. Creates
12 | # a circle filled with a checkerboard pattern, blurs it using the circle's mask
13 | # + XY coords as a reference image, and then optimizes an initial random image
14 | # + initial kernel std. values through the filter to match the target image.
15 | ###############################################################################
16 |
17 |
18 | def gaussian_filter(ref, val, kstd):
19 | return gfilt(ref / kstd[:, :, None, None], val)
20 |
21 |
22 | sxy = 3
23 | srgb = 0.25
24 | h, w = 512, 512
25 |
26 | if __name__ == "__main__":
27 | np.random.seed(0)
28 | th.manual_seed(0)
29 |
30 | yx = np.mgrid[:h, :w].astype(np.float32)
31 | yx = th.from_numpy(yx).cuda()
32 | kstd = th.FloatTensor([sxy, sxy, srgb]).cuda()
33 |
34 | tgt = np.zeros((512, 512, 3), np.uint8)
35 | mask = np.zeros((512, 512), np.uint8)
36 | cv2.circle(mask, (256, 256), 128, 255, -1)
37 | mask = mask > 0
38 |
39 | # Make a simple checkerboard texture.
40 | color_1 = [255, 128, 32]
41 | color_2 = [128, 255, 32]
42 | for i in range(8):
43 | for j in range(8):
44 | tgt[8 + i :: 16, 8 + j :: 16] = color_1
45 | tgt[i::16, j::16] = color_1
46 | tgt[8 + i :: 16, j::16] = color_2
47 | tgt[i::16, 8 + j :: 16] = color_2
48 | tgt[~mask] = 0
49 | tgt = th.from_numpy(tgt.transpose(2, 0, 1).copy()).cuda().float() / 255
50 | prefilt = tgt.clone()
51 | prefilt_np = (255 * prefilt).byte().data.cpu().numpy().transpose(1, 2, 0)
52 |
53 | # Create the filtered target image.
54 | stacked = th.cat([yx, th.from_numpy(mask[None]).cuda().float()], dim=0)
55 | tgt = gaussian_filter(stacked[None], tgt[None], kstd[None])[0]
56 | tgt_np = (255 * tgt).byte().data.cpu().numpy().transpose(1, 2, 0)
57 |
58 | # Create a random initial image that will be optimized.
59 | img_est = (0.5 * th.rand(3, h, w).cuda() + 0.25).requires_grad_(True)
60 | kstd_orig = kstd.clone()
61 | kstd[:] = 8
62 | kstd.requires_grad_(True)
63 | optim = th.optim.Adam([img_est, kstd], lr=1e-3)
64 |
65 | if not os.path.exists("sanity_imgs"):
66 | os.mkdir("sanity_imgs")
67 |
68 | for it in tqdm(range(8000)):
69 | filt = gaussian_filter(stacked[None], img_est[None], kstd[None])[0]
70 | diff = (filt - tgt) ** 2
71 | loss = diff.mean()
72 |
73 | optim.zero_grad()
74 | loss.backward()
75 | optim.step()
76 |
77 | # If left unconstrained, the image can take on negative values.
78 | with th.no_grad():
79 | img_est.clamp_(0, 1)
80 |
81 | if it == 4000:
82 | for pg in optim.param_groups:
83 | pg["lr"] /= 2
84 |
85 | if it % 25 == 0:
86 | with th.no_grad():
87 | filt = gaussian_filter(stacked[None], img_est[None], kstd[None])[0]
88 | diff = (filt - tgt) ** 2
89 | img = th.cat(
90 | [
91 | img_est.clamp(0, 1),
92 | filt.clamp(0, 1),
93 | (abs(diff) / diff.max().clamp(min=1e-2)).clamp(0, 1),
94 | ],
95 | dim=2,
96 | )
97 | img = (255 * img).byte().data.cpu().numpy().transpose(1, 2, 0)
98 | img = np.vstack([img, np.hstack([prefilt_np, tgt_np, np.zeros_like(tgt_np)])]).copy()
99 |
100 | color = (255, 255, 255)
101 | font = cv2.FONT_HERSHEY_SIMPLEX
102 | cv2.putText(img, "Input (estimated)", (32, 64), font, 1, color, 1)
103 | cv2.putText(img, "Filtered (estimated)", (32 + 512, 64), font, 1, color, 1)
104 | cv2.putText(img, "Diff (filt_est - filt_gt)x100", (32 + 2 * 512, 64), font, 1, color, 1)
105 | cv2.putText(img, "Input (GT)", (32, 64 + 512), font, 1, color, 1)
106 | cv2.putText(img, "Filtered (GT)", (32 + 512, 64 + 512), font, 1, color, 1)
107 | cv2.imwrite(f"sanity_imgs/{it:06d}.png", img)
108 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
5 |
6 | GFILT_CALL = "_call_gfilt_kernels<{ref_dim}, {val_dim}>(values, output, tmp_vals_1, tmp_vals_2, hash_entries, hash_keys, neib_ents, barycentric, valid_entries, n_valid, hash_cap, N, reverse, stream);"
7 | def make_gfilt_dispatch_table(fname, ref_dims=range(2, 6), val_dims=range(1, 16)):
8 | with open(fname, "w") as f:
9 | f.write("switch(1000 * ref_dim + val_dim) {\n")
10 | for rdim in ref_dims:
11 | for vdim in val_dims:
12 | f.write(f"\tcase {1000 * rdim + vdim}:\n")
13 | f.write("\t\t" + GFILT_CALL.format(ref_dim=rdim, val_dim=vdim) + "\n")
14 | f.write("\t\tbreak;\n")
15 | f.write("\tdefault:\n")
16 | f.write("\t\tprintf(\"Unsupported ref_dim/val_dim combination (%zd, %zd), generate a new dispatch table using 'make_gfilt_dispatch.py'.\\n\", ref_dim, val_dim);\n")
17 | f.write("\t\texit(-1);\n")
18 | f.write("}\n")
19 |
20 | if __name__ == "__main__":
21 | table_fn = "src/gfilt_dispatch_table.h"
22 | if not os.path.exists(table_fn) or (os.path.getmtime(table_fn) < os.path.getmtime(__file__)):
23 | make_gfilt_dispatch_table(table_fn)
24 |
25 | cxx_args = ["-O3", "-fopenmp", "-std=c++14"]
26 | nvcc_args = ["-O3"]
27 | if "CC" in os.environ:
28 | nvcc_args.append("-ccbin=" + os.path.dirname(os.environ.get("CC")))
29 |
30 | setup(
31 | name="permutohedral",
32 | version="0.4",
33 | description="",
34 | url="",
35 | author="Gabriel Schwartz",
36 | author_email="gbschwartz@gmail.com",
37 | ext_modules=[
38 | CUDAExtension(
39 | "permutohedral_ext",
40 | [
41 | "src/permutohedral.cpp",
42 | "src/build_hash_wrapper.cu",
43 | "src/gfilt_wrapper.cu",
44 | "src/gfilt_cuda.cu",
45 | "src/build_hash.cu"
46 | ],
47 | extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args},
48 | )
49 | ],
50 | cmdclass={"build_ext": BuildExtension},
51 | packages=["permutohedral"],
52 | )
53 |
54 | setup(name="crfrnn", packages=["crfrnn"])
55 |
--------------------------------------------------------------------------------
/src/build_hash.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "hash_fns.cuh"
4 | #include "build_hash_kernel.h"
5 |
6 | template
7 | __device__ int hash_insert(int* entries, short* keys, size_t capacity, const short* key) {
8 | unsigned int h = hash(key) % capacity;
9 | const unsigned int init_h = h;
10 | const size_t max_it = min(MIN_QUAD_PROBES, capacity);
11 | int* entry = &entries[h];
12 |
13 | // First try quadratic probing for a fixed number of iterations, then fall
14 | // back to linear probing.
15 | for(int iter=0; iter= 0 && key_cmp(&keys[ret * dim], key)){
25 | // If another thread already inserted the same key, return the
26 | // entry for that key.
27 | return ret;
28 | }
29 |
30 | h = (init_h + iter * iter) % capacity;
31 | }
32 |
33 | for(int iter=0; iter= 0 && key_cmp(&keys[ret * dim], key)){
43 | // If another thread already inserted the same key, return the
44 | // entry for that key.
45 | return ret;
46 | }
47 |
48 | h = (h + 1) % capacity;
49 | }
50 |
51 | // We wrapped around without finding a free slot, the table is full.
52 | return -1;
53 | }
54 |
55 | constexpr float root_two_thirds = 0.81649658092f;
56 |
57 | template
58 | __device__ __inline__ void embed(const float* f, float* e, size_t N) {
59 | constexpr float sf = root_two_thirds * ((float)dim + 1.f);
60 |
61 | e[dim] = -sqrt((float)dim / ((float)dim + 1.f)) * f[N * (dim - 1)] * sf;
62 | for(int i=dim - 1; i > 0; --i) {
63 | e[i] = sf * f[N * i] / sqrt((i + 1.f) / (i + 2.f)) + e[i + 1] - sqrt(i / (i + 1.f)) * sf * f[N * (i - 1)];
64 | }
65 | e[0] = sf * f[0] / sqrt(0.5f) + e[1];
66 | }
67 |
68 | template
69 | __device__ __inline__ short round2mult(float f) {
70 | const float s = f / (dim + 1.f);
71 | const float lo = floor(s) * ((float)dim + 1.f);
72 | const float hi = ceil(s) * ((float)dim + 1.f);
73 | return ((hi - f) > (f - lo)) ? lo : hi;
74 | }
75 |
76 | template
77 | __device__ __inline__ void ranksort_diff(const float* v1, const short* v2, short* rank) {
78 | for(int i=0; i<=dim; ++i) {
79 | rank[i] = 0;
80 | const float di = v1[i] - v2[i];
81 | for(int j=0; j<=dim; ++j) {
82 | const float dj = v1[j] - v2[j];
83 | if (di < dj || (di==dj && i>j)) { ++rank[i]; }
84 | }
85 | }
86 | }
87 |
88 | template
89 | __device__ __inline__ short canonical_coord(int k, int i) {
90 | return (i < (dim + 1 - k)) ? k : (k - (dim + 1));
91 | }
92 |
93 | template
94 | __global__ void build_hash(const float* points,
95 | int* hash_entries,
96 | short* hash_keys,
97 | int* neib_ents,
98 | float* barycentric,
99 | size_t hash_cap,
100 | size_t N) {
101 |
102 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
103 | if (idx < N) {
104 | short rounded[dim + 1];
105 | short rank[dim + 1];
106 | float embedded[dim + 1];
107 |
108 | const float* p = &points[idx];
109 | float* b = &barycentric[idx];
110 |
111 | embed(p, embedded, N);
112 |
113 | short sum = 0;
114 | for(int i=0; i<=dim; ++i) {
115 | const short r = round2mult(embedded[i]);
116 | rounded[i] = r;
117 | sum += r;
118 | }
119 | sum /= (short)dim + 1;
120 |
121 | // Compute rank(embedded - rounded), decreasing order
122 | ranksort_diff(embedded, rounded, rank);
123 |
124 | // Walk the point back onto H_d (Lemma 2.9 in permutohedral_techreport.pdf)
125 | for (int i = 0; i <= dim; i++) {
126 | rank[i] += sum;
127 | if (rank[i] < 0) {
128 | rank[i] += (short)dim + 1;
129 | rounded[i] += (short)dim + 1;
130 | } else if (rank[i] > dim) {
131 | rank[i] -= (short)dim + 1;
132 | rounded[i] -= (short)dim + 1;
133 | }
134 | }
135 |
136 | // The temporary key has 1 extra dimension. Normally we ignore the last
137 | // entry in the key because they sum to 0, but we can use the key as swap
138 | // space to invert the sorting permutation.
139 | short key[dim];
140 | for(int k=0; k<=dim; ++k) {
141 | for(int i=0; i(k, rank[i]) + rounded[i];
143 | }
144 |
145 | const int ind = hash_insert(hash_entries, hash_keys, hash_cap, key);
146 | assert(ind >= 0);
147 | neib_ents[idx + N * k] = ind;
148 | }
149 |
150 | float bar_tmp[dim + 2]{0};
151 | // Compute the barycentric coordinates (p.10 in [Adams etal 2010])
152 | for (int i = 0; i <= dim; ++i) {
153 | const float delta = (embedded[i] - rounded[i]) * (1.f / ((float)dim + 1));
154 | bar_tmp[dim - rank[i]] += delta;
155 | bar_tmp[dim + 1 - rank[i]] -= delta;
156 | }
157 | // Wrap around
158 | bar_tmp[0] += 1.0 + bar_tmp[dim + 1];
159 |
160 | for (int i = 0; i <= dim; ++i) {
161 | b[N * i] = bar_tmp[i];
162 | }
163 | }
164 | }
165 |
166 | template
167 | __global__ void dedup(int* hash_entries, short* hash_keys, size_t hash_cap) {
168 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
169 | if (idx < hash_cap) {
170 | int& e = hash_entries[idx];
171 | if(e >= 0) {
172 | const short* key = &hash_keys[idx * dim];
173 | e = hash_lookup(hash_entries, hash_keys, hash_cap, key);
174 | }
175 | }
176 | }
177 |
178 | __global__ void find_valid(const int* hash_entries,
179 | int* valid_entries,
180 | int* n_valid,
181 | size_t hash_cap) {
182 |
183 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
184 | if (idx < hash_cap) {
185 | const int& e = hash_entries[idx];
186 | if(e >= 0) {
187 | const int my_ind = atomicAdd(n_valid, 1);
188 | valid_entries[my_ind] = e;
189 | }
190 | }
191 | }
192 |
193 | template
194 | void _call_hash_kernels(const float* points,
195 | int* hash_entries,
196 | short* hash_keys,
197 | int* neib_ents,
198 | float* barycentric,
199 | int* valid_entries,
200 | int* n_valid,
201 | size_t hash_cap,
202 | size_t N,
203 | cudaStream_t stream) {
204 |
205 | build_hash<<>>(points,
206 | hash_entries, hash_keys, neib_ents, barycentric, hash_cap, N);
207 | dedup<<>>(hash_entries, hash_keys, hash_cap);
208 | find_valid<<>>(hash_entries, valid_entries, n_valid, hash_cap);
209 | }
210 |
211 | void call_build_hash_kernels(const float* points,
212 | int* hash_entries,
213 | short* hash_keys,
214 | int* neib_ents,
215 | float* barycentric,
216 | int* valid_entries,
217 | int* n_valid,
218 | size_t hash_cap,
219 | size_t N, size_t dim,
220 | cudaStream_t stream) {
221 |
222 | switch(dim) {
223 | case 1:
224 | _call_hash_kernels<1>(points, hash_entries, hash_keys, neib_ents,
225 | barycentric, valid_entries, n_valid, hash_cap, N, stream);
226 | break;
227 | case 2:
228 | _call_hash_kernels<2>(points, hash_entries, hash_keys, neib_ents,
229 | barycentric, valid_entries, n_valid, hash_cap, N, stream);
230 | break;
231 | case 3:
232 | _call_hash_kernels<3>(points, hash_entries, hash_keys, neib_ents,
233 | barycentric, valid_entries, n_valid, hash_cap, N, stream);
234 | break;
235 | case 4:
236 | _call_hash_kernels<4>(points, hash_entries, hash_keys, neib_ents,
237 | barycentric, valid_entries, n_valid, hash_cap, N, stream);
238 | break;
239 | case 5:
240 | _call_hash_kernels<5>(points, hash_entries, hash_keys, neib_ents,
241 | barycentric, valid_entries, n_valid, hash_cap, N, stream);
242 | break;
243 | default:
244 | fprintf(stderr,
245 | "Can't build hash tables for more than 5 dimensional points (but you can fix this by copy/pasting the above lines a few more times if you need to).\n");
246 | exit(-1);
247 | }
248 | }
249 |
250 |
--------------------------------------------------------------------------------
/src/build_hash_cuda.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | void build_hash_cuda(const torch::Tensor& th_points,
4 | torch::Tensor th_hash_entries,
5 | torch::Tensor th_hash_keys,
6 | torch::Tensor th_neib_ents,
7 | torch::Tensor th_barycentric,
8 | torch::Tensor th_valid_entries,
9 | torch::Tensor th_n_valid);
10 |
--------------------------------------------------------------------------------
/src/build_hash_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _BHK_H
2 | #define _BHK_H
3 |
4 | #define BLOCK 256
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | void call_build_hash_kernels(const float* points,
11 | int* hash_entries,
12 | short* hash_keys,
13 | int* neib_ents,
14 | float* barycentric,
15 | int* valid_entries,
16 | int* n_valid,
17 | size_t hash_cap,
18 | size_t N, size_t dim,
19 | cudaStream_t stream);
20 |
21 | #ifdef __cplusplus
22 | }
23 | #endif
24 |
25 | #endif
26 |
--------------------------------------------------------------------------------
/src/build_hash_wrapper.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "build_hash_kernel.h"
5 | #include "common.h"
6 |
7 | void build_hash_cuda(const torch::Tensor& th_points,
8 | torch::Tensor th_hash_entries,
9 | torch::Tensor th_hash_keys,
10 | torch::Tensor th_neib_ents,
11 | torch::Tensor th_barycentric,
12 | torch::Tensor th_valid_entries,
13 | torch::Tensor th_n_valid) {
14 |
15 | CHECK_INPUT(th_points)
16 | CHECK_INPUT(th_hash_entries)
17 | CHECK_INPUT(th_hash_keys)
18 | CHECK_INPUT(th_neib_ents)
19 | CHECK_INPUT(th_barycentric)
20 | CHECK_INPUT(th_valid_entries)
21 | CHECK_CONTIGUOUS(th_n_valid)
22 |
23 | CHECK_4DIMS(th_points)
24 | CHECK_2DIMS(th_hash_entries)
25 | CHECK_3DIMS(th_hash_keys)
26 | CHECK_4DIMS(th_neib_ents)
27 | CHECK_4DIMS(th_barycentric)
28 | CHECK_2DIMS(th_valid_entries)
29 | CHECK_2DIMS(th_n_valid)
30 |
31 | const float* points = DATA_PTR(th_points, float);
32 | int* hash_entries = DATA_PTR(th_hash_entries, int);
33 | short* hash_keys = DATA_PTR(th_hash_keys, short);
34 | int* neib_ents = DATA_PTR(th_neib_ents, int);
35 | float* barycentric = DATA_PTR(th_barycentric, float);
36 | int* valid_entries = DATA_PTR(th_valid_entries, int);
37 | int* n_valid = DATA_PTR(th_n_valid, int);
38 |
39 | const int B = th_points.size(0);
40 | const int dim = th_points.size(1);
41 | const int H = th_points.size(2);
42 | const int W = th_points.size(3);
43 | const int hash_cap = th_hash_entries.size(1);
44 |
45 | cudaError_t err;
46 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
47 |
48 | for (int b=0; b < B; ++b) {
49 | call_build_hash_kernels(
50 | points + (b * dim * H * W),
51 | hash_entries + (b * hash_cap),
52 | hash_keys + (b * hash_cap * dim),
53 | neib_ents + (b * (dim + 1) * H * W),
54 | barycentric + (b * (dim + 1) * H * W),
55 | valid_entries + (b * hash_cap),
56 | n_valid + b,
57 | hash_cap, H * W, dim, stream
58 | );
59 |
60 | err = cudaGetLastError();
61 | if (err != cudaSuccess) {
62 | fprintf(stderr, "build_hash CUDA kernel failure: %s\n", cudaGetErrorString(err));
63 | exit(-1);
64 | }
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | #ifdef TORCH_CHECK
7 | #define MYCHECK TORCH_CHECK
8 | #else
9 | #define MYCHECK AT_ASSERTM
10 | #endif
11 |
12 | #define CHECK_CUDA(x) MYCHECK(x.device().is_cuda(), #x " is not a CUDA tensor.")
13 | #define CHECK_CONTIGUOUS(x) MYCHECK(x.is_contiguous(), #x " is not contiguous.")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 | #define CHECK_2DIMS(x) MYCHECK((x.sizes().size() == 2), #x " is not 2 dimensional.")
18 | #define CHECK_3DIMS(x) MYCHECK((x.sizes().size() == 3), #x " is not 3 dimensional.")
19 | #define CHECK_4DIMS(x) MYCHECK((x.sizes().size() == 4), #x " is not 4 dimensional.")
20 |
21 | #ifdef USE_OLD_PYTORCH_DATA_MEMBER
22 | #define DATA_PTR(x,t) x.data()
23 | #else
24 | #define DATA_PTR(x,t) x.data_ptr()
25 | #endif
26 |
27 |
--------------------------------------------------------------------------------
/src/gfilt_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "hash_fns.cuh"
4 | #include "gfilt_kernel.h"
5 |
6 | template
7 | __global__ void splat(const float* values,
8 | const float* barycentric,
9 | const int* hash_entries,
10 | const int* neib_ents,
11 | float* hash_values, size_t N) {
12 |
13 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
14 | if (idx >= N) { return; }
15 |
16 | // We work with val_dim + 1 value channels here (one extra) because the
17 | // last channel keeps track of the normalization factor to use at the end.
18 | // During splatting, we splat 1 * bary into the last channel.
19 | float local_v[val_dim + 1];
20 | float local_b[ref_dim+1];
21 | local_v[val_dim] = 1.f;
22 |
23 | for(int i=0; i < val_dim; ++i) { local_v[i] = values[idx + N*i]; }
24 | for(int i=0; i < (ref_dim+1); ++i) { local_b[i] = barycentric[idx + N*i]; }
25 |
26 | // Splat this point onto each vertex of its surrounding simplex.
27 | for(int k=0; k
38 | __global__ void blur(float* out,
39 | const int* hash_entries,
40 | const int* valid_entries,
41 | const short* hash_keys,
42 | const float* hash_values,
43 | size_t hash_cap, int n_valid, size_t axis) {
44 |
45 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
46 | if (idx >= n_valid) { return; }
47 |
48 | // We work with val_dim + 1 value channels here (one extra) because the
49 | // last channel keeps track of the normalization factor to use at the end.
50 | float local_out[val_dim + 1];
51 |
52 | // The local key storage needs the normally-ignored value at the end so
53 | // that key[axis] is always a valid memory access.
54 | //short key[ref_dim];
55 | short key[ref_dim+1];
56 |
57 | const int& ind_c = valid_entries[idx];
58 | for(int i=0; i < val_dim+1; ++i) { local_out[i] = 0; }
59 |
60 | const short* key_c = &hash_keys[ind_c*ref_dim];
61 |
62 | for(int i=0; i(hash_entries, hash_keys, hash_cap, key);
66 |
67 | for(int i=0; i(hash_entries, hash_keys, hash_cap, key);
71 |
72 | if(ind_l >= 0 && ind_r >= 0) {
73 | for(int i=0; i < val_dim + 1; ++i) {
74 | local_out[i] = (hash_values[ind_l * (val_dim + 1) + i] +
75 | 2*hash_values[ind_c * (val_dim + 1) + i] +
76 | hash_values[ind_r * (val_dim + 1) + i]) * 0.25f;
77 | }
78 | } else if(ind_l >= 0) {
79 | for(int i=0; i < val_dim + 1; ++i) {
80 | local_out[i] = (hash_values[ind_l * (val_dim + 1) + i] +
81 | 2*hash_values[ind_c * (val_dim + 1) + i]) * 0.25f;
82 | }
83 | } else if(ind_r >= 0) {
84 | for(int i=0; i < val_dim + 1; ++i) {
85 | local_out[i] = (hash_values[ind_r * (val_dim + 1) + i] +
86 | 2*hash_values[ind_c * (val_dim + 1) + i]) * 0.25f;
87 | }
88 | } else {
89 | for(int i=0; i < val_dim + 1; ++i) {
90 | local_out[i] = hash_values[ind_c * (val_dim + 1) + i] * 0.5f;
91 | }
92 | }
93 |
94 | for(int i=0; i < val_dim + 1; ++i) { out[ind_c * (val_dim + 1) + i] = local_out[i]; }
95 | }
96 |
97 | template
98 | __global__ void slice(float* out,
99 | const float* barycentric,
100 | const int* hash_entries,
101 | const int* neib_ents,
102 | const float* hash_values, int N) {
103 |
104 | const size_t idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
105 | if (idx >= N) { return; }
106 |
107 | // We work with val_dim + 1 value channels here (one extra) because the
108 | // last channel keeps track of the normalization factor by which we will
109 | // scale the output.
110 | float local_out[val_dim + 1];
111 | for(int i=0; i < val_dim + 1; ++i) { local_out[i] = 0; }
112 |
113 | // Gather values from each of the surrounding simplex vertices.
114 | for(int k=0; k
129 | void _call_gfilt_kernels(const float* values, float* output,
130 | float* tmp_vals_1, float* tmp_vals_2,
131 | const int* hash_entries, const short* hash_keys,
132 | const int* neib_ents, const float* barycentric,
133 | const int* valid_entries, int n_valid,
134 | size_t hash_cap, size_t N, bool reverse, cudaStream_t stream) {
135 |
136 | splat<<>>(values,
137 | barycentric, hash_entries, neib_ents, tmp_vals_1, N);
138 |
139 | float* tmp_swap;
140 |
141 | for(int ax=reverse ? ref_dim : 0;
142 | ax <= ref_dim && ax >= 0;
143 | reverse ? --ax : ++ax
144 | ) {
145 | blur<<>>(tmp_vals_2, hash_entries, valid_entries, hash_keys,
147 | tmp_vals_1, hash_cap, n_valid, ax);
148 |
149 | tmp_swap = tmp_vals_1;
150 | tmp_vals_1 = tmp_vals_2;
151 | tmp_vals_2 = tmp_swap;
152 | }
153 |
154 | slice<<>>(output, barycentric, hash_entries, neib_ents, tmp_vals_2, N);
155 | }
156 |
157 | void call_gfilt_kernels(const float* values, float* output,
158 | float* tmp_vals_1, float* tmp_vals_2,
159 | const int* hash_entries, const short* hash_keys,
160 | const int* neib_ents, const float* barycentric,
161 | const int* valid_entries, int n_valid,
162 | size_t hash_cap, size_t N, size_t ref_dim,
163 | size_t val_dim, bool reverse, cudaStream_t stream) {
164 |
165 | #include "gfilt_dispatch_table.h"
166 | }
167 |
--------------------------------------------------------------------------------
/src/gfilt_cuda.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | void gfilt_cuda(const torch::Tensor& th_values, torch::Tensor th_output,
4 | torch::Tensor th_tmp_vals_1, torch::Tensor th_tmp_vals_2,
5 | const torch::Tensor& th_hash_entries, const torch::Tensor& th_hash_keys,
6 | const torch::Tensor& th_neib_ents, const torch::Tensor& th_barycentric,
7 | const torch::Tensor& th_valid_entries, const torch::Tensor& n_valid,
8 | int64_t ref_dim, bool reverse);
9 |
--------------------------------------------------------------------------------
/src/gfilt_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _GFK_H
2 | #define _GFK_H
3 |
4 | #define BLOCK 256
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | void call_gfilt_kernels(const float* values, float* output,
11 | float* tmp_vals_1, float* tmp_vals_2,
12 | const int* hash_entries, const short* hash_keys,
13 | const int* neib_ents, const float* barycentric,
14 | const int* valid_entries, int n_valid,
15 | size_t hash_cap, size_t N, size_t ref_dim,
16 | size_t val_dim, bool reverse, cudaStream_t stream);
17 |
18 | #ifdef __cplusplus
19 | }
20 | #endif
21 |
22 | #endif
23 |
--------------------------------------------------------------------------------
/src/gfilt_wrapper.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "gfilt_kernel.h"
5 | #include "common.h"
6 |
7 | void gfilt_cuda(const torch::Tensor& th_values, torch::Tensor th_output,
8 | torch::Tensor th_tmp_vals_1, torch::Tensor th_tmp_vals_2,
9 | const torch::Tensor& th_hash_entries, const torch::Tensor& th_hash_keys,
10 | const torch::Tensor& th_neib_ents, const torch::Tensor& th_barycentric,
11 | const torch::Tensor& th_valid_entries, const torch::Tensor& th_n_valid,
12 | int64_t ref_dim, bool reverse) {
13 |
14 | CHECK_4DIMS(th_values)
15 | CHECK_4DIMS(th_output)
16 | CHECK_2DIMS(th_tmp_vals_1)
17 | CHECK_2DIMS(th_tmp_vals_2)
18 |
19 | CHECK_2DIMS(th_hash_entries)
20 | CHECK_3DIMS(th_hash_keys)
21 | CHECK_4DIMS(th_neib_ents)
22 | CHECK_4DIMS(th_barycentric)
23 | CHECK_2DIMS(th_valid_entries)
24 | CHECK_CONTIGUOUS(th_n_valid)
25 | MYCHECK(th_n_valid.device().is_cpu(), "n_valid is not a CPU tensor.")
26 |
27 | const float* values = DATA_PTR(th_values, float);
28 | float* output = DATA_PTR(th_output, float);
29 | float* tmp_vals_1 = DATA_PTR(th_tmp_vals_1, float);
30 | float* tmp_vals_2 = DATA_PTR(th_tmp_vals_2, float);
31 | const int* hash_entries = DATA_PTR(th_hash_entries, int);
32 | const short* hash_keys = DATA_PTR(th_hash_keys, short);
33 | const int* neib_ents = DATA_PTR(th_neib_ents, int);
34 | const float* barycentric = DATA_PTR(th_barycentric, float);
35 | const int* valid_entries = DATA_PTR(th_valid_entries, int);
36 | int* n_valid = DATA_PTR(th_n_valid, int);
37 |
38 | const int B = th_values.size(0);
39 | const int H = th_values.size(2);
40 | const int W = th_values.size(3);
41 | const int val_dim = th_values.size(1);
42 | const int hash_cap = th_hash_entries.size(1);
43 |
44 | cudaError_t err;
45 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
46 |
47 | for (int b=0; b < B; ++b) {
48 | th_tmp_vals_1.fill_(0.f);
49 | th_tmp_vals_2.fill_(0.f);
50 |
51 | call_gfilt_kernels(
52 | values + (b * val_dim * H * W),
53 | output + (b * val_dim * H * W),
54 | tmp_vals_1, tmp_vals_2,
55 | hash_entries + (b * hash_cap),
56 | hash_keys + (b * hash_cap * ref_dim),
57 | neib_ents + (b * (ref_dim + 1) * H * W),
58 | barycentric + (b * (ref_dim + 1) * H * W),
59 | valid_entries + (b * hash_cap),
60 | n_valid[b],
61 | hash_cap, H * W, ref_dim, val_dim,
62 | reverse, stream
63 | );
64 |
65 | err = cudaGetLastError();
66 | if (err != cudaSuccess) {
67 | fprintf(stderr, "gfilt CUDA kernel failure: %s\n", cudaGetErrorString(err));
68 | exit(-1);
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/hash_fns.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "build_hash_kernel.h"
3 |
4 | const size_t MIN_QUAD_PROBES = 10000;
5 |
6 | inline dim3 cuda_gridsize(int n) {
7 | int k = (n - 1) / BLOCK + 1;
8 | int x = k;
9 | int y = 1;
10 |
11 | if(x > 65535) {
12 | x = ceil(sqrt(k));
13 | y = (n - 1) / (x * BLOCK) + 1;
14 | }
15 |
16 | return dim3(x, y, 1);
17 | }
18 |
19 | template
20 | __device__ inline unsigned int hash(const short* key) {
21 | unsigned int h = 0;
22 | /*
23 | for (int i=0; i < key_dim; ++i) {
24 | h ^= ((unsigned int)key[i]) << ((31/key_dim)*i);
25 | }
26 | */
27 | for (int i=0; i < key_dim; i++) {
28 | h += key[i];
29 | h = h * 2531011;
30 | }
31 | return h;
32 | }
33 |
34 | template
35 | __device__ inline bool key_cmp(const short* key1, const short* key2) {
36 | for(int i=0; i
43 | __device__ inline int hash_lookup(const int* entries, const short* keys, size_t capacity, const short* key) {
44 | unsigned int h = hash(key) % capacity;
45 | const unsigned int init_h = h;
46 | const size_t max_it = min((size_t)MIN_QUAD_PROBES, capacity);
47 |
48 | // The probing sequence here needs to match the one used for insertion,
49 | // otherwise bad things will happen.
50 | for(int iter=0; iter(&keys[entry*key_dim], key)) { return entry; }
53 | h = (init_h + iter*iter) % capacity;
54 | }
55 |
56 | for(int iter=0; iter(&keys[entry*key_dim], key)) { return entry; }
59 | h = (h+1) % capacity;
60 | }
61 |
62 | // We wrapped around without finding a matching key, the key is not in the
63 | // table.
64 | return -1;
65 | }
66 |
--------------------------------------------------------------------------------
/src/permutohedral.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "build_hash_cuda.h"
5 | #include "gfilt_cuda.h"
6 |
7 |
8 | PYBIND11_MODULE(permutohedral_ext, m) {
9 | m.def("gfilt_cuda", &gfilt_cuda, "High-dimensional Gaussian filter (CUDA)");
10 | m.def("build_hash_cuda", &build_hash_cuda, "High-dimensional Gaussian filter (CUDA)");
11 | }
12 |
13 | TORCH_LIBRARY(permutohedral_ext, m) {
14 | m.def("gfilt_cuda", &gfilt_cuda);
15 | m.def("build_hash_cuda", &build_hash_cuda);
16 | }
17 |
--------------------------------------------------------------------------------