├── .gitignore ├── README.md ├── __init__.py ├── core.py ├── core_warp.py ├── example ├── butterfly.png ├── noise_input.png └── noise_optimized.png ├── interactive.py ├── test.py ├── test_answer ├── down_down_aa.mat ├── down_down_butterfly_irregular_aa.mat ├── down_down_butterfly_irregular_noaa.mat ├── down_down_irregular_aa.mat ├── down_down_irregular_noaa.mat ├── down_down_noaa.mat ├── down_down_small_aa.mat ├── down_down_small_noaa.mat ├── down_down_x2_aa.mat ├── down_down_x3_aa.mat ├── down_down_x4_aa.mat ├── down_down_x5_aa.mat ├── gen_test.m ├── up_up_bottomright_noaa.mat ├── up_up_butterfly_irregular_noaa.mat ├── up_up_irregular_aa.mat ├── up_up_irregular_noaa.mat └── up_up_topleft_noaa.mat ├── test_benchmark.py ├── test_gradient.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | lab/ 3 | legacy/ 4 | logs/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # Editor 66 | .vscode 67 | 68 | # Temp files 69 | *.swp 70 | *.m~ 71 | 72 | # Output images 73 | *.png 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bicubic-pytorch 2 | 3 | Image resizing is one of the essential operations in computer vision and image processing. 4 | While MATLAB's `imresize` function is used as a standard, implementations from other libraries (e.g., PIL, OpenCV, PyTorch, ...) are not consistent with the MATLAB, especially for a bicubic kernel. 5 | The goal of this repository is to provide a **MATLAB-like bicubic interpolation** in a widely-used PyTorch framework. 6 | Any issues are welcomed to make it better! 7 | 8 | 9 | The merits of our implementation are: 10 | * Easy to use. You only need one python file to copy. 11 | * Consistent with MATLAB's `imresize('bicubic')`, with or without antialiasing. 12 | * Support arbitrary resizing factors on different dimensions. 13 | * Very fast, support GPU acceleration and batching. 14 | * Fully differentiable with respect to input and output images. 15 | 16 | ## Updates 17 | 18 | Previous version had some trouble with fractional scale factors (It was okay to select the output sizes which can result to fractional scale factors). 19 | Version 1.2.0 fixes the issue and improves the accuracy. 20 | 21 | ## Environment and Dependency 22 | 23 | This repository is tested under: 24 | * Ubuntu 18.04 25 | * PyTorch 1.5.1 (minimum 0.4.0 is required) 26 | * CUDA 10.2 27 | * MATLAB R2019b 28 | 29 | However, we avoid using any version-dependent coding style to make our method compatible with various environments. 30 | If you are not going to generate any test cases, MATLAB is not required. 31 | You do not need any additional dependencies to use this repository. 32 | 33 | 34 | ## How to use 35 | 36 | We provide two options to use this package in your project. 37 | The first way is a Git submodule system, which helps you to keep track of important updates. 38 | ```bash 39 | # In your project repository 40 | $ git submodule add https://github.com/thstkdgus35/bicubic_pytorch 41 | 42 | # To get an update 43 | $ cd bicubic_pytorch 44 | $ git pull origin 45 | ``` 46 | 47 | ```python 48 | # In your python code 49 | import torch 50 | from bicubic_pytorch import core 51 | 52 | x = torch.randn(1, 3, 224, 224) 53 | y = core.imresize(x, scale=0.5) 54 | ``` 55 | 56 | Otherwise, copy `core.py` from the repository as follows: 57 | 58 | ```python 59 | import torch 60 | from torch import cuda 61 | import core 62 | 63 | # We support 2, 3, and 4-dim Tensors 64 | # (H x W, C x H x W, and B x C x H x W, respectively). 65 | # Larger batch sizes are also supported. 66 | x = torch.randn(1, 3, 456, 321) 67 | 68 | # If the input is on a CUDA device, all computations will be done using the GPU. 69 | if cuda.is_available(): 70 | x = x.cuda() 71 | 72 | # Resize by scale 73 | x_resized_1 = core.imresize(x, scale=0.456) 74 | 75 | # Resize by resolution (456, 321) -> (123, 456) 76 | x_resized_2 = core.imresize(x, sizes=(123, 456)) 77 | 78 | # Resize without antialiasing (Not compatible with MATLAB) 79 | x_resized_3 = core.imresize(x, scale=0.456, antialiasing=False) 80 | ``` 81 | 82 | 83 | ## How to test 84 | 85 | You can run `test.py` to check the consistency with MATLAB's `imresize`. 86 | 87 | ```bash 88 | $ python test.py 89 | ``` 90 | 91 | You can generate more test cases using `test_answer/gen_test.m`. 92 | 93 | ```bash 94 | $ cd test_answer 95 | $ matlab -nodisplay < gen_test.m 96 | ``` 97 | 98 | 99 | ## Automatic differentiation 100 | 101 | Our implementation is fully differentiable. 102 | We provide a test script to optimize a random noise Tensor `n` so that `imresize(n)` be a target image. 103 | Please run `test_gradient.py` to test the example. 104 | 105 | ```bash 106 | $ python test_gradient.py 107 | ``` 108 | 109 | You can check the input noise from `example/noise_input.png` and the optimized image from `example/noise_optimized.png`. 110 | 111 | ![noise](example/noise_input.png) 112 | ![optimized](example/noise_optimized.png) 113 | ![target](example/butterfly.png) 114 | 115 | From the left, input noise, optimized, and target images. 116 | 117 | ## Acknowledgement 118 | 119 | The repositories below have provided excellent insights. 120 | 121 | * [https://github.com/fatheral/matlab_imresize](https://github.com/fatheral/matlab_imresize) 122 | * [https://github.com/sefibk/KernelGAN](https://github.com/sefibk/KernelGAN) 123 | 124 | ## Citation 125 | 126 | If you have found our implementation useful, please star and cite this repository: 127 | ``` 128 | @misc{son2020bicubic, 129 | author = {Son, Sanghyun}, 130 | title = {bicubic-pytorch}, 131 | year = {2020}, 132 | publisher = {GitHub}, 133 | journal = {GitHub repository}, 134 | howpublished = {\usr{https://github.com/thstkdgus35/bicubic-pytorch}}, 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['core', 'core_warp'] 2 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A standalone PyTorch implementation for fast and efficient bicubic resampling. 3 | The resulting values are the same to MATLAB function imresize('bicubic'). 4 | 5 | ## Author: Sanghyun Son 6 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) 7 | ## Version: 1.2.0 8 | ## Last update: July 9th, 2020 (KST) 9 | 10 | Depencency: torch 11 | 12 | Example:: 13 | >>> import torch 14 | >>> import core 15 | >>> x = torch.arange(16).float().view(1, 1, 4, 4) 16 | >>> y = core.imresize(x, sizes=(3, 3)) 17 | >>> print(y) 18 | tensor([[[[ 0.7506, 2.1004, 3.4503], 19 | [ 6.1505, 7.5000, 8.8499], 20 | [11.5497, 12.8996, 14.2494]]]]) 21 | ''' 22 | 23 | import math 24 | import typing 25 | 26 | import torch 27 | from torch.nn import functional as F 28 | 29 | __all__ = ['imresize'] 30 | 31 | _I = typing.Optional[int] 32 | _D = typing.Optional[torch.dtype] 33 | 34 | def nearest_contribution(x: torch.Tensor) -> torch.Tensor: 35 | range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5)) 36 | cont = range_around_0.to(dtype=x.dtype) 37 | return cont 38 | 39 | def linear_contribution(x: torch.Tensor) -> torch.Tensor: 40 | ax = x.abs() 41 | range_01 = ax.le(1) 42 | cont = (1 - ax) * range_01.to(dtype=x.dtype) 43 | return cont 44 | 45 | def cubic_contribution(x: torch.Tensor, a: float=-0.5) -> torch.Tensor: 46 | ax = x.abs() 47 | ax2 = ax * ax 48 | ax3 = ax * ax2 49 | 50 | range_01 = ax.le(1) 51 | range_12 = torch.logical_and(ax.gt(1), ax.le(2)) 52 | 53 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 54 | cont_01 = cont_01 * range_01.to(dtype=x.dtype) 55 | 56 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 57 | cont_12 = cont_12 * range_12.to(dtype=x.dtype) 58 | 59 | cont = cont_01 + cont_12 60 | return cont 61 | 62 | def gaussian_contribution(x: torch.Tensor, sigma: float=2.0) -> torch.Tensor: 63 | range_3sigma = (x.abs() <= 3 * sigma + 1) 64 | # Normalization will be done after 65 | cont = torch.exp(-x.pow(2) / (2 * sigma**2)) 66 | cont = cont * range_3sigma.to(dtype=x.dtype) 67 | return cont 68 | 69 | def discrete_kernel( 70 | kernel: str, scale: float, antialiasing: bool=True) -> torch.Tensor: 71 | 72 | ''' 73 | For downsampling with integer scale only. 74 | ''' 75 | downsampling_factor = int(1 / scale) 76 | if kernel == 'cubic': 77 | kernel_size_orig = 4 78 | else: 79 | raise ValueError('Pass!') 80 | 81 | if antialiasing: 82 | kernel_size = kernel_size_orig * downsampling_factor 83 | else: 84 | kernel_size = kernel_size_orig 85 | 86 | if downsampling_factor % 2 == 0: 87 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) 88 | else: 89 | kernel_size -= 1 90 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) 91 | 92 | with torch.no_grad(): 93 | r = torch.linspace(-a, a, steps=kernel_size) 94 | k = cubic_contribution(r).view(-1, 1) 95 | k = torch.matmul(k, k.t()) 96 | k /= k.sum() 97 | 98 | return k 99 | 100 | def reflect_padding( 101 | x: torch.Tensor, 102 | dim: int, 103 | pad_pre: int, 104 | pad_post: int) -> torch.Tensor: 105 | 106 | ''' 107 | Apply reflect padding to the given Tensor. 108 | Note that it is slightly different from the PyTorch functional.pad, 109 | where boundary elements are used only once. 110 | Instead, we follow the MATLAB implementation 111 | which uses boundary elements twice. 112 | 113 | For example, 114 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 115 | while our implementation yields [a, a, b, c, d, d]. 116 | ''' 117 | b, c, h, w = x.size() 118 | if dim == 2 or dim == -2: 119 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) 120 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) 121 | for p in range(pad_pre): 122 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) 123 | for p in range(pad_post): 124 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) 125 | else: 126 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) 127 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) 128 | for p in range(pad_pre): 129 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) 130 | for p in range(pad_post): 131 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) 132 | 133 | return padding_buffer 134 | 135 | def padding( 136 | x: torch.Tensor, 137 | dim: int, 138 | pad_pre: int, 139 | pad_post: int, 140 | padding_type: typing.Optional[str]='reflect') -> torch.Tensor: 141 | 142 | if padding_type is None: 143 | return x 144 | elif padding_type == 'reflect': 145 | x_pad = reflect_padding(x, dim, pad_pre, pad_post) 146 | else: 147 | raise ValueError('{} padding is not supported!'.format(padding_type)) 148 | 149 | return x_pad 150 | 151 | def get_padding( 152 | base: torch.Tensor, 153 | kernel_size: int, 154 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 155 | 156 | base = base.long() 157 | r_min = base.min() 158 | r_max = base.max() + kernel_size - 1 159 | 160 | if r_min <= 0: 161 | pad_pre = -r_min 162 | pad_pre = pad_pre.item() 163 | base += pad_pre 164 | else: 165 | pad_pre = 0 166 | 167 | if r_max >= x_size: 168 | pad_post = r_max - x_size + 1 169 | pad_post = pad_post.item() 170 | else: 171 | pad_post = 0 172 | 173 | return pad_pre, pad_post, base 174 | 175 | def get_weight( 176 | dist: torch.Tensor, 177 | kernel_size: int, 178 | kernel: str='cubic', 179 | sigma: float=2.0, 180 | antialiasing_factor: float=1) -> torch.Tensor: 181 | 182 | buffer_pos = dist.new_zeros(kernel_size, len(dist)) 183 | for idx, buffer_sub in enumerate(buffer_pos): 184 | buffer_sub.copy_(dist - idx) 185 | 186 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 187 | buffer_pos *= antialiasing_factor 188 | if kernel == 'cubic': 189 | weight = cubic_contribution(buffer_pos) 190 | elif kernel == 'gaussian': 191 | weight = gaussian_contribution(buffer_pos, sigma=sigma) 192 | else: 193 | raise ValueError('{} kernel is not supported!'.format(kernel)) 194 | 195 | weight /= weight.sum(dim=0, keepdim=True) 196 | return weight 197 | 198 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 199 | # Resize height 200 | if dim == 2 or dim == -2: 201 | k = (kernel_size, 1) 202 | h_out = x.size(-2) - kernel_size + 1 203 | w_out = x.size(-1) 204 | # Resize width 205 | else: 206 | k = (1, kernel_size) 207 | h_out = x.size(-2) 208 | w_out = x.size(-1) - kernel_size + 1 209 | 210 | unfold = F.unfold(x, k) 211 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 212 | return unfold 213 | 214 | def reshape_input( 215 | x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, _I, _I]: 216 | 217 | if x.dim() == 4: 218 | b, c, h, w = x.size() 219 | elif x.dim() == 3: 220 | c, h, w = x.size() 221 | b = None 222 | elif x.dim() == 2: 223 | h, w = x.size() 224 | b = c = None 225 | else: 226 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) 227 | 228 | x = x.view(-1, 1, h, w) 229 | return x, b, c, h, w 230 | 231 | def reshape_output( 232 | x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: 233 | 234 | rh = x.size(-2) 235 | rw = x.size(-1) 236 | # Back to the original dimension 237 | if b is not None: 238 | x = x.view(b, c, rh, rw) # 4-dim 239 | else: 240 | if c is not None: 241 | x = x.view(c, rh, rw) # 3-dim 242 | else: 243 | x = x.view(rh, rw) # 2-dim 244 | 245 | return x 246 | 247 | def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: 248 | if x.dtype != torch.float32 or x.dtype != torch.float64: 249 | dtype = x.dtype 250 | x = x.float() 251 | else: 252 | dtype = None 253 | 254 | return x, dtype 255 | 256 | def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: 257 | if dtype is not None: 258 | if not dtype.is_floating_point: 259 | x = x.round() 260 | # To prevent over/underflow when converting types 261 | if dtype is torch.uint8: 262 | x = x.clamp(0, 255) 263 | 264 | x = x.to(dtype=dtype) 265 | 266 | return x 267 | 268 | def resize_1d( 269 | x: torch.Tensor, 270 | dim: int, 271 | size: typing.Optional[int], 272 | scale: typing.Optional[float], 273 | kernel: str='cubic', 274 | sigma: float=2.0, 275 | padding_type: str='reflect', 276 | antialiasing: bool=True) -> torch.Tensor: 277 | 278 | ''' 279 | Args: 280 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 281 | dim (int): 282 | scale (float): 283 | size (int): 284 | 285 | Return: 286 | ''' 287 | # Identity case 288 | if scale == 1: 289 | return x 290 | 291 | # Default bicubic kernel with antialiasing (only when downsampling) 292 | if kernel == 'cubic': 293 | kernel_size = 4 294 | else: 295 | kernel_size = math.floor(6 * sigma) 296 | 297 | if antialiasing and (scale < 1): 298 | antialiasing_factor = scale 299 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 300 | else: 301 | antialiasing_factor = 1 302 | 303 | # We allow margin to both sizes 304 | kernel_size += 2 305 | 306 | # Weights only depend on the shape of input and output, 307 | # so we do not calculate gradients here. 308 | with torch.no_grad(): 309 | pos = torch.linspace( 310 | 0, size - 1, steps=size, dtype=x.dtype, device=x.device, 311 | ) 312 | pos = (pos + 0.5) / scale - 0.5 313 | base = pos.floor() - (kernel_size // 2) + 1 314 | dist = pos - base 315 | weight = get_weight( 316 | dist, 317 | kernel_size, 318 | kernel=kernel, 319 | sigma=sigma, 320 | antialiasing_factor=antialiasing_factor, 321 | ) 322 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) 323 | 324 | # To backpropagate through x 325 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) 326 | unfold = reshape_tensor(x_pad, dim, kernel_size) 327 | # Subsampling first 328 | if dim == 2 or dim == -2: 329 | sample = unfold[..., base, :] 330 | weight = weight.view(1, kernel_size, sample.size(2), 1) 331 | else: 332 | sample = unfold[..., base] 333 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 334 | 335 | # Apply the kernel 336 | x = sample * weight 337 | x = x.sum(dim=1, keepdim=True) 338 | return x 339 | 340 | def downsampling_2d( 341 | x: torch.Tensor, 342 | k: torch.Tensor, 343 | scale: int, 344 | padding_type: str='reflect') -> torch.Tensor: 345 | 346 | c = x.size(1) 347 | k_h = k.size(-2) 348 | k_w = k.size(-1) 349 | 350 | k = k.to(dtype=x.dtype, device=x.device) 351 | k = k.view(1, 1, k_h, k_w) 352 | k = k.repeat(c, c, 1, 1) 353 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 354 | e = e.view(c, c, 1, 1) 355 | k = k * e 356 | 357 | pad_h = (k_h - scale) // 2 358 | pad_w = (k_w - scale) // 2 359 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) 360 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) 361 | y = F.conv2d(x, k, padding=0, stride=scale) 362 | return y 363 | 364 | def imresize( 365 | x: torch.Tensor, 366 | scale: typing.Optional[float]=None, 367 | sizes: typing.Optional[typing.Tuple[int, int]]=None, 368 | kernel: typing.Union[str, torch.Tensor]='cubic', 369 | sigma: float=2, 370 | rotation_degree: float=0, 371 | padding_type: str='reflect', 372 | antialiasing: bool=True) -> torch.Tensor: 373 | 374 | ''' 375 | Args: 376 | x (torch.Tensor): 377 | scale (float): 378 | sizes (tuple(int, int)): 379 | kernel (str, default='cubic'): 380 | sigma (float, default=2): 381 | rotation_degree (float, default=0): 382 | padding_type (str, default='reflect'): 383 | antialiasing (bool, default=True): 384 | 385 | Return: 386 | torch.Tensor: 387 | ''' 388 | 389 | if scale is None and sizes is None: 390 | raise ValueError('One of scale or sizes must be specified!') 391 | if scale is not None and sizes is not None: 392 | raise ValueError('Please specify scale or sizes to avoid conflict!') 393 | 394 | x, b, c, h, w = reshape_input(x) 395 | 396 | if sizes is None: 397 | ''' 398 | # Check if we can apply the convolution algorithm 399 | scale_inv = 1 / scale 400 | if isinstance(kernel, str) and scale_inv.is_integer(): 401 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 402 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 403 | raise ValueError( 404 | 'An integer downsampling factor ' 405 | 'should be used with a predefined kernel!' 406 | ) 407 | ''' 408 | # Determine output size 409 | sizes = (math.ceil(h * scale), math.ceil(w * scale)) 410 | scales = (scale, scale) 411 | 412 | if scale is None: 413 | scales = (sizes[0] / h, sizes[1] / w) 414 | 415 | x, dtype = cast_input(x) 416 | 417 | if isinstance(kernel, str): 418 | # Shared keyword arguments across dimensions 419 | kwargs = { 420 | 'kernel': kernel, 421 | 'sigma': sigma, 422 | 'padding_type': padding_type, 423 | 'antialiasing': antialiasing, 424 | } 425 | # Core resizing module 426 | x = resize_1d(x, -2, size=sizes[0], scale=scales[0], **kwargs) 427 | x = resize_1d(x, -1, size=sizes[1], scale=scales[1], **kwargs) 428 | elif isinstance(kernel, torch.Tensor): 429 | x = downsampling_2d(x, kernel, scale=int(1 / scale)) 430 | 431 | x = reshape_output(x, b, c) 432 | x = cast_output(x, dtype) 433 | return x 434 | 435 | if __name__ == '__main__': 436 | # Just for debugging 437 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200) 438 | a = torch.arange(64).float().view(1, 1, 8, 8) 439 | z = imresize(a, 0.5) 440 | print(z) 441 | #a = torch.arange(16).float().view(1, 1, 4, 4) 442 | ''' 443 | a = torch.zeros(1, 1, 4, 4) 444 | a[..., 0, 0] = 100 445 | a[..., 1, 0] = 10 446 | a[..., 0, 1] = 1 447 | a[..., 0, -1] = 100 448 | a = torch.zeros(1, 1, 4, 4) 449 | a[..., -1, -1] = 100 450 | a[..., -2, -1] = 10 451 | a[..., -1, -2] = 1 452 | a[..., -1, 0] = 100 453 | ''' 454 | #b = imresize(a, sizes=(3, 8), antialiasing=False) 455 | #c = imresize(a, sizes=(11, 13), antialiasing=True) 456 | #c = imresize(a, sizes=(4, 4), antialiasing=False, kernel='gaussian', sigma=1) 457 | #print(a) 458 | #print(b) 459 | #print(c) 460 | 461 | #r = discrete_kernel('cubic', 1 / 3) 462 | #print(r) 463 | ''' 464 | a = torch.arange(225).float().view(1, 1, 15, 15) 465 | imresize(a, sizes=[5, 5]) 466 | ''' 467 | -------------------------------------------------------------------------------- /core_warp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing 3 | 4 | import core 5 | 6 | import cv2 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | 11 | def contribution_2d(x: torch.Tensor, kernel: str='cubic') -> torch.Tensor: 12 | ''' 13 | Args: 14 | x (torch.Tensor): (2, k, N), where x[0] is the x-coordinate. 15 | kernel (str): 16 | 17 | Return 18 | torch.Tensor: (k^2, N) 19 | ''' 20 | if kernel == 'nearest': 21 | weight = core.nearest_contribution(x) 22 | elif kernel == 'bilinear': 23 | weight = core.linear_contribution(x) 24 | elif kernel == 'bicubic': 25 | weight = core.cubic_contribution(x) 26 | 27 | weight_x = weight[0].unsqueeze(0) 28 | weight_y = weight[1].unsqueeze(1) 29 | weight = weight_x * weight_y 30 | weight = weight.view(-1, weight.size(-1)) 31 | weight = weight / weight.sum(0, keepdim=True) 32 | return weight 33 | 34 | def warp_by_size( 35 | x: torch.Tensor, 36 | m: torch.Tensor, 37 | sizes: typing.Tuple[int, int], 38 | kernel: str='bicubic', 39 | padding_type: str='reflect', 40 | fill_value: int=0) -> torch.Tensor: 41 | 42 | kernels = {'nearest': 1, 'bilinear': 2, 'bicubic': 4} 43 | if kernel in kernels: 44 | k = kernels[kernel] 45 | pad = k // 2 46 | else: 47 | raise ValueError('kernel: {} is not supported!'.format(kernel)) 48 | 49 | dkwargs = {'device': x.device, 'requires_grad': False} 50 | # Construct the target coordinates 51 | # The target coordinates do not require gradients 52 | pos = torch.arange(sizes[0] * sizes[1], **dkwargs) 53 | pos_i = (pos // sizes[1]).float() 54 | pos_j = (pos % sizes[1]).float() 55 | # Map the target coordinates to the source coordinates 56 | # This implements the backward warping 57 | pos_tar = torch.stack([pos_j, pos_i, torch.ones_like(pos_i)], dim=0) 58 | pos_src = torch.matmul(m.inverse(), pos_tar) 59 | pos_src = pos_src[:2] / pos_src[-1, :] 60 | # Out of the image 61 | pos_bound = pos_src.new_tensor([x.size(-1), x.size(-2)]) - 0.5 62 | pos_bound.unsqueeze_(-1) 63 | pos_in = torch.logical_and(pos_src.ge(-0.5), pos_src.lt(pos_bound)) 64 | pos_in = pos_in.all(0) 65 | # Remove the outside region and compensate subpixel shift 66 | sub = (k % 2) / 2 67 | pos_src = pos_src[..., pos_in] 68 | pos_src_sub = pos_src - sub 69 | pos_discrete = pos_src_sub.ceil().long() 70 | pos_frac = pos_src_sub - pos_src.floor() 71 | pos_frac.unsqueeze_(1) 72 | # (2, 1, HW) 73 | pos_w = torch.linspace(pad - k + 1, pad, k, **dkwargs) 74 | pos_w = pos_w.view(1, -1, 1).repeat(2, 1, 1) 75 | pos_w = pos_frac - pos_w 76 | weight = contribution_2d(pos_w, kernel=kernel) 77 | weight.unsqueeze_(0) 78 | 79 | # Calculate the exact sampling point 80 | idx = pos_discrete[0] + (x.size(-1) + 1 - k % 2) * pos_discrete[1] 81 | 82 | # (B, k^2, HW) 83 | x = core.padding(x, -2, pad, pad, padding_type=padding_type) 84 | x = core.padding(x, -1, pad, pad, padding_type=padding_type) 85 | x = F.unfold(x, (k, k)) 86 | sample = x[..., idx] 87 | 88 | y = sample * weight 89 | y = y.sum(dim=1) 90 | out = y.new_full((y.size(0), pos_in.size(0)), fill_value) 91 | out.masked_scatter_(pos_in, y) 92 | out = out.view(-1, 1, *sizes) 93 | return out 94 | 95 | def warp( 96 | x: torch.Tensor, 97 | m: torch.Tensor, 98 | sizes: typing.Union[typing.Tuple[int, int], str, None]=None, 99 | kernel: str='bicubic', 100 | padding_type: str='reflect', 101 | fill_value: int=0) -> torch.Tensor: 102 | 103 | x, b, c, h, w = core.reshape_input(x) 104 | x, dtype = core.cast_input(x) 105 | m = m.to(x.device) 106 | 107 | if sizes is None: 108 | sizes = (h, w) 109 | elif isinstance(sizes, str) and sizes == 'auto': 110 | with torch.no_grad(): 111 | corners = m.new_tensor([ 112 | [-0.5, -0.5, w - 0.5, w - 0.5], 113 | [-0.5, h - 0.5, -0.5, h - 0.5], 114 | [1, 1, 1, 1], 115 | ]) 116 | corners = torch.matmul(m, corners) 117 | corners = corners / corners[-1, :] 118 | y_min = corners[1].min() + 0.5 119 | x_min = corners[0].min() + 0.5 120 | h_new = math.floor(corners[1].max() - y_min + 0.5) 121 | w_new = math.floor(corners[0].max() - x_min + 0.5) 122 | m_comp = m.new_tensor([[1, 0, -x_min], [0, 1, -y_min], [0, 0, 1]]) 123 | m = torch.matmul(m_comp, m) 124 | sizes = (h_new, w_new) 125 | 126 | elif not isinstance(sizes, tuple): 127 | raise ValueError('sizes:', sizes, 'is not supported!') 128 | 129 | x = warp_by_size( 130 | x, 131 | m, 132 | sizes, 133 | kernel=kernel, 134 | padding_type=padding_type, 135 | fill_value=fill_value, 136 | ) 137 | x = core.reshape_output(x, b, c) 138 | x = core.cast_output(x, dtype) 139 | return x 140 | 141 | 142 | if __name__ == '__main__': 143 | import os 144 | import utils 145 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200) 146 | #x = torch.arange(64).float().view(1, 1, 8, 8) 147 | x = torch.arange(16).float().view(1, 1, 4, 4) 148 | #x = utils.get_img('example/butterfly.png') 149 | #x.requires_grad = True 150 | #m = torch.Tensor([[3.2, 0.016, -68], [1.23, 1.7, -54], [0.008, 0.0001, 1]]) 151 | #m = torch.Tensor([[2.33e-01, 3.97e-3, 3], [-4.49e-1, 2.49e-1, 1.15e2], [-2.95e-3, 1.55e-5, 1]]) 152 | m = torch.Tensor([[2, 0, 0], [0, 2, 0], [0, 0, 1]]) 153 | y = warp(x, m, sizes='auto', kernel='bicubic', fill_value=-1) 154 | z = core.imresize(x, scale=2, kernel='cubic') 155 | print(y) 156 | print(z) 157 | #os.makedirs('dummy', exist_ok=True) 158 | #utils.save_img(y, 'dummy/warp.png') 159 | -------------------------------------------------------------------------------- /example/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/example/butterfly.png -------------------------------------------------------------------------------- /example/noise_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/example/noise_input.png -------------------------------------------------------------------------------- /example/noise_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/example/noise_optimized.png -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import argparse 4 | import typing 5 | 6 | import numpy as np 7 | import imageio 8 | from PIL import Image 9 | 10 | import cv2 11 | import torch 12 | 13 | import core_warp 14 | import utils 15 | 16 | from PyQt5.QtWidgets import QApplication 17 | from PyQt5.QtWidgets import QMainWindow 18 | from PyQt5.QtGui import QImage 19 | from PyQt5.QtGui import QPixmap 20 | from PyQt5.QtGui import QPainter 21 | from PyQt5.QtGui import QPen 22 | from PyQt5.QtGui import QBrush 23 | from PyQt5.QtCore import Qt 24 | 25 | 26 | class Interactive(QMainWindow): 27 | 28 | def __init__(self, app: QApplication, img: str) -> None: 29 | super().__init__() 30 | self.setStyleSheet('background-color: gray;') 31 | self.margin = 300 32 | img = Image.open(img) 33 | #w, h = img.size 34 | #img = img.resize((2 * w, 2 * h), Image.NEAREST) 35 | self.img = np.array(img) 36 | self.img_tensor = utils.np2tensor(self.img).cuda() 37 | self.img_h = self.img.shape[0] 38 | self.img_w = self.img.shape[1] 39 | 40 | self.offset_h = self.margin 41 | self.offset_w = self.img_w + 2 * self.margin 42 | 43 | window_h = self.img_h + 2 * self.margin 44 | window_w = 2 * self.img_w + 3 * self.margin 45 | 46 | monitor_resolution = app.desktop().screenGeometry() 47 | screen_h = monitor_resolution.height() 48 | screen_w = monitor_resolution.width() 49 | 50 | screen_offset_h = (screen_h - window_h) // 2 51 | screen_offset_w = (screen_w - window_w) // 2 52 | 53 | self.setGeometry(screen_offset_w, screen_offset_h, window_w, window_h) 54 | self.reset_cps() 55 | self.line_order = ('tl', 'tr', 'br', 'bl') 56 | self.grab = None 57 | 58 | self.inter = cv2.INTER_CUBIC 59 | self.inter_idx = 2 60 | self.backend = 'opencv' 61 | self.update() 62 | return 63 | 64 | def reset_cps(self) -> None: 65 | self.cps = { 66 | 'tl': (0, 0), 67 | 'tr': (0, self.img_w - 1), 68 | 'bl': (self.img_h - 1, 0), 69 | 'br': (self.img_h - 1, self.img_w - 1), 70 | } 71 | return 72 | 73 | def keyPressEvent(self, e) -> None: 74 | if e.key() == Qt.Key_Escape: 75 | self.close() 76 | 77 | if e.key() == Qt.Key_I: 78 | self.inter_idx = (self.inter_idx + 1) % 3 79 | if self.inter_idx == 0: 80 | self.inter = cv2.INTER_NEAREST 81 | elif self.inter_idx == 1: 82 | self.inter = cv2.INTER_LINEAR 83 | else: 84 | self.inter = cv2.INTER_CUBIC 85 | elif e.key() == Qt.Key_M: 86 | if self.backend == 'opencv': 87 | self.backend = 'core' 88 | elif self.backend == 'core': 89 | self.backend = 'opencv' 90 | elif e.key() == Qt.Key_R: 91 | self.reset_cps() 92 | 93 | self.update() 94 | return 95 | 96 | def mousePressEvent(self, e) -> None: 97 | is_left = e.buttons() & Qt.LeftButton 98 | if is_left: 99 | threshold = 20 100 | min_dist = 987654321 101 | for key, val in self.cps.items(): 102 | y, x = val 103 | dy = e.y() - y - self.offset_h 104 | dx = e.x() - x - self.offset_w 105 | dist = dy ** 2 + dx ** 2 106 | if dist < min_dist: 107 | min_dist = dist 108 | self.grab = key 109 | 110 | if min_dist > threshold ** 2: 111 | self.grab = None 112 | 113 | return 114 | 115 | def get_matrix(self) -> np.array: 116 | points_from = np.array([ 117 | [0, 0], 118 | [self.img_w - 1, 0], 119 | [0, self.img_h - 1], 120 | [self.img_w - 1, self.img_h - 1], 121 | ]).astype(np.float32) 122 | points_to = np.array([ 123 | [self.cps['tl'][1], self.cps['tl'][0]], 124 | [self.cps['tr'][1], self.cps['tr'][0]], 125 | [self.cps['bl'][1], self.cps['bl'][0]], 126 | [self.cps['br'][1], self.cps['br'][0]], 127 | ]).astype(np.float32) 128 | m = cv2.getPerspectiveTransform(points_from, points_to) 129 | return m 130 | 131 | def get_dimension( 132 | self, 133 | m: np.array) -> typing.Tuple[float, float, float, float]: 134 | 135 | ''' 136 | 137 | ''' 138 | 139 | ''' 140 | What is a difference between corners and corner_points? 141 | corners: 142 | Actual corners of a rectangular image. 143 | Determine the image size. 144 | corner_points: 145 | The point coordinates. 146 | Determine the pixel position. 147 | ''' 148 | corners = np.array([ 149 | [-0.5, -0.5, self.img_w - 0.5, self.img_w - 0.5], 150 | [-0.5, self.img_h - 0.5, -0.5, self.img_h - 0.5], 151 | [1, 1, 1, 1], 152 | ]) 153 | corners = np.matmul(m, corners) 154 | corners /= corners[-1, :] 155 | y_min = corners[1].min() + 0.5 156 | x_min = corners[0].min() + 0.5 157 | h_new = math.floor(corners[1].max() - y_min + 0.5) 158 | w_new = math.floor(corners[0].max() - x_min + 0.5) 159 | ''' 160 | corner_points = np.array([ 161 | [0, 0, self.img_w - 1, self.img_w - 1], 162 | [0, self.img_h - 1, 0, self.img_h - 1], 163 | [1, 1, 1, 1], 164 | ]) 165 | corner_points = np.matmul(m, corner_points) 166 | corner_points /= corner_points[-1, :] 167 | y_min = corner_points[1].min() 168 | x_min = corner_points[0].min() 169 | h_new = math.floor(corner_points[1].max() - y_min) 170 | w_new = math.floor(corner_points[0].max() - x_min) 171 | ''' 172 | return y_min, x_min, h_new, w_new 173 | 174 | def mouseMoveEvent(self, e) -> None: 175 | if self.grab is not None: 176 | y_old, x_old = self.cps[self.grab] 177 | y_new = e.y() - self.offset_h 178 | x_new = e.x() - self.offset_w 179 | self.cps[self.grab] = (y_new, x_new) 180 | 181 | is_convex = True 182 | #cross = None 183 | for i, pos in enumerate(self.line_order): 184 | y1, x1 = self.cps[pos] 185 | y2, x2 = self.cps[self.line_order[(i + 1) % 4]] 186 | y3, x3 = self.cps[self.line_order[(i + 2) % 4]] 187 | dx1 = x2 - x1 188 | dy1 = y2 - y1 189 | dx2 = x3 - x2 190 | dy2 = y3 - y2 191 | cross_new = dx1 * dy2 - dy1 * dx2 192 | if cross_new < 6000: 193 | is_convex = False 194 | break 195 | 196 | if not is_convex: 197 | self.cps[self.grab] = (y_old, x_old) 198 | 199 | self.update() 200 | return 201 | 202 | def mouseReleaseEvent(self, e) -> None: 203 | if self.grab is not None: 204 | self.grab = None 205 | 206 | return 207 | 208 | def paintEvent(self, e) -> None: 209 | if self.inter == cv2.INTER_NEAREST: 210 | inter_method = 'Nearest' 211 | elif self.inter == cv2.INTER_LINEAR: 212 | inter_method = 'Bilinear' 213 | elif self.inter == cv2.INTER_CUBIC: 214 | inter_method = 'Bicubic' 215 | 216 | self.setWindowTitle( 217 | 'Interpolation: {} / backend: {}'.format(inter_method, self.backend) 218 | ) 219 | 220 | qimg = QImage( 221 | self.img, 222 | self.img_w, 223 | self.img_h, 224 | 3 * self.img_w, 225 | QImage.Format_RGB888, 226 | ) 227 | qpix = QPixmap(qimg) 228 | 229 | qp = QPainter() 230 | qp.begin(self) 231 | qp.drawPixmap(self.margin, self.margin, self.img_w, self.img_h, qpix) 232 | 233 | m = self.get_matrix() 234 | y_min, x_min, h_new, w_new = self.get_dimension(m) 235 | mc = np.array([[1, 0, -x_min], [0, 1, -y_min], [0, 0, 1]]) 236 | m = np.matmul(mc, m) 237 | if self.backend == 'opencv': 238 | warp = cv2.warpPerspective( 239 | self.img, m, (w_new, h_new), flags=self.inter, 240 | ) 241 | elif self.backend == 'core': 242 | warp = core_warp.warp( 243 | self.img_tensor, 244 | torch.Tensor(m), 245 | sizes=(h_new, w_new), 246 | kernel=inter_method.lower(), 247 | fill_value=0.5 248 | ) 249 | warp = utils.tensor2np(warp) 250 | 251 | qimg_warp = QImage(warp, w_new, h_new, 3 * w_new, QImage.Format_RGB888) 252 | qpix_warp = QPixmap(qimg_warp) 253 | qp.drawPixmap( 254 | self.offset_w + x_min, 255 | self.offset_h + y_min, 256 | w_new, 257 | h_new, 258 | qpix_warp, 259 | ) 260 | ''' 261 | for i, pos in enumerate(self.line_order): 262 | j = (i + 1) % 4 263 | y, x = self.cps[pos] 264 | y = y + self.offset_h 265 | x = x + self.offset_w 266 | y_next, x_next = self.cps[self.line_order[j]] 267 | y_next = y_next + self.offset_h 268 | x_next = x_next + self.offset_w 269 | qp.drawLine(x, y, x_next, y_next) 270 | ''' 271 | center_y = self.offset_h + self.img_h // 2 272 | center_x = self.offset_w + self.img_w // 2 273 | 274 | pen_blue = QPen(Qt.blue, 5) 275 | pen_white = QPen(Qt.white, 10) 276 | text_size = 20 277 | #brush = QBrush(Qt.red, Qt.SolidPattern) 278 | #qp.setBrush(brush) 279 | for key, val in self.cps.items(): 280 | y, x = val 281 | y = y + self.offset_h 282 | x = x + self.offset_w 283 | qp.setPen(pen_blue) 284 | #qp.drawEllipse(x, y, 3, 3) 285 | qp.drawPoint(x, y) 286 | qp.setPen(pen_white) 287 | dy = y - center_y 288 | dx = x - center_x 289 | dl = math.sqrt(dy ** 2 + dx ** 2) / 10 290 | qp.drawText( 291 | x + (dx / dl) - text_size // 2, 292 | y + (dy / dl) - text_size // 2, 293 | text_size, 294 | text_size, 295 | int(Qt.AlignCenter), 296 | key, 297 | ) 298 | 299 | qp.end() 300 | return 301 | 302 | 303 | def main() -> None: 304 | parser = argparse.ArgumentParser() 305 | parser.add_argument('--img', type=str, default='example/butterfly_corners.png') 306 | parser.add_argument('--full', action='store_true') 307 | cfg = parser.parse_args() 308 | 309 | app = QApplication(sys.argv) 310 | sess = Interactive(app, cfg.img) 311 | 312 | if cfg.full: 313 | sess.showFullScreen() 314 | else: 315 | sess.show() 316 | 317 | sys.exit(app.exec_()) 318 | 319 | if __name__ == '__main__': 320 | main() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import unittest 3 | from scipy import io 4 | 5 | import core 6 | import utils 7 | 8 | import torch 9 | from torch import cuda 10 | 11 | 12 | class TestBicubic(unittest.TestCase): 13 | ''' 14 | Why do we have to split CUDA? 15 | ''' 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__(*args, **kwargs) 19 | self.input_small = torch.arange(16).view(1, 1, 4, 4).float() 20 | self.input_square = torch.arange(64).view(1, 1, 8, 8).float() 21 | self.input_rect = torch.arange(80).view(1, 1, 8, 10).float() 22 | self.input_15x15 = torch.arange(225).view(1, 1, 15, 15).float() 23 | 24 | self.input_topleft = torch.zeros(1, 1, 4, 4).float() 25 | self.input_topleft[..., 0, 0] = 100 26 | self.input_topleft[..., 1, 0] = 10 27 | self.input_topleft[..., 0, 1] = 1 28 | self.input_topleft[..., 0, 3] = 100 29 | 30 | self.input_bottomright = torch.zeros(1, 1, 4, 4).float() 31 | self.input_bottomright[..., 3, 3] = 100 32 | self.input_bottomright[..., 2, 3] = 10 33 | self.input_bottomright[..., 3, 2] = 1 34 | self.input_bottomright[..., 3, 0] = 100 35 | 36 | self.butterfly = utils.get_img(path.join('example', 'butterfly.png')) 37 | 38 | if cuda.is_available(): 39 | self.test_cuda = True 40 | self.input_small_cuda = self.input_small.cuda() 41 | self.input_square_cuda = self.input_square.cuda() 42 | self.input_rect_cuda = self.input_rect.cuda() 43 | self.input_15x15_cuda = self.input_15x15.cuda() 44 | self.input_topleft_cuda = self.input_topleft.cuda() 45 | self.input_bottomright_cuda = self.input_bottomright.cuda() 46 | self.butterfly_cuda = self.butterfly.cuda() 47 | else: 48 | self.test_cuda = False 49 | 50 | # You can use different functions for testing your implementation. 51 | self.imresize = core.imresize 52 | self.eps = 1e-3 53 | 54 | def get_answer(self, case: str) -> torch.Tensor: 55 | mat = io.loadmat(path.join('test_answer', case + '.mat')) 56 | tensor = torch.Tensor(mat[case]) 57 | if tensor.dim() == 3: 58 | tensor = tensor.permute(2, 0, 1) 59 | 60 | while tensor.dim() < 4: 61 | tensor.unsqueeze_(0) 62 | 63 | return tensor 64 | 65 | def _check_diff(self, x: torch.Tensor, y: torch.Tensor) -> None: 66 | diff = torch.norm(x - y, 2).item() 67 | if diff > self.eps: 68 | print('Implementation:') 69 | print(x) 70 | print('MATLAB reference:') 71 | print(y) 72 | raise ArithmeticError( 73 | 'Difference is not negligible!: {}'.format(diff), 74 | ) 75 | else: 76 | print('Allowable difference: {:.4e} < {:.4e}'.format(diff, self.eps)) 77 | 78 | return 79 | 80 | def check_diff(self, x: torch.Tensor, answer: str) -> None: 81 | y = self.get_answer(answer).to(dtype=x.dtype, device=x.device) 82 | self._check_diff(x, y) 83 | return 84 | 85 | def test_consistency_down_down_x4_large_noaa(self) -> None: 86 | x = torch.randn(1, 3, 2048, 2048) 87 | with utils.Timer( 88 | '(2048, 2048) RGB to (512, 512) ' 89 | 'without AA (Cubic Conv. vs. Naive): {}'): 90 | 91 | x = self.imresize(x, scale=0.25, antialiasing=False) 92 | y = self.imresize(x, sizes=(512, 512), antialiasing=False) 93 | 94 | self._check_diff(x, y) 95 | return 96 | 97 | def test_cuda_consistency_down_down_x4_large_noaa(self) -> None: 98 | if self.test_cuda is False: 99 | return 100 | 101 | x = torch.randn(1, 3, 2048, 2048).cuda() 102 | with utils.Timer( 103 | '(2048, 2048) RGB to (512, 512) ' 104 | 'without AA (Cubic Conv. vs. Naive) using CUDA: {}'): 105 | 106 | x = self.imresize(x, scale=0.25, antialiasing=False) 107 | y = self.imresize(x, sizes=(512, 512), antialiasing=False) 108 | 109 | self._check_diff(x, y) 110 | return 111 | 112 | def test_consistency_down_down_x4_large_aa(self) -> None: 113 | x = torch.randn(1, 3, 2048, 2048) 114 | with utils.Timer( 115 | '(2048, 2048) RGB to (512, 512) ' 116 | 'with AA (Cubic Conv. vs. Naive): {}'): 117 | 118 | x = self.imresize(x, scale=0.25, antialiasing=True) 119 | y = self.imresize(x, sizes=(512, 512), antialiasing=True) 120 | 121 | self._check_diff(x, y) 122 | return 123 | 124 | def test_cuda_consistency_down_down_x4_large_aa(self) -> None: 125 | if self.test_cuda is False: 126 | return 127 | 128 | x = torch.randn(1, 3, 2048, 2048).cuda() 129 | with utils.Timer( 130 | '(2048, 2048) RGB to (512, 512) ' 131 | 'with AA (Cubic Conv. vs. Naive) using CUDA: {}'): 132 | 133 | x = self.imresize(x, scale=0.25, antialiasing=True) 134 | y = self.imresize(x, sizes=(512, 512), antialiasing=True) 135 | 136 | self._check_diff(x, y) 137 | return 138 | 139 | def test_down_down_x4_large_noaa(self) -> None: 140 | x = torch.randn(1, 3, 2048, 2048) 141 | with utils.Timer( 142 | '(2048, 2048) RGB to (512, 512) without AA (Cubic Conv.): {}'): 143 | 144 | x = self.imresize(x, scale=0.25, antialiasing=False) 145 | 146 | return 147 | 148 | def test_cuda_down_down_x4_large_noaa(self) -> None: 149 | if self.test_cuda is False: 150 | return 151 | 152 | x = torch.randn(1, 3, 2048, 2048).cuda() 153 | with utils.Timer( 154 | '(2048, 2048) RGB to (512, 512) ' 155 | 'without AA using CUDA (Cubic Conv.): {}'): 156 | 157 | x = self.imresize(x, scale=0.25, antialiasing=False) 158 | 159 | return 160 | 161 | def test_down_down_x4_naive_large_noaa(self) -> None: 162 | x = torch.randn(1, 3, 2048, 2048) 163 | with utils.Timer( 164 | '(2048, 2048) RGB to (512, 512) ' 165 | 'without AA (Naive): {}'): 166 | 167 | x = self.imresize(x, sizes=(512, 512), antialiasing=False) 168 | 169 | return 170 | 171 | def test_cuda_down_down_x4_naive_large_noaa(self) -> None: 172 | if self.test_cuda is False: 173 | return 174 | 175 | x = torch.randn(1, 3, 2048, 2048).cuda() 176 | with utils.Timer( 177 | '(2048, 2048) RGB to (512, 512) ' 178 | 'without AA using CUDA (Naive): {}'): 179 | 180 | x = self.imresize(x, sizes=(512, 512), antialiasing=False) 181 | 182 | return 183 | 184 | def test_down_down_x4_large_aa(self) -> None: 185 | x = torch.randn(1, 3, 2048, 2048) 186 | with utils.Timer( 187 | '(2048, 2048) RGB to (512, 512) with AA (Cubic Conv.): {}'): 188 | 189 | x = self.imresize(x, scale=0.25, antialiasing=True) 190 | 191 | return 192 | 193 | def test_cuda_down_down_x4_large_aa(self) -> None: 194 | if self.test_cuda is False: 195 | return 196 | 197 | x = torch.randn(1, 3, 2048, 2048).cuda() 198 | with utils.Timer( 199 | '(2048, 2048) RGB to (512, 512) ' 200 | 'with AA using CUDA (Cubic Conv.): {}'): 201 | 202 | x = self.imresize(x, scale=0.25, antialiasing=True) 203 | 204 | return 205 | 206 | def test_down_down_x4_naive_large_aa(self) -> None: 207 | x = torch.randn(1, 3, 2048, 2048) 208 | with utils.Timer( 209 | '(2048, 2048) RGB to (512, 512) with AA (Naive): {}'): 210 | 211 | x = self.imresize(x, sizes=(512, 512), antialiasing=True) 212 | 213 | return 214 | 215 | def test_cuda_down_down_x4_naive_large_aa(self) -> None: 216 | if self.test_cuda is False: 217 | return 218 | 219 | x = torch.randn(1, 3, 2048, 2048).cuda() 220 | with utils.Timer( 221 | '(2048, 2048) RGB to (512, 512) ' 222 | 'with AA using CUDA (Naive): {}'): 223 | 224 | x = self.imresize(x, sizes=(512, 512), antialiasing=True) 225 | 226 | return 227 | 228 | def test_down_down_small_noaa(self) -> None: 229 | with utils.Timer('(4, 4) to (3, 3) without AA: {}'): 230 | x = self.imresize( 231 | self.input_small, sizes=(3, 3), antialiasing=False, 232 | ) 233 | 234 | self.check_diff(x, 'down_down_small_noaa') 235 | return 236 | 237 | def test_cuda_down_down_small_noaa(self) -> None: 238 | if self.test_cuda is False: 239 | return 240 | 241 | with utils.Timer('(4, 4) to (3, 3) without AA using CUDA: {}'): 242 | x = self.imresize( 243 | self.input_small_cuda, sizes=(3, 3), antialiasing=False, 244 | ) 245 | 246 | self.check_diff(x, 'down_down_small_noaa') 247 | return 248 | 249 | def test_down_down_small_aa(self) -> None: 250 | with utils.Timer('(4, 4) to (3, 3) with AA: {}'): 251 | x = self.imresize( 252 | self.input_small, sizes=(3, 3), antialiasing=True, 253 | ) 254 | 255 | self.check_diff(x, 'down_down_small_aa') 256 | return 257 | 258 | def test_cuda_down_down_small_aa(self) -> None: 259 | if self.test_cuda is False: 260 | return 261 | 262 | with utils.Timer('(4, 4) to (3, 3) with AA using CUDA: {}'): 263 | x = self.imresize( 264 | self.input_small_cuda, sizes=(3, 3), antialiasing=True, 265 | ) 266 | 267 | self.check_diff(x, 'down_down_small_aa') 268 | return 269 | 270 | def test_down_down_noaa(self) -> None: 271 | with utils.Timer('(8, 8) to (3, 4) without AA: {}'): 272 | x = self.imresize( 273 | self.input_square, sizes=(3, 4), antialiasing=False, 274 | ) 275 | 276 | self.check_diff(x, 'down_down_noaa') 277 | return 278 | 279 | def test_cuda_down_down_noaa(self) -> None: 280 | if self.test_cuda is False: 281 | return 282 | 283 | with utils.Timer('(8, 8) to (3, 4) without AA using CUDA: {}'): 284 | x = self.imresize( 285 | self.input_square_cuda, sizes=(3, 4), antialiasing=False, 286 | ) 287 | 288 | self.check_diff(x, 'down_down_noaa') 289 | return 290 | 291 | def test_down_down_aa(self) -> None: 292 | with utils.Timer('(8, 8) to (3, 4) with AA: {}'): 293 | x = self.imresize( 294 | self.input_square, sizes=(3, 4), antialiasing=True, 295 | ) 296 | 297 | self.check_diff(x, 'down_down_aa') 298 | return 299 | 300 | def test_cuda_down_down_aa(self) -> None: 301 | if self.test_cuda is False: 302 | return 303 | 304 | with utils.Timer('(8, 8) to (3, 4) with AA using CUDA: {}'): 305 | x = self.imresize( 306 | self.input_square_cuda, sizes=(3, 4), antialiasing=True, 307 | ) 308 | 309 | self.check_diff(x, 'down_down_aa') 310 | return 311 | 312 | def test_down_down_irregular_noaa(self) -> None: 313 | with utils.Timer('(8, 8) to (5, 7) without AA: {}'): 314 | x = self.imresize( 315 | self.input_square, sizes=(5, 7), antialiasing=False, 316 | ) 317 | 318 | self.check_diff(x, 'down_down_irregular_noaa') 319 | return 320 | 321 | def test_cuda_down_down_irregular_noaa(self) -> None: 322 | if self.test_cuda is False: 323 | return 324 | 325 | with utils.Timer('(8, 8) to (5, 7) without AA using CUDA: {}'): 326 | x = self.imresize( 327 | self.input_square_cuda, sizes=(5, 7), antialiasing=False, 328 | ) 329 | 330 | self.check_diff(x, 'down_down_irregular_noaa') 331 | return 332 | 333 | def test_down_down_x2_aa(self) -> None: 334 | with utils.Timer('(8, 8) to (4, 4) with AA: {}'): 335 | x = self.imresize( 336 | self.input_square, scale=(1 / 2), antialiasing=True, 337 | ) 338 | 339 | self.check_diff(x, 'down_down_x2_aa') 340 | return 341 | 342 | def test_cuda_down_down_x2_aa(self) -> None: 343 | if self.test_cuda is False: 344 | return 345 | 346 | with utils.Timer('(8, 8) to (4, 4) with AA using CUDA: {}'): 347 | x = self.imresize( 348 | self.input_square_cuda, scale=(1 / 2), antialiasing=True, 349 | ) 350 | 351 | self.check_diff(x, 'down_down_x2_aa') 352 | return 353 | 354 | def test_down_down_x3_aa(self) -> None: 355 | with utils.Timer('(15, 15) to (5, 5) with AA: {}'): 356 | x = self.imresize( 357 | self.input_15x15, scale=(1 / 3), antialiasing=True, 358 | ) 359 | 360 | self.check_diff(x, 'down_down_x3_aa') 361 | return 362 | 363 | def test_cuda_down_down_x3_aa(self) -> None: 364 | if self.test_cuda is False: 365 | return 366 | 367 | with utils.Timer('(15, 15) to (5, 5) with AA using CUDA: {}'): 368 | x = self.imresize( 369 | self.input_15x15_cuda, scale=(1 / 3), antialiasing=True, 370 | ) 371 | 372 | self.check_diff(x, 'down_down_x3_aa') 373 | return 374 | 375 | def test_down_down_x4_aa(self) -> None: 376 | with utils.Timer('(8, 8) to (2, 2) with AA: {}'): 377 | x = self.imresize( 378 | self.input_square, scale=(1 / 4), antialiasing=True, 379 | ) 380 | 381 | self.check_diff(x, 'down_down_x4_aa') 382 | return 383 | 384 | def test_cuda_down_down_x4_aa(self) -> None: 385 | if self.test_cuda is False: 386 | return 387 | 388 | with utils.Timer('(8, 8) to (2, 2) with AA using CUDA: {}'): 389 | x = self.imresize( 390 | self.input_square_cuda, scale=(1 / 4), antialiasing=True, 391 | ) 392 | 393 | self.check_diff(x, 'down_down_x4_aa') 394 | return 395 | 396 | def test_down_down_x5_aa(self) -> None: 397 | with utils.Timer('(15, 15) to (3, 3) with AA: {}'): 398 | x = self.imresize( 399 | self.input_15x15, scale=(1 / 5), antialiasing=True, 400 | ) 401 | 402 | self.check_diff(x, 'down_down_x5_aa') 403 | return 404 | 405 | def test_cuda_down_down_x5_aa(self) -> None: 406 | if self.test_cuda is False: 407 | return 408 | 409 | with utils.Timer('(15, 15) to (3, 3) with AA using CUDA: {}'): 410 | x = self.imresize( 411 | self.input_15x15_cuda, scale=(1 / 5), antialiasing=True, 412 | ) 413 | 414 | self.check_diff(x, 'down_down_x5_aa') 415 | return 416 | 417 | def test_up_up_topleft_noaa(self) -> None: 418 | with utils.Timer('(4, 4) topleft to (5, 5) without AA: {}'): 419 | x = self.imresize( 420 | self.input_topleft, sizes=(5, 5), antialiasing=False, 421 | ) 422 | 423 | self.check_diff(x, 'up_up_topleft_noaa') 424 | return 425 | 426 | def test_cuda_up_up_topleft_noaa(self) -> None: 427 | if self.test_cuda is False: 428 | return 429 | 430 | with utils.Timer('(4, 4) topleft to (5, 5) without AA using CUDA: {}'): 431 | x = self.imresize( 432 | self.input_topleft_cuda, sizes=(5, 5), antialiasing=False, 433 | ) 434 | 435 | self.check_diff(x, 'up_up_topleft_noaa') 436 | return 437 | 438 | def test_up_up_bottomright_noaa(self) -> None: 439 | with utils.Timer('(4, 4) bottomright to (5, 5) without AA: {}'): 440 | x = self.imresize( 441 | self.input_bottomright, sizes=(5, 5), antialiasing=False, 442 | ) 443 | 444 | self.check_diff(x, 'up_up_bottomright_noaa') 445 | return 446 | 447 | def test_cuda_up_up_bottomright_noaa(self) -> None: 448 | if self.test_cuda is False: 449 | return 450 | 451 | with utils.Timer('(4, 4) bottomright to (5, 5) without AA using CUDA: {}'): 452 | x = self.imresize( 453 | self.input_bottomright_cuda, sizes=(5, 5), antialiasing=False, 454 | ) 455 | 456 | self.check_diff(x, 'up_up_bottomright_noaa') 457 | return 458 | 459 | def test_up_up_irregular_noaa(self) -> None: 460 | with utils.Timer('(8, 8) to (11, 13) without AA: {}'): 461 | x = self.imresize( 462 | self.input_square, sizes=(11, 13), antialiasing=False, 463 | ) 464 | 465 | self.check_diff(x, 'up_up_irregular_noaa') 466 | return 467 | 468 | def test_cuda_up_up_irregular_noaa(self) -> None: 469 | if self.test_cuda is False: 470 | return 471 | 472 | with utils.Timer('(8, 8) to (11, 13) without AA using CUDA: {}'): 473 | x = self.imresize( 474 | self.input_square_cuda, sizes=(11, 13), antialiasing=False, 475 | ) 476 | 477 | self.check_diff(x, 'up_up_irregular_noaa') 478 | return 479 | 480 | def test_up_up_irregular_aa(self) -> None: 481 | with utils.Timer('(8, 8) to (11, 13) with AA: {}'): 482 | x = self.imresize( 483 | self.input_square, sizes=(11, 13), antialiasing=True, 484 | ) 485 | 486 | self.check_diff(x, 'up_up_irregular_aa') 487 | return 488 | 489 | def test_cuda_up_up_irregular_aa(self) -> None: 490 | if self.test_cuda is False: 491 | return 492 | 493 | with utils.Timer('(8, 8) to (11, 13) with AA using CUDA: {}'): 494 | x = self.imresize( 495 | self.input_square_cuda, sizes=(11, 13), antialiasing=True, 496 | ) 497 | 498 | self.check_diff(x, 'up_up_irregular_aa') 499 | return 500 | 501 | def test_down_down_butterfly_irregular_noaa(self) -> None: 502 | with utils.Timer('(256, 256) butterfly to (123, 234) without AA: {}'): 503 | x = self.imresize( 504 | self.butterfly, sizes=(123, 234), antialiasing=False, 505 | ) 506 | 507 | self.check_diff(x, 'down_down_butterfly_irregular_noaa') 508 | return 509 | 510 | def test_cuda_down_down_butterfly_irregular_noaa(self) -> None: 511 | if self.test_cuda is False: 512 | return 513 | 514 | with utils.Timer('(256, 256) butterfly to (123, 234) without AA using CUDA: {}'): 515 | x = self.imresize( 516 | self.butterfly_cuda, sizes=(123, 234), antialiasing=False, 517 | ) 518 | 519 | self.check_diff(x, 'down_down_butterfly_irregular_noaa') 520 | return 521 | 522 | def test_double_down_down_butterfly_irregular_noaa(self) -> None: 523 | double = self.butterfly.double() 524 | with utils.Timer('(256, 256) butterfly (double) to (123, 234) without AA: {}'): 525 | x = self.imresize(double, sizes=(123, 234), antialiasing=False) 526 | 527 | self.check_diff(x, 'down_down_butterfly_irregular_noaa') 528 | return 529 | 530 | def test_double_cuda_down_down_butterfly_irregular_noaa(self) -> None: 531 | if self.test_cuda is False: 532 | return 533 | 534 | double = self.butterfly_cuda.double() 535 | with utils.Timer('(256, 256) butterfly (double) to (123, 234) without AA using CUDA: {}'): 536 | x = self.imresize(double, sizes=(123, 234), antialiasing=False) 537 | 538 | self.check_diff(x, 'down_down_butterfly_irregular_noaa') 539 | return 540 | 541 | def test_down_down_butterfly_irregular_aa(self) -> None: 542 | with utils.Timer('(256, 256) butterfly to (123, 234) with AA: {}'): 543 | x = self.imresize( 544 | self.butterfly, sizes=(123, 234), antialiasing=True, 545 | ) 546 | 547 | self.check_diff(x, 'down_down_butterfly_irregular_aa') 548 | return 549 | 550 | def test_cuda_down_down_butterfly_irregular_aa(self) -> None: 551 | if self.test_cuda is False: 552 | return 553 | 554 | with utils.Timer('(256, 256) butterfly to (123, 234) with AA using CUDA: {}'): 555 | x = self.imresize( 556 | self.butterfly_cuda, sizes=(123, 234), antialiasing=True, 557 | ) 558 | 559 | self.check_diff(x, 'down_down_butterfly_irregular_aa') 560 | return 561 | 562 | def test_up_up_butterfly_irregular_noaa(self) -> None: 563 | with utils.Timer('(256, 256) butterfly to (1234, 789) without AA: {}'): 564 | x = self.imresize( 565 | self.butterfly, sizes=(1234, 789), antialiasing=False, 566 | ) 567 | 568 | self.check_diff(x, 'up_up_butterfly_irregular_noaa') 569 | return 570 | 571 | def test_cuda_up_up_butterfly_irregular_noaa(self) -> None: 572 | if self.test_cuda is False: 573 | return 574 | 575 | with utils.Timer('(256, 256) butterfly to (1234, 789) without AA using CUDA: {}'): 576 | x = self.imresize( 577 | self.butterfly_cuda, sizes=(1234, 789), antialiasing=False, 578 | ) 579 | 580 | self.check_diff(x, 'up_up_butterfly_irregular_noaa') 581 | return 582 | 583 | 584 | if __name__ == '__main__': 585 | unittest.main() 586 | -------------------------------------------------------------------------------- /test_answer/down_down_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_butterfly_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_butterfly_irregular_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_butterfly_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_butterfly_irregular_noaa.mat -------------------------------------------------------------------------------- /test_answer/down_down_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_irregular_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_irregular_noaa.mat -------------------------------------------------------------------------------- /test_answer/down_down_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_noaa.mat -------------------------------------------------------------------------------- /test_answer/down_down_small_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_small_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_small_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_small_noaa.mat -------------------------------------------------------------------------------- /test_answer/down_down_x2_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_x2_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_x3_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_x3_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_x4_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_x4_aa.mat -------------------------------------------------------------------------------- /test_answer/down_down_x5_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/down_down_x5_aa.mat -------------------------------------------------------------------------------- /test_answer/gen_test.m: -------------------------------------------------------------------------------- 1 | input_square = linspace(0, 63, 64); 2 | % Transpose required 3 | input_square = reshape(input_square, [8, 8])'; 4 | 5 | input_small = linspace(0, 15, 16); 6 | input_small = reshape(input_small, [4, 4])'; 7 | 8 | input_15x15 = linspace(0, 224, 225); 9 | input_15x15 = reshape(input_15x15, [15, 15])'; 10 | 11 | input_topleft = zeros(4, 4); 12 | input_topleft(1, 1) = 100; 13 | input_topleft(2, 1) = 10; 14 | input_topleft(1, 2) = 1; 15 | input_topleft(1, 4) = 100; 16 | 17 | input_bottomright = zeros(4, 4); 18 | input_bottomright(4, 4) = 100; 19 | input_bottomright(3, 4) = 10; 20 | input_bottomright(4, 3) = 1; 21 | input_bottomright(4, 1) = 100; 22 | 23 | 24 | fprintf('(4, 4) to (3, 3) without AA\n'); 25 | down_down_small_noaa = imresize( ... 26 | input_small, [3, 3], 'bicubic', 'antialiasing', false ... 27 | ); 28 | save('down_down_small_noaa.mat', 'down_down_small_noaa'); 29 | 30 | fprintf('(4, 4) to (3, 3) with AA\n'); 31 | down_down_small_aa = imresize( ... 32 | input_small, [3, 3], 'bicubic', 'antialiasing', true ... 33 | ); 34 | save('down_down_small_aa.mat', 'down_down_small_aa'); 35 | 36 | 37 | fprintf('(8, 8) to (3, 4) without AA\n'); 38 | down_down_noaa = imresize( ... 39 | input_square, [3, 4], 'bicubic', 'antialiasing', false ... 40 | ); 41 | save('down_down_noaa.mat', 'down_down_noaa'); 42 | 43 | fprintf('(8, 8) to (3, 4) with AA\n'); 44 | down_down_aa = imresize( ... 45 | input_square, [3, 4], 'bicubic', 'antialiasing', true ... 46 | ); 47 | save('down_down_aa.mat', 'down_down_aa'); 48 | 49 | 50 | fprintf('(8, 8) to (5, 7) without AA\n'); 51 | down_down_irregular_noaa = imresize( ... 52 | input_square, [5, 7], 'bicubic', 'antialiasing', false ... 53 | ); 54 | save('down_down_irregular_noaa.mat', 'down_down_irregular_noaa'); 55 | 56 | fprintf('(8, 8) to (5, 7) with AA\n'); 57 | down_down_irregular_aa = imresize( ... 58 | input_square, [5, 7], 'bicubic', 'antialiasing', true ... 59 | ); 60 | save('down_down_irregular_aa.mat', 'down_down_irregular_aa'); 61 | 62 | 63 | fprintf('(4, 4) topleft to (5, 5) without AA\n'); 64 | up_up_topleft_noaa = imresize( ... 65 | input_topleft, [5, 5], 'bicubic', 'antialiasing', false ... 66 | ); 67 | save('up_up_topleft_noaa.mat', 'up_up_topleft_noaa'); 68 | 69 | fprintf('(4, 4) bottomright to (5, 5) without AA\n'); 70 | up_up_bottomright_noaa = imresize( ... 71 | input_bottomright, [5, 5], 'bicubic', 'antialiasing', false ... 72 | ); 73 | save('up_up_bottomright_noaa.mat', 'up_up_bottomright_noaa'); 74 | 75 | 76 | fprintf('(8, 8) to (11, 13) without AA\n'); 77 | up_up_irregular_noaa = imresize( ... 78 | input_square, [11, 13], 'bicubic', 'antialiasing', false ... 79 | ); 80 | save('up_up_irregular_noaa.mat', 'up_up_irregular_noaa'); 81 | 82 | fprintf('(8, 8) to (11, 13) with AA\n'); 83 | up_up_irregular_aa = imresize( ... 84 | input_square, [11, 13], 'bicubic', 'antialiasing', true ... 85 | ); 86 | save('up_up_irregular_aa.mat', 'up_up_irregular_aa'); 87 | 88 | 89 | butterfly = imread(fullfile('..', 'example', 'butterfly.png')); 90 | butterfly = im2double(butterfly); 91 | fprintf('(256, 256) butterfly.png to (123, 234) without AA\n') 92 | down_down_butterfly_irregular_noaa = imresize( ... 93 | butterfly, [123, 234], 'bicubic', 'antialiasing', false ... 94 | ); 95 | save( ... 96 | 'down_down_butterfly_irregular_noaa.mat', ... 97 | 'down_down_butterfly_irregular_noaa' ... 98 | ); 99 | 100 | fprintf('(256, 256) butterfly.png to (123, 234) with AA\n') 101 | down_down_butterfly_irregular_aa = imresize( ... 102 | butterfly, [123, 234], 'bicubic', 'antialiasing', true ... 103 | ); 104 | save( ... 105 | 'down_down_butterfly_irregular_aa.mat', ... 106 | 'down_down_butterfly_irregular_aa' ... 107 | ); 108 | 109 | 110 | fprintf('(256, 256) butterfly.png to (1234, 789) without AA\n') 111 | up_up_butterfly_irregular_noaa = imresize( ... 112 | butterfly, [1234, 789], 'bicubic', 'antialiasing', false ... 113 | ); 114 | save( ... 115 | 'up_up_butterfly_irregular_noaa.mat', ... 116 | 'up_up_butterfly_irregular_noaa' ... 117 | ); 118 | 119 | 120 | fprintf('(8, 8) to (4, 4) with AA\n'); 121 | down_down_x2_aa = imresize( ... 122 | input_square, [4, 4], 'bicubic', 'antialiasing', true ... 123 | ); 124 | save('down_down_x2_aa.mat', 'down_down_x2_aa'); 125 | 126 | fprintf('(8, 8) to (2, 2) with AA\n'); 127 | down_down_x4_aa = imresize( ... 128 | input_square, [2, 2], 'bicubic', 'antialiasing', true ... 129 | ); 130 | save('down_down_x4_aa.mat', 'down_down_x4_aa'); 131 | 132 | 133 | fprintf('(15, 15) to (5, 5) with AA\n'); 134 | down_down_x3_aa = imresize( ... 135 | input_15x15, [5, 5], 'bicubic', 'antialiasing', true ... 136 | ); 137 | save('down_down_x3_aa.mat', 'down_down_x3_aa'); 138 | 139 | fprintf('(15, 15) to (3, 3) with AA\n'); 140 | down_down_x5_aa = imresize( ... 141 | input_15x15, [3, 3], 'bicubic', 'antialiasing', true ... 142 | ); 143 | save('down_down_x5_aa.mat', 'down_down_x5_aa'); 144 | -------------------------------------------------------------------------------- /test_answer/up_up_bottomright_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/up_up_bottomright_noaa.mat -------------------------------------------------------------------------------- /test_answer/up_up_butterfly_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/up_up_butterfly_irregular_noaa.mat -------------------------------------------------------------------------------- /test_answer/up_up_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/up_up_irregular_aa.mat -------------------------------------------------------------------------------- /test_answer/up_up_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/up_up_irregular_noaa.mat -------------------------------------------------------------------------------- /test_answer/up_up_topleft_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanghyun-son/bicubic_pytorch/da432d11475d935641cb2a30c0e2ff976e7111fa/test_answer/up_up_topleft_noaa.mat -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import unittest 3 | 4 | import core_warp 5 | import utils 6 | 7 | import torch 8 | from torch import cuda 9 | 10 | 11 | class TestWarpBenchmark(unittest.TestCase): 12 | 13 | def __init__(self, *args, **kwargs) -> None: 14 | super().__init__(*args, **kwargs) 15 | self.n = 1000 16 | self.butterfly = utils.get_img(path.join('example', 'butterfly.png')) 17 | # Batching 18 | self.butterfly = self.butterfly.repeat(16, 1, 1, 1) 19 | self.m = torch.Tensor([ 20 | [3.2, 0.016, -68], 21 | [1.23, 1.7, -54], 22 | [0.008, 0.0001, 1], 23 | ]) 24 | if cuda.is_available(): 25 | self.butterfly = self.butterfly.cuda() 26 | self.m = self.m.cuda() 27 | 28 | with utils.Timer('Warm-up: {}'): 29 | for _ in range(100): 30 | _ = core_warp.warp( 31 | self.butterfly, 32 | self.m, 33 | sizes='auto', 34 | kernel='bicubic', 35 | fill_value=0, 36 | ) 37 | 38 | cuda.synchronize() 39 | 40 | def test_warp_nearest(self) -> torch.Tensor: 41 | with utils.Timer('Nearest warping: {}'): 42 | for _ in range(self.n): 43 | _ = core_warp.warp( 44 | self.butterfly, 45 | self.m, 46 | sizes='auto', 47 | kernel='nearest', 48 | fill_value=0, 49 | ) 50 | 51 | cuda.synchronize() 52 | 53 | def test_warp_bilinear(self) -> torch.Tensor: 54 | with utils.Timer('Bilinear warping: {}'): 55 | for _ in range(self.n): 56 | _ = core_warp.warp( 57 | self.butterfly, 58 | self.m, 59 | sizes='auto', 60 | kernel='bilinear', 61 | fill_value=0, 62 | ) 63 | 64 | cuda.synchronize() 65 | 66 | def test_warp_bicubic(self) -> torch.Tensor: 67 | with utils.Timer('Bicubic warping: {}'): 68 | for _ in range(self.n): 69 | _ = core_warp.warp( 70 | self.butterfly, 71 | self.m, 72 | sizes='auto', 73 | kernel='bicubic', 74 | fill_value=0, 75 | ) 76 | 77 | cuda.synchronize() 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() -------------------------------------------------------------------------------- /test_gradient.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import unittest 3 | 4 | import core 5 | import utils 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch import cuda 11 | from torch import optim 12 | 13 | 14 | class TestGradient(unittest.TestCase): 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | super().__init__(*args, **kwargs) 18 | self.n_iters = 200 19 | self.lr = 1e-2 20 | self.input_size = (123, 234) 21 | 22 | if cuda.is_available(): 23 | self.device = torch.device('cuda') 24 | else: 25 | self.device = torch.device('cpu') 26 | 27 | self.target = utils.get_img(path.join('example', 'butterfly.png')) 28 | self.target = self.target.to(self.device) 29 | self.target_size = (self.target.size(-2), self.target.size(-1)) 30 | 31 | def test_backpropagation(self) -> None: 32 | noise = torch.rand( 33 | 1, 34 | self.target.size(1), 35 | self.input_size[0], 36 | self.input_size[1], 37 | device=self.device, 38 | ) 39 | noise_p = nn.Parameter(noise, requires_grad=True) 40 | utils.save_img(noise_p, path.join('example', 'noise_input.png')) 41 | optimizer = optim.Adam([noise_p], lr=self.lr) 42 | 43 | for i in range(self.n_iters): 44 | optimizer.zero_grad() 45 | noise_up = core.imresize(noise_p, size=self.target_size) 46 | loss = F.mse_loss(noise_up, self.target) 47 | loss.backward() 48 | if i == 0 or (i + 1) % 20 == 0: 49 | print('Iter {:0>4}\tLoss: {:.8f}'.format(i + 1, loss.item())) 50 | 51 | optimizer.step() 52 | 53 | utils.save_img(noise_p, path.join('example', 'noise_optimized.png')) 54 | assert loss.item() < 1e-2, 'Failed to optimize!' 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Minor utilities for testing. 3 | You do not have to import this code to use core.py 4 | ''' 5 | 6 | import time 7 | import numpy as np 8 | import imageio 9 | 10 | import torch 11 | 12 | 13 | class Timer(object): 14 | 15 | def __init__(self, msg: str) -> None: 16 | self.msg = msg.replace('{}', '{:.6f}s') 17 | self.tic = None 18 | return 19 | 20 | def __enter__(self) -> None: 21 | self.tic = time.time() 22 | return 23 | 24 | def __exit__(self, *args, **kwargs) -> None: 25 | toc = time.time() - self.tic 26 | print('\n' + self.msg.format(toc)) 27 | return 28 | 29 | 30 | def np2tensor(x: np.array) -> torch.Tensor: 31 | x = np.transpose(x, (2, 0, 1)) 32 | x = torch.from_numpy(x) 33 | with torch.no_grad(): 34 | while x.dim() < 4: 35 | x.unsqueeze_(0) 36 | 37 | x = x.float() / 255 38 | 39 | return x 40 | 41 | def tensor2np(x: torch.Tensor) -> np.array: 42 | with torch.no_grad(): 43 | x = 255 * x 44 | x = x.round().clamp(min=0, max=255).byte() 45 | x = x.squeeze(0) 46 | 47 | x = x.cpu().numpy() 48 | x = np.transpose(x, (1, 2, 0)) 49 | x = np.ascontiguousarray(x) 50 | return x 51 | 52 | def get_img(img_path: str) -> torch.Tensor: 53 | x = imageio.imread(img_path) 54 | x = np2tensor(x) 55 | return x 56 | 57 | def save_img(x: torch.Tensor, img_path: str) -> None: 58 | x = tensor2np(x) 59 | imageio.imwrite(img_path, x) 60 | return 61 | --------------------------------------------------------------------------------