├── LICENSE ├── README.md ├── boxfilter.py ├── demo.py ├── guided_filter.py └── sample_images ├── cat.bmp ├── cave-flash.bmp ├── cave-noflash.bmp ├── filtered_cat.png ├── filtered_cave.png ├── filtered_mask.png ├── toy-mask.bmp ├── toy.bmp └── tulips.bmp /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 teppei suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # guided-filter-pytorch 2 | PyTorch implementation of Guided Image Filtering. 3 | The implementation is based on [original matlab implementation](http://kaiminghe.com/eccv10/). 4 | 5 | # Example results 6 | ## Denoising 7 | 8 | 9 | ## Structure transferring 10 | -------------------------------------------------------------------------------- /boxfilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _diff_x(src, r): 4 | cum_src = src.cumsum(-2) 5 | 6 | left = cum_src[..., r:2*r + 1, :] 7 | middle = cum_src[..., 2*r + 1:, :] - cum_src[..., :-2*r - 1, :] 8 | right = cum_src[..., -1:, :] - cum_src[..., -2*r - 1:-r - 1, :] 9 | 10 | output = torch.cat([left, middle, right], -2) 11 | 12 | return output 13 | 14 | def _diff_y(src, r): 15 | cum_src = src.cumsum(-1) 16 | 17 | left = cum_src[..., r:2*r + 1] 18 | middle = cum_src[..., 2*r + 1:] - cum_src[..., :-2*r - 1] 19 | right = cum_src[..., -1:] - cum_src[..., -2*r - 1:-r - 1] 20 | 21 | output = torch.cat([left, middle, right], -1) 22 | 23 | return output 24 | 25 | def boxfilter2d(src, radius): 26 | return _diff_y(_diff_x(src, radius), radius) 27 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import torchvision.utils as tv_utils 4 | 5 | from guided_filter import GuidedFilter2d, FastGuidedFilter2d 6 | 7 | def structure_transferring(radius, eps, fast=False, out_filename="result.png"): 8 | img = plt.imread("sample_images/toy.bmp") 9 | mask = plt.imread("sample_images/toy-mask.bmp")[...,0] 10 | if fast: 11 | GF = FastGuidedFilter2d(radius, eps, 2) 12 | else: 13 | GF = GuidedFilter2d(radius, eps) 14 | 15 | tch_img = torch.from_numpy(img).permute(2, 0, 1)[None].float() 16 | tch_mask = torch.from_numpy(mask)[None, None].float() 17 | 18 | out = GF(tch_mask, tch_img) 19 | 20 | tv_utils.save_image(out, out_filename, normalize=True) 21 | 22 | def filtering(radius, eps, fast=False, out_filename="result.png"): 23 | img = plt.imread("sample_images/cat.bmp") 24 | if fast: 25 | GF = FastGuidedFilter2d(radius, eps, 2) 26 | else: 27 | GF = GuidedFilter2d(radius, eps) 28 | 29 | tch_img = torch.from_numpy(img)[None, None].float() 30 | 31 | out = GF(tch_img, tch_img) 32 | 33 | tv_utils.save_image(out, out_filename, normalize=True) 34 | 35 | def denoising(radius, eps, fast=False, out_filename="result.png"): 36 | img = plt.imread("sample_images/cave-noflash.bmp") 37 | guide = plt.imread("sample_images/cave-flash.bmp") 38 | if fast: 39 | GF = FastGuidedFilter2d(radius, eps, 2) 40 | else: 41 | GF = GuidedFilter2d(radius, eps) 42 | 43 | tch_img = torch.from_numpy(img).permute(2, 0, 1)[None].float() 44 | tch_guide = torch.from_numpy(guide).permute(2, 0, 1)[None].float() 45 | 46 | out = GF(tch_img, tch_guide) 47 | 48 | tv_utils.save_image(out, out_filename, normalize=True) 49 | 50 | if __name__ == "__main__": 51 | import argparse 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--task", choices=["transferring", "filtering", "denoising"], type=str) 54 | parser.add_argument("--radius", default=30, type=int) 55 | parser.add_argument("--eps", default=1e-4, type=float) 56 | parser.add_argument("--fast", action="store_true") 57 | parser.add_argument("--output", default="results.png", type=str) 58 | args = parser.parse_args() 59 | 60 | if args.task == "transferring": 61 | structure_transferring(args.radius, args.eps, args.fast, args.output) 62 | elif args.task == "filtering": 63 | filtering(args.radius, args.eps, args.fast, args.output) 64 | else: 65 | denoising(args.radius, args.eps, args.fast, args.output) 66 | -------------------------------------------------------------------------------- /guided_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from boxfilter import boxfilter2d 6 | 7 | class GuidedFilter2d(nn.Module): 8 | def __init__(self, radius: int, eps: float): 9 | super().__init__() 10 | self.r = radius 11 | self.eps = eps 12 | 13 | def forward(self, x, guide): 14 | if guide.shape[1] == 3: 15 | return guidedfilter2d_color(guide, x, self.r, self.eps) 16 | elif guide.shape[1] == 1: 17 | return guidedfilter2d_gray(guide, x, self.r, self.eps) 18 | else: 19 | raise NotImplementedError 20 | 21 | class FastGuidedFilter2d(GuidedFilter2d): 22 | """Fast guided filter""" 23 | def __init__(self, radius: int, eps: float, s: int): 24 | super().__init__(radius, eps) 25 | self.s = s 26 | 27 | def forward(self, x, guide): 28 | if guide.shape[1] == 3: 29 | return guidedfilter2d_color(guide, x, self.r, self.eps, self.s) 30 | elif guide.shape[1] == 1: 31 | return guidedfilter2d_gray(guide, x, self.r, self.eps, self.s) 32 | else: 33 | raise NotImplementedError 34 | 35 | def guidedfilter2d_color(guide, src, radius, eps, scale=None): 36 | """guided filter for a color guide image 37 | 38 | Parameters 39 | ----- 40 | guide: (B, 3, H, W)-dim torch.Tensor 41 | guide image 42 | src: (B, C, H, W)-dim torch.Tensor 43 | filtering image 44 | radius: int 45 | filter radius 46 | eps: float 47 | regularization coefficient 48 | """ 49 | assert guide.shape[1] == 3 50 | if src.ndim == 3: 51 | src = src[:, None] 52 | if scale is not None: 53 | guide_sub = guide.clone() 54 | src = F.interpolate(src, scale_factor=1./scale, mode="nearest") 55 | guide = F.interpolate(guide, scale_factor=1./scale, mode="nearest") 56 | radius = radius // scale 57 | 58 | guide_r, guide_g, guide_b = torch.chunk(guide, 3, 1) # b x 1 x H x W 59 | ones = torch.ones_like(guide_r) 60 | N = boxfilter2d(ones, radius) 61 | 62 | mean_I = boxfilter2d(guide, radius) / N # b x 3 x H x W 63 | mean_I_r, mean_I_g, mean_I_b = torch.chunk(mean_I, 3, 1) # b x 1 x H x W 64 | 65 | mean_p = boxfilter2d(src, radius) / N # b x C x H x W 66 | 67 | mean_Ip_r = boxfilter2d(guide_r * src, radius) / N # b x C x H x W 68 | mean_Ip_g = boxfilter2d(guide_g * src, radius) / N # b x C x H x W 69 | mean_Ip_b = boxfilter2d(guide_b * src, radius) / N # b x C x H x W 70 | 71 | cov_Ip_r = mean_Ip_r - mean_I_r * mean_p # b x C x H x W 72 | cov_Ip_g = mean_Ip_g - mean_I_g * mean_p # b x C x H x W 73 | cov_Ip_b = mean_Ip_b - mean_I_b * mean_p # b x C x H x W 74 | 75 | var_I_rr = boxfilter2d(guide_r * guide_r, radius) / N - mean_I_r * mean_I_r + eps # b x 1 x H x W 76 | var_I_rg = boxfilter2d(guide_r * guide_g, radius) / N - mean_I_r * mean_I_g # b x 1 x H x W 77 | var_I_rb = boxfilter2d(guide_r * guide_b, radius) / N - mean_I_r * mean_I_b # b x 1 x H x W 78 | var_I_gg = boxfilter2d(guide_g * guide_g, radius) / N - mean_I_g * mean_I_g + eps # b x 1 x H x W 79 | var_I_gb = boxfilter2d(guide_g * guide_b, radius) / N - mean_I_g * mean_I_b # b x 1 x H x W 80 | var_I_bb = boxfilter2d(guide_b * guide_b, radius) / N - mean_I_b * mean_I_b + eps # b x 1 x H x W 81 | 82 | # determinant 83 | cov_det = var_I_rr * var_I_gg * var_I_bb \ 84 | + var_I_rg * var_I_gb * var_I_rb \ 85 | + var_I_rb * var_I_rg * var_I_gb \ 86 | - var_I_rb * var_I_gg * var_I_rb \ 87 | - var_I_rg * var_I_rg * var_I_bb \ 88 | - var_I_rr * var_I_gb * var_I_gb # b x 1 x H x W 89 | 90 | # inverse 91 | inv_var_I_rr = (var_I_gg * var_I_bb - var_I_gb * var_I_gb) / cov_det # b x 1 x H x W 92 | inv_var_I_rg = - (var_I_rg * var_I_bb - var_I_rb * var_I_gb) / cov_det # b x 1 x H x W 93 | inv_var_I_rb = (var_I_rg * var_I_gb - var_I_rb * var_I_gg) / cov_det # b x 1 x H x W 94 | inv_var_I_gg = (var_I_rr * var_I_bb - var_I_rb * var_I_rb) / cov_det # b x 1 x H x W 95 | inv_var_I_gb = - (var_I_rr * var_I_gb - var_I_rb * var_I_rg) / cov_det # b x 1 x H x W 96 | inv_var_I_bb = (var_I_rr * var_I_gg - var_I_rg * var_I_rg) / cov_det # b x 1 x H x W 97 | 98 | inv_sigma = torch.stack([ 99 | torch.stack([inv_var_I_rr, inv_var_I_rg, inv_var_I_rb], 1), 100 | torch.stack([inv_var_I_rg, inv_var_I_gg, inv_var_I_gb], 1), 101 | torch.stack([inv_var_I_rb, inv_var_I_gb, inv_var_I_bb], 1) 102 | ], 1).squeeze(-3) # b x 3 x 3 x H x W 103 | 104 | cov_Ip = torch.stack([cov_Ip_r, cov_Ip_g, cov_Ip_b], 1) # b x 3 x C x H x W 105 | 106 | a = torch.einsum("bichw,bijhw->bjchw", (cov_Ip, inv_sigma)) 107 | b = mean_p - a[:, 0] * mean_I_r - a[:, 1] * mean_I_g - a[:, 2] * mean_I_b # b x C x H x W 108 | 109 | mean_a = torch.stack([boxfilter2d(a[:, i], radius) / N for i in range(3)], 1) 110 | mean_b = boxfilter2d(b, radius) / N 111 | 112 | if scale is not None: 113 | guide = guide_sub 114 | mean_a = torch.stack([F.interpolate(mean_a[:, i], guide.shape[-2:], mode='bilinear') for i in range(3)], 1) 115 | mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear') 116 | 117 | q = torch.einsum("bichw,bihw->bchw", (mean_a, guide)) + mean_b 118 | 119 | return q 120 | 121 | def guidedfilter2d_gray(guide, src, radius, eps, scale=None): 122 | """guided filter for a gray scale guide image 123 | 124 | Parameters 125 | ----- 126 | guide: (B, 1, H, W)-dim torch.Tensor 127 | guide image 128 | src: (B, C, H, W)-dim torch.Tensor 129 | filtering image 130 | radius: int 131 | filter radius 132 | eps: float 133 | regularization coefficient 134 | """ 135 | if guide.ndim == 3: 136 | guide = guide[:, None] 137 | if src.ndim == 3: 138 | src = src[:, None] 139 | 140 | if scale is not None: 141 | guide_sub = guide.clone() 142 | src = F.interpolate(src, scale_factor=1./scale, mode="nearest") 143 | guide = F.interpolate(guide, scale_factor=1./scale, mode="nearest") 144 | radius = radius // scale 145 | 146 | ones = torch.ones_like(guide) 147 | N = boxfilter2d(ones, radius) 148 | 149 | mean_I = boxfilter2d(guide, radius) / N 150 | mean_p = boxfilter2d(src, radius) / N 151 | mean_Ip = boxfilter2d(guide*src, radius) / N 152 | cov_Ip = mean_Ip - mean_I * mean_p 153 | 154 | mean_II = boxfilter2d(guide*guide, radius) / N 155 | var_I = mean_II - mean_I * mean_I 156 | 157 | a = cov_Ip / (var_I + eps) 158 | b = mean_p - a * mean_I 159 | 160 | mean_a = boxfilter2d(a, radius) / N 161 | mean_b = boxfilter2d(b, radius) / N 162 | 163 | if scale is not None: 164 | guide = guide_sub 165 | mean_a = F.interpolate(mean_a, guide.shape[-2:], mode='bilinear') 166 | mean_b = F.interpolate(mean_b, guide.shape[-2:], mode='bilinear') 167 | 168 | q = mean_a * guide + mean_b 169 | return q 170 | -------------------------------------------------------------------------------- /sample_images/cat.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cat.bmp -------------------------------------------------------------------------------- /sample_images/cave-flash.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cave-flash.bmp -------------------------------------------------------------------------------- /sample_images/cave-noflash.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/cave-noflash.bmp -------------------------------------------------------------------------------- /sample_images/filtered_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_cat.png -------------------------------------------------------------------------------- /sample_images/filtered_cave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_cave.png -------------------------------------------------------------------------------- /sample_images/filtered_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/filtered_mask.png -------------------------------------------------------------------------------- /sample_images/toy-mask.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/toy-mask.bmp -------------------------------------------------------------------------------- /sample_images/toy.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/toy.bmp -------------------------------------------------------------------------------- /sample_images/tulips.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/guided-filter-pytorch/1acb13f0710d53f88ab427c18ec1951663e06efb/sample_images/tulips.bmp --------------------------------------------------------------------------------