├── LICENSE
├── README.md
├── boxfilter.py
├── demo.py
├── guided_filter.py
└── sample_images
├── cat.bmp
├── cave-flash.bmp
├── cave-noflash.bmp
├── filtered_cat.png
├── filtered_cave.png
├── filtered_mask.png
├── toy-mask.bmp
├── toy.bmp
└── tulips.bmp
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 teppei suzuki
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # guided-filter-pytorch
2 | PyTorch implementation of Guided Image Filtering.
3 | The implementation is based on [original matlab implementation](http://kaiminghe.com/eccv10/).
4 |
5 | # Example results
6 | ## Denoising
7 | 
8 |
9 | ## Structure transferring
10 | 
--------------------------------------------------------------------------------
/boxfilter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _diff_x(src, r):
4 | cum_src = src.cumsum(-2)
5 |
6 | left = cum_src[..., r:2*r + 1, :]
7 | middle = cum_src[..., 2*r + 1:, :] - cum_src[..., :-2*r - 1, :]
8 | right = cum_src[..., -1:, :] - cum_src[..., -2*r - 1:-r - 1, :]
9 |
10 | output = torch.cat([left, middle, right], -2)
11 |
12 | return output
13 |
14 | def _diff_y(src, r):
15 | cum_src = src.cumsum(-1)
16 |
17 | left = cum_src[..., r:2*r + 1]
18 | middle = cum_src[..., 2*r + 1:] - cum_src[..., :-2*r - 1]
19 | right = cum_src[..., -1:] - cum_src[..., -2*r - 1:-r - 1]
20 |
21 | output = torch.cat([left, middle, right], -1)
22 |
23 | return output
24 |
25 | def boxfilter2d(src, radius):
26 | return _diff_y(_diff_x(src, radius), radius)
27 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import torchvision.utils as tv_utils
4 |
5 | from guided_filter import GuidedFilter2d, FastGuidedFilter2d
6 |
7 | def structure_transferring(radius, eps, fast=False, out_filename="result.png"):
8 | img = plt.imread("sample_images/toy.bmp")
9 | mask = plt.imread("sample_images/toy-mask.bmp")[...,0]
10 | if fast:
11 | GF = FastGuidedFilter2d(radius, eps, 2)
12 | else:
13 | GF = GuidedFilter2d(radius, eps)
14 |
15 | tch_img = torch.from_numpy(img).permute(2, 0, 1)[None].float()
16 | tch_mask = torch.from_numpy(mask)[None, None].float()
17 |
18 | out = GF(tch_mask, tch_img)
19 |
20 | tv_utils.save_image(out, out_filename, normalize=True)
21 |
22 | def filtering(radius, eps, fast=False, out_filename="result.png"):
23 | img = plt.imread("sample_images/cat.bmp")
24 | if fast:
25 | GF = FastGuidedFilter2d(radius, eps, 2)
26 | else:
27 | GF = GuidedFilter2d(radius, eps)
28 |
29 | tch_img = torch.from_numpy(img)[None, None].float()
30 |
31 | out = GF(tch_img, tch_img)
32 |
33 | tv_utils.save_image(out, out_filename, normalize=True)
34 |
35 | def denoising(radius, eps, fast=False, out_filename="result.png"):
36 | img = plt.imread("sample_images/cave-noflash.bmp")
37 | guide = plt.imread("sample_images/cave-flash.bmp")
38 | if fast:
39 | GF = FastGuidedFilter2d(radius, eps, 2)
40 | else:
41 | GF = GuidedFilter2d(radius, eps)
42 |
43 | tch_img = torch.from_numpy(img).permute(2, 0, 1)[None].float()
44 | tch_guide = torch.from_numpy(guide).permute(2, 0, 1)[None].float()
45 |
46 | out = GF(tch_img, tch_guide)
47 |
48 | tv_utils.save_image(out, out_filename, normalize=True)
49 |
50 | if __name__ == "__main__":
51 | import argparse
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--task", choices=["transferring", "filtering", "denoising"], type=str)
54 | parser.add_argument("--radius", default=30, type=int)
55 | parser.add_argument("--eps", default=1e-4, type=float)
56 | parser.add_argument("--fast", action="store_true")
57 | parser.add_argument("--output", default="results.png", type=str)
58 | args = parser.parse_args()
59 |
60 | if args.task == "transferring":
61 | structure_transferring(args.radius, args.eps, args.fast, args.output)
62 | elif args.task == "filtering":
63 | filtering(args.radius, args.eps, args.fast, args.output)
64 | else:
65 | denoising(args.radius, args.eps, args.fast, args.output)
66 |
--------------------------------------------------------------------------------
/guided_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from boxfilter import boxfilter2d
6 |
7 | class GuidedFilter2d(nn.Module):
8 | def __init__(self, radius: int, eps: float):
9 | super().__init__()
10 | self.r = radius
11 | self.eps = eps
12 |
13 | def forward(self, x, guide):
14 | if guide.shape[1] == 3:
15 | return guidedfilter2d_color(guide, x, self.r, self.eps)
16 | elif guide.shape[1] == 1:
17 | return guidedfilter2d_gray(guide, x, self.r, self.eps)
18 | else:
19 | raise NotImplementedError
20 |
21 | class FastGuidedFilter2d(GuidedFilter2d):
22 | """Fast guided filter"""
23 | def __init__(self, radius: int, eps: float, s: int):
24 | super().__init__(radius, eps)
25 | self.s = s
26 |
27 | def forward(self, x, guide):
28 | if guide.shape[1] == 3:
29 | return guidedfilter2d_color(guide, x, self.r, self.eps, self.s)
30 | elif guide.shape[1] == 1:
31 | return guidedfilter2d_gray(guide, x, self.r, self.eps, self.s)
32 | else:
33 | raise NotImplementedError
34 |
35 | def guidedfilter2d_color(guide, src, radius, eps, scale=None):
36 | """guided filter for a color guide image
37 |
38 | Parameters
39 | -----
40 | guide: (B, 3, H, W)-dim torch.Tensor
41 | guide image
42 | src: (B, C, H, W)-dim torch.Tensor
43 | filtering image
44 | radius: int
45 | filter radius
46 | eps: float
47 | regularization coefficient
48 | """
49 | assert guide.shape[1] == 3
50 | if src.ndim == 3:
51 | src = src[:, None]
52 | if scale is not None:
53 | guide_sub = guide.clone()
54 | src = F.interpolate(src, scale_factor=1./scale, mode="nearest")
55 | guide = F.interpolate(guide, scale_factor=1./scale, mode="nearest")
56 | radius = radius // scale
57 |
58 | guide_r, guide_g, guide_b = torch.chunk(guide, 3, 1) # b x 1 x H x W
59 | ones = torch.ones_like(guide_r)
60 | N = boxfilter2d(ones, radius)
61 |
62 | mean_I = boxfilter2d(guide, radius) / N # b x 3 x H x W
63 | mean_I_r, mean_I_g, mean_I_b = torch.chunk(mean_I, 3, 1) # b x 1 x H x W
64 |
65 | mean_p = boxfilter2d(src, radius) / N # b x C x H x W
66 |
67 | mean_Ip_r = boxfilter2d(guide_r * src, radius) / N # b x C x H x W
68 | mean_Ip_g = boxfilter2d(guide_g * src, radius) / N # b x C x H x W
69 | mean_Ip_b = boxfilter2d(guide_b * src, radius) / N # b x C x H x W
70 |
71 | cov_Ip_r = mean_Ip_r - mean_I_r * mean_p # b x C x H x W
72 | cov_Ip_g = mean_Ip_g - mean_I_g * mean_p # b x C x H x W
73 | cov_Ip_b = mean_Ip_b - mean_I_b * mean_p # b x C x H x W
74 |
75 | var_I_rr = boxfilter2d(guide_r * guide_r, radius) / N - mean_I_r * mean_I_r + eps # b x 1 x H x W
76 | var_I_rg = boxfilter2d(guide_r * guide_g, radius) / N - mean_I_r * mean_I_g # b x 1 x H x W
77 | var_I_rb = boxfilter2d(guide_r * guide_b, radius) / N - mean_I_r * mean_I_b # b x 1 x H x W
78 | var_I_gg = boxfilter2d(guide_g * guide_g, radius) / N - mean_I_g * mean_I_g + eps # b x 1 x H x W
79 | var_I_gb = boxfilter2d(guide_g * guide_b, radius) / N - mean_I_g * mean_I_b # b x 1 x H x W
80 | var_I_bb = boxfilter2d(guide_b * guide_b, radius) / N - mean_I_b * mean_I_b + eps # b x 1 x H x W
81 |
82 | # determinant
83 | cov_det = var_I_rr * var_I_gg * var_I_bb \
84 | + var_I_rg * var_I_gb * var_I_rb \
85 | + var_I_rb * var_I_rg * var_I_gb \
86 | - var_I_rb * var_I_gg * var_I_rb \
87 | - var_I_rg * var_I_rg * var_I_bb \
88 | - var_I_rr * var_I_gb * var_I_gb # b x 1 x H x W
89 |
90 | # inverse
91 | inv_var_I_rr = (var_I_gg * var_I_bb - var_I_gb * var_I_gb) / cov_det # b x 1 x H x W
92 | inv_var_I_rg = - (var_I_rg * var_I_bb - var_I_rb * var_I_gb) / cov_det # b x 1 x H x W
93 | inv_var_I_rb = (var_I_rg * var_I_gb - var_I_rb * var_I_gg) / cov_det # b x 1 x H x W
94 | inv_var_I_gg = (var_I_rr * var_I_bb - var_I_rb * var_I_rb) / cov_det # b x 1 x H x W
95 | inv_var_I_gb = - (var_I_rr * var_I_gb - var_I_rb * var_I_rg) / cov_det # b x 1 x H x W
96 | inv_var_I_bb = (var_I_rr * var_I_gg - var_I_rg * var_I_rg) / cov_det # b x 1 x H x W
97 |
98 | inv_sigma = torch.stack([
99 | torch.stack([inv_var_I_rr, inv_var_I_rg, inv_var_I_rb], 1),
100 | torch.stack([inv_var_I_rg, inv_var_I_gg, inv_var_I_gb], 1),
101 | torch.stack([inv_var_I_rb, inv_var_I_gb, inv_var_I_bb], 1)
102 | ], 1).squeeze(-3) # b x 3 x 3 x H x W
103 |
104 | cov_Ip = torch.stack([cov_Ip_r, cov_Ip_g, cov_Ip_b], 1) # b x 3 x C x H x W
105 |
106 | a = torch.einsum("bichw,bijhw->bjchw", (cov_Ip, inv_sigma))
107 | b = mean_p - a[:, 0] * mean_I_r - a[:, 1] * mean_I_g - a[:, 2] * mean_I_b # b x C x H x W
108 |
109 | mean_a = torch.stack([boxfilter2d(a[:, i], radius) / N for i in range(3)], 1)
110 | mean_b = boxfilter2d(b, radius) / N
111 |
112 | if scale is not None:
113 | guide = guide_sub
114 | mean_a = torch.stack([F.interpolate(mean_a[:, i], guide.shape[-2:], mode='bilinear') for i in range(3)], 1)
115 | mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear')
116 |
117 | q = torch.einsum("bichw,bihw->bchw", (mean_a, guide)) + mean_b
118 |
119 | return q
120 |
121 | def guidedfilter2d_gray(guide, src, radius, eps, scale=None):
122 | """guided filter for a gray scale guide image
123 |
124 | Parameters
125 | -----
126 | guide: (B, 1, H, W)-dim torch.Tensor
127 | guide image
128 | src: (B, C, H, W)-dim torch.Tensor
129 | filtering image
130 | radius: int
131 | filter radius
132 | eps: float
133 | regularization coefficient
134 | """
135 | if guide.ndim == 3:
136 | guide = guide[:, None]
137 | if src.ndim == 3:
138 | src = src[:, None]
139 |
140 | if scale is not None:
141 | guide_sub = guide.clone()
142 | src = F.interpolate(src, scale_factor=1./scale, mode="nearest")
143 | guide = F.interpolate(guide, scale_factor=1./scale, mode="nearest")
144 | radius = radius // scale
145 |
146 | ones = torch.ones_like(guide)
147 | N = boxfilter2d(ones, radius)
148 |
149 | mean_I = boxfilter2d(guide, radius) / N
150 | mean_p = boxfilter2d(src, radius) / N
151 | mean_Ip = boxfilter2d(guide*src, radius) / N
152 | cov_Ip = mean_Ip - mean_I * mean_p
153 |
154 | mean_II = boxfilter2d(guide*guide, radius) / N
155 | var_I = mean_II - mean_I * mean_I
156 |
157 | a = cov_Ip / (var_I + eps)
158 | b = mean_p - a * mean_I
159 |
160 | mean_a = boxfilter2d(a, radius) / N
161 | mean_b = boxfilter2d(b, radius) / N
162 |
163 | if scale is not None:
164 | guide = guide_sub
165 | mean_a = F.interpolate(mean_a, guide.shape[-2:], mode='bilinear')
166 | mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear')
167 |
168 | q = mean_a * guide + mean_b
169 | return q
170 |
--------------------------------------------------------------------------------
/sample_images/cat.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cat.bmp
--------------------------------------------------------------------------------
/sample_images/cave-flash.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cave-flash.bmp
--------------------------------------------------------------------------------
/sample_images/cave-noflash.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cave-noflash.bmp
--------------------------------------------------------------------------------
/sample_images/filtered_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_cat.png
--------------------------------------------------------------------------------
/sample_images/filtered_cave.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_cave.png
--------------------------------------------------------------------------------
/sample_images/filtered_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_mask.png
--------------------------------------------------------------------------------
/sample_images/toy-mask.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/toy-mask.bmp
--------------------------------------------------------------------------------
/sample_images/toy.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/toy.bmp
--------------------------------------------------------------------------------
/sample_images/tulips.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/tulips.bmp
--------------------------------------------------------------------------------