├── README.md ├── datasets └── .gitignore ├── experiments └── .gitignore ├── results └── .gitignore └── src ├── .vscode └── settings.json ├── __pycache__ ├── data_loss.cpython-37.pyc ├── dataset.cpython-37.pyc ├── filters.cpython-37.pyc ├── networks.cpython-37.pyc ├── options.cpython-37.pyc ├── saver.cpython-37.pyc ├── trainer_down.cpython-37.pyc ├── trainer_sr.cpython-37.pyc └── utility.cpython-37.pyc ├── bicubic_pytorch ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── core.cpython-37.pyc ├── core.py ├── example │ ├── butterfly.png │ ├── noise_input.png │ └── noise_optimized.png ├── test.py ├── test_answer │ ├── down_down_aa.mat │ ├── down_down_butterfly_irregular_aa.mat │ ├── down_down_butterfly_irregular_noaa.mat │ ├── down_down_irregular_aa.mat │ ├── down_down_irregular_noaa.mat │ ├── down_down_noaa.mat │ ├── down_down_small_aa.mat │ ├── down_down_small_noaa.mat │ ├── down_down_x2_aa.mat │ ├── down_down_x3_aa.mat │ ├── down_down_x4_aa.mat │ ├── down_down_x5_aa.mat │ ├── gen_test.m │ ├── up_up_bottomright_noaa.mat │ ├── up_up_butterfly_irregular_noaa.mat │ ├── up_up_irregular_aa.mat │ ├── up_up_irregular_noaa.mat │ └── up_up_topleft_noaa.mat ├── test_gradient.py └── utils.py ├── data_loss.py ├── dataset.py ├── demo.sh ├── filters.py ├── model ├── Discriminator.py ├── ESRGAN.py ├── __init__.py ├── __pycache__ │ ├── Discriminator.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── common.cpython-37.pyc │ ├── common_rrdb.cpython-37.pyc │ ├── edsr.cpython-37.pyc │ └── rrdb.cpython-37.pyc ├── block.py ├── common.py ├── common_rrdb.py ├── ddbpn.py ├── didn.py ├── edsr.py ├── mdsr.py ├── rcan.py ├── rdn.py ├── rrdb.py └── vdsr.py ├── networks.py ├── options.py ├── saver.py ├── test_down.py ├── test_sr.py ├── train.py ├── trainer_down.py ├── trainer_sr.py └── utility.py /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive-Downsampling-Model 2 | This repository is an official PyTorch implementation of the paper **"Toward Real-World Super-Resolution via Adaptive Downsampling Models"** which is accepted at TPAMI([link](https://ieeexplore.ieee.org/document/9521710)). 3 | 4 | ## Dependencies 5 | * Python 3.7 6 | * PyTorch >= 1.6.0 7 | * matplotlib 8 | * imageio 9 | * pyyaml 10 | * scipy 11 | * numpy 12 | * tqdm 13 | * PIL 14 | 15 | In this project, we learn the adaptive downsampling model(**ADM**) with an unpaired dataset consisting of *HR* and *LR* images but not pixel aligned. 16 | 17 | ## 🚉: Dataset preperation 18 | 19 | As *HR* is not responsible to pixel aligned with *LR*, we recommend to use [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset as *HR* dataset, which is consist of clean and high-resolution images. 20 | You can download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB). 21 | We note that experiments conducted on our paper used 400 HR images('0001.png'-'0400.png') as *HR*. 22 | 23 | For *LR* datasets, you should put a bunch of images that undergo a similar downsampling process. (e.g. [DPED](https://people.ee.ethz.ch/~ihnatova/), [RealSR](https://github.com/csjcai/RealSR), or you own images from same camera setting) 24 | 25 | Please put *HR* and *LR* datasets in ```datasets/```. Again, different lengths between each dataset are acceptable, as we noted that two datasets are not responsible for pixel-aligned. However, we also note that total iterations of 1 epoch can differ along with your dataset size. Since our **ADM** learns average downsampling kernel along with *LR* datasets, please use available LR images with scene variety as much as possible for stable training(we recommend to use more than 200 for HD scale images). For more details on the effect of the number of training samples, please refer to our paper. 26 | 27 | ## 🚋: Training 28 | 29 | ### Learning to Downsample 30 | 31 | Let denote filename of *HR* dataset as *Source*, and *LR* dataset as *Target*. 32 | 33 | Our ADM will retrieve downsampling between two datasets *Source* and *Target*, and generate downsampled version of *Source* with retrieved downsampling kernel. Basic usage of training is following: 34 | 35 | ``` 36 | cd src 37 | CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --make_down 38 | ``` 39 | 40 | Generated downsampled version of *Source* will be saved at ```./experiments/save_name/down_results/```. Note that you can use *Source* and generated downsampled version of *Source* as **paired dataset** in conventional SR settings. 41 | 42 | ### Joint training with Image super-resolution (Optional) 43 | 44 | Here we additionally support joint training with SR network, which use intermediate generated image as paired dataset hence does not require additional SR network training step. Usage of joint training is following: 45 | 46 | ``` 47 | cd src 48 | CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint 49 | ``` 50 | 51 | Default SR model is 'EDSR-baseline', but you can change with ```--sr_model```. 52 | 53 | If you have validation datasets for evaluating perforamce of SR, please locate them in ```datasets/```. After then, you can measure performance of SR network by measuring PSNR on validation datasets. You can specify it with ```--test_lr filename_lr --test_hr filename_hr```, where *filename_hr* and *filename_lr* should be paired images. 54 | 55 | In case that you don't have validation paired datasets as common in real-world, you can visualize SR results by ```--test_lr filename_lr --realsr --save_results```, and then SR results will be saved in ```./experiments/save_name/sr_results/```. Note that *filename_lr* can be same with *Target*. 56 | 57 | You can check detailed usage of this repo in ```demo.sh``` . 58 | 59 | Please note that experimental results reported in our paper are conducted in separate manner(i.e., we only generate downsampled images and train SR network with corresponding official implementation), so the results in joint training may slightly differs with number in the paper. 60 | 61 | ## BibTeX 62 | 63 | @ARTICLE{9521710, 64 | author={Son, Sanghyun and Kim, Jaeha and Lai, Wei-Sheng and Yang, Ming-Hsuan and Lee, Kyoung Mu}, 65 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 66 | title={Toward Real-World Super-Resolution via Adaptive Downsampling Models}, 67 | year={2021}, 68 | volume={}, 69 | number={}, 70 | pages={1-1}, 71 | doi={10.1109/TPAMI.2021.3106790} 72 | } 73 | 74 | ## :e-mail: Contact 75 | 76 | If you have any question, please email `jhkim97s2@gmail.com`. 77 | 78 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /experiments/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /src/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/jaeha/anaconda3/envs/py3/bin/python", 3 | "git.ignoreLimitWarning": true 4 | } -------------------------------------------------------------------------------- /src/__pycache__/data_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/data_loss.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/filters.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/filters.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/saver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/saver.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/trainer_down.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/trainer_down.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/trainer_sr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/trainer_sr.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/utility.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/__pycache__/utility.cpython-37.pyc -------------------------------------------------------------------------------- /src/bicubic_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # bicubic-pytorch 2 | 3 | Image resizing is one of the essential operations in computer vision and image processing. 4 | While MATLAB's `imresize` function is used as a standard, implementations from other libraries (e.g., PIL, OpenCV, PyTorch, ...) are not consistent with the MATLAB, especially for a bicubic kernel. 5 | The goal of this repository is to provide a **MATLAB-like bicubic interpolation** in a widely-used PyTorch framework. 6 | Any issues are welcomed to make it better! 7 | 8 | 9 | The merits of our implementation are: 10 | * Easy to use. You only need one python file to copy. 11 | * Consistent with MATLAB's `imresize('bicubic')`, with or without antialiasing. 12 | * Support arbitrary resizing factors on different dimensions. 13 | * Very fast, support GPU acceleration and batching. 14 | * Fully differentiable with respect to input and output images. 15 | 16 | 17 | ## Environment and Dependency 18 | 19 | This repository is tested under: 20 | * Ubuntu 18.04 21 | * PyTorch 1.5.1 (minimum 0.4.0 is required) 22 | * CUDA 10.2 23 | * MATLAB R2019b 24 | 25 | However, we avoid using any version-dependent coding style to make our method compatible with various environments. 26 | If you are not going to generate any test cases, MATLAB is not required. 27 | You do not need any additional dependencies to use this repository. 28 | 29 | 30 | ## How to use 31 | 32 | We provide two options to use this package in your project. 33 | The first way is a Git submodule system, which helps you to keep track of important updates. 34 | ```bash 35 | # In your project repository 36 | $ git submodule add https://github.com/thstkdgus35/bicubic_pytorch 37 | 38 | # To get an update 39 | $ cd bicubic_pytorch 40 | $ git pull origin 41 | ``` 42 | 43 | ```python 44 | # In your python code 45 | import torch 46 | from bicubic_pytorch import core 47 | 48 | x = torch.randn(1, 3, 224, 224) 49 | y = core.imresize(x, scale=0.5) 50 | ``` 51 | 52 | Otherwise, copy `core.py` from the repository as follows: 53 | 54 | ```python 55 | import torch 56 | from torch import cuda 57 | import core 58 | 59 | # We support 2, 3, and 4-dim Tensors 60 | # (H x W, C x H x W, and B x C x H x W, respectively). 61 | # Larger batch sizes are also supported. 62 | x = torch.randn(1, 3, 456, 321) 63 | 64 | # If the input is on a CUDA device, all computations will be done using the GPU. 65 | if cuda.is_available(): 66 | x = x.cuda() 67 | 68 | # Resize by scale 69 | x_resized_1 = core.imresize(x, scale=0.456) 70 | 71 | # Resize by resolution (456, 321) -> (123, 456) 72 | x_resized_2 = core.imresize(x, sides=(123, 456)) 73 | 74 | # Resize without antialiasing (Not compatible with MATLAB) 75 | x_resized_3 = core.imresize(x, scale=0.456, antialiasing=False) 76 | ``` 77 | 78 | 79 | ## How to test 80 | 81 | You can run `test.py` to check the consistency with MATLAB's `imresize`. 82 | 83 | ```bash 84 | $ python test.py 85 | ``` 86 | 87 | You can generate more test cases using `test_answer/gen_test.m`. 88 | 89 | ```bash 90 | $ cd test_answer 91 | $ matlab -nodisplay < gen_test.m 92 | ``` 93 | 94 | 95 | ## Automatic differentiation 96 | 97 | Our implementation is fully differentiable. 98 | We provide a test script to optimize a random noise Tensor `n` so that `imresize(n)` be a target image. 99 | Please run `test_gradient.py` to test the example. 100 | 101 | ```bash 102 | $ python test_gradient.py 103 | ``` 104 | 105 | You can check the input noise from `example/noise_input.png` and the optimized image from `example/noise_optimized.png`. 106 | 107 | ![noise](example/noise_input.png) 108 | ![optimized](example/noise_optimized.png) 109 | ![target](example/butterfly.png) 110 | 111 | From the left, input noise, optimized, and target images. 112 | 113 | ## Acknowledgement 114 | 115 | The repositories below have provided excellent insights. 116 | 117 | * [https://github.com/fatheral/matlab_imresize](https://github.com/fatheral/matlab_imresize) 118 | * [https://github.com/sefibk/KernelGAN](https://github.com/sefibk/KernelGAN) 119 | 120 | ## Citation 121 | 122 | If you have found our implementation useful, please star and cite this repository: 123 | ``` 124 | @misc{son2020bicubic, 125 | author = {Son, Sanghyun}, 126 | title = {bicubic-pytorch}, 127 | year = {2020}, 128 | publisher = {GitHub}, 129 | journal = {GitHub repository}, 130 | howpublished = {\usr{https://github.com/thstkdgus35/bicubic-pytorch}}, 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /src/bicubic_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['core'] 2 | -------------------------------------------------------------------------------- /src/bicubic_pytorch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/bicubic_pytorch/__pycache__/core.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/__pycache__/core.cpython-37.pyc -------------------------------------------------------------------------------- /src/bicubic_pytorch/core.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A standalone PyTorch implementation for fast and efficient bicubic resampling. 3 | The resulting values are the same to MATLAB function imresize('bicubic'). 4 | 5 | ## Author: Sanghyun Son 6 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) 7 | ## Version: 1.1.0 8 | ## Last update: July 9th, 2020 (KST) 9 | 10 | Depencency: torch 11 | 12 | Example:: 13 | >>> import torch 14 | >>> import core 15 | >>> x = torch.arange(16).float().view(1, 1, 4, 4) 16 | >>> y = core.imresize(x, sides=(3, 3)) 17 | >>> print(y) 18 | tensor([[[[ 0.7506, 2.1004, 3.4503], 19 | [ 6.1505, 7.5000, 8.8499], 20 | [11.5497, 12.8996, 14.2494]]]]) 21 | ''' 22 | 23 | import math 24 | import typing 25 | 26 | import torch 27 | from torch.nn import functional as F 28 | 29 | __all__ = ['imresize'] 30 | 31 | K = typing.TypeVar('K', str, torch.Tensor) 32 | 33 | def cubic_contribution(x: torch.Tensor, a: float=-0.5) -> torch.Tensor: 34 | ax = x.abs() 35 | ax2 = ax * ax 36 | ax3 = ax * ax2 37 | 38 | range_01 = (ax <= 1) 39 | range_12 = (ax > 1) * (ax <= 2) 40 | 41 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 42 | cont_01 = cont_01 * range_01.to(dtype=x.dtype) 43 | 44 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 45 | cont_12 = cont_12 * range_12.to(dtype=x.dtype) 46 | 47 | cont = cont_01 + cont_12 48 | cont = cont / cont.sum() 49 | return cont 50 | 51 | def gaussian_contribution(x: torch.Tensor, sigma: float=2.0) -> torch.Tensor: 52 | range_3sigma = (x.abs() <= 3 * sigma + 1) 53 | # Normalization will be done after 54 | cont = torch.exp(-x.pow(2) / (2 * sigma**2)) 55 | cont = cont * range_3sigma.to(dtype=x.dtype) 56 | return cont 57 | 58 | def discrete_kernel( 59 | kernel: str, scale: float, antialiasing: bool=True) -> torch.Tensor: 60 | 61 | ''' 62 | For downsampling with integer scale only. 63 | ''' 64 | downsampling_factor = int(1 / scale) 65 | if kernel == 'cubic': 66 | kernel_size_orig = 4 67 | else: 68 | raise ValueError('Pass!') 69 | 70 | if antialiasing: 71 | kernel_size = kernel_size_orig * downsampling_factor 72 | else: 73 | kernel_size = kernel_size_orig 74 | 75 | if downsampling_factor % 2 == 0: 76 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) 77 | else: 78 | kernel_size -= 1 79 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) 80 | 81 | with torch.no_grad(): 82 | r = torch.linspace(-a, a, steps=kernel_size) 83 | k = cubic_contribution(r).view(-1, 1) 84 | k = torch.matmul(k, k.t()) 85 | k /= k.sum() 86 | 87 | return k 88 | 89 | def reflect_padding( 90 | x: torch.Tensor, 91 | dim: int, 92 | pad_pre: int, 93 | pad_post: int) -> torch.Tensor: 94 | 95 | ''' 96 | Apply reflect padding to the given Tensor. 97 | Note that it is slightly different from the PyTorch functional.pad, 98 | where boundary elements are used only once. 99 | Instead, we follow the MATLAB implementation 100 | which uses boundary elements twice. 101 | 102 | For example, 103 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 104 | while our implementation yields [a, a, b, c, d, d]. 105 | ''' 106 | b, c, h, w = x.size() 107 | if dim == 2 or dim == -2: 108 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) 109 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) 110 | for p in range(pad_pre): 111 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) 112 | for p in range(pad_post): 113 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) 114 | else: 115 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) 116 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) 117 | for p in range(pad_pre): 118 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) 119 | for p in range(pad_post): 120 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) 121 | 122 | return padding_buffer 123 | 124 | def padding( 125 | x: torch.Tensor, 126 | dim: int, 127 | pad_pre: int, 128 | pad_post: int, 129 | padding_type: str='reflect') -> torch.Tensor: 130 | 131 | if padding_type == 'reflect': 132 | x_pad = reflect_padding(x, dim, pad_pre, pad_post) 133 | else: 134 | raise ValueError('{} padding is not supported!'.format(padding_type)) 135 | 136 | return x_pad 137 | 138 | def get_padding( 139 | base: torch.Tensor, 140 | kernel_size: int, 141 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 142 | 143 | base = base.long() 144 | r_min = base.min() 145 | r_max = base.max() + kernel_size - 1 146 | 147 | if r_min <= 0: 148 | pad_pre = -r_min 149 | pad_pre = pad_pre.item() 150 | base += pad_pre 151 | else: 152 | pad_pre = 0 153 | 154 | if r_max >= x_size: 155 | pad_post = r_max - x_size + 1 156 | pad_post = pad_post.item() 157 | else: 158 | pad_post = 0 159 | 160 | return pad_pre, pad_post, base 161 | 162 | def get_weight( 163 | dist: torch.Tensor, 164 | kernel_size: int, 165 | kernel: str='cubic', 166 | sigma: float=2.0, 167 | antialiasing_factor: float=1) -> torch.Tensor: 168 | 169 | buffer_pos = dist.new_zeros(kernel_size, len(dist)) 170 | for idx, buffer_sub in enumerate(buffer_pos): 171 | buffer_sub.copy_(dist - idx) 172 | 173 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 174 | buffer_pos *= antialiasing_factor 175 | if kernel == 'cubic': 176 | weight = cubic_contribution(buffer_pos) 177 | elif kernel == 'gaussian': 178 | weight = gaussian_contribution(buffer_pos, sigma=sigma) 179 | else: 180 | raise ValueError('{} kernel is not supported!'.format(kernel)) 181 | 182 | weight /= weight.sum(dim=0, keepdim=True) 183 | return weight 184 | 185 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 186 | # Resize height 187 | if dim == 2 or dim == -2: 188 | k = (kernel_size, 1) 189 | h_out = x.size(-2) - kernel_size + 1 190 | w_out = x.size(-1) 191 | # Resize width 192 | else: 193 | k = (1, kernel_size) 194 | h_out = x.size(-2) 195 | w_out = x.size(-1) - kernel_size + 1 196 | 197 | unfold = F.unfold(x, k) 198 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 199 | return unfold 200 | 201 | def resize_1d( 202 | x: torch.Tensor, 203 | dim: int, 204 | side: int=None, 205 | kernel: str='cubic', 206 | sigma: float=2.0, 207 | padding_type: str='reflect', 208 | antialiasing: bool=True) -> torch.Tensor: 209 | 210 | ''' 211 | Args: 212 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 213 | dim (int): 214 | scale (float): 215 | side (int): 216 | 217 | Return: 218 | ''' 219 | scale = side / x.size(dim) 220 | # Identity case 221 | if scale == 1: 222 | return x 223 | 224 | # Default bicubic kernel with antialiasing (only when downsampling) 225 | if kernel == 'cubic': 226 | kernel_size = 4 227 | else: 228 | kernel_size = math.floor(6 * sigma) 229 | 230 | if antialiasing and (scale < 1): 231 | antialiasing_factor = scale 232 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 233 | else: 234 | antialiasing_factor = 1 235 | 236 | # We allow margin to both sides 237 | kernel_size += 2 238 | 239 | # Weights only depend on the shape of input and output, 240 | # so we do not calculate gradients here. 241 | with torch.no_grad(): 242 | d = 1 / (2 * side) 243 | pos = torch.linspace( 244 | start=d, 245 | end=(1 - d), 246 | steps=side, 247 | dtype=x.dtype, 248 | device=x.device, 249 | ) 250 | pos = x.size(dim) * pos - 0.5 251 | base = pos.floor() - (kernel_size // 2) + 1 252 | dist = pos - base 253 | weight = get_weight( 254 | dist, 255 | kernel_size, 256 | kernel=kernel, 257 | sigma=sigma, 258 | antialiasing_factor=antialiasing_factor, 259 | ) 260 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) 261 | 262 | # To backpropagate through x 263 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) 264 | unfold = reshape_tensor(x_pad, dim, kernel_size) 265 | # Subsampling first 266 | if dim == 2 or dim == -2: 267 | sample = unfold[..., base, :] 268 | weight = weight.view(1, kernel_size, sample.size(2), 1) 269 | else: 270 | sample = unfold[..., base] 271 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 272 | 273 | # Apply the kernel 274 | down = sample * weight 275 | down = down.sum(dim=1, keepdim=True) 276 | return down 277 | 278 | def downsampling_2d( 279 | x: torch.Tensor, 280 | k: torch.Tensor, 281 | scale: int, 282 | padding_type: str='reflect') -> torch.Tensor: 283 | 284 | c = x.size(1) 285 | k_h = k.size(-2) 286 | k_w = k.size(-1) 287 | 288 | k = k.to(dtype=x.dtype, device=x.device) 289 | k = k.view(1, 1, k_h, k_w) 290 | k = k.repeat(c, c, 1, 1) 291 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 292 | e = e.view(c, c, 1, 1) 293 | k = k * e 294 | 295 | pad_h = (k_h - scale) // 2 296 | pad_w = (k_w - scale) // 2 297 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) 298 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) 299 | y = F.conv2d(x, k, padding=0, stride=scale) 300 | return y 301 | 302 | def imresize( 303 | x: torch.Tensor, 304 | scale: float=None, 305 | sides: typing.Tuple[int, int]=None, 306 | kernel: K='cubic', 307 | sigma: float=2, 308 | rotation_degree: float=0, 309 | padding_type: str='reflect', 310 | antialiasing: bool=True) -> torch.Tensor: 311 | 312 | ''' 313 | Args: 314 | x (torch.Tensor): 315 | scale (float): 316 | sides (tuple(int, int)): 317 | kernel (str, default='cubic'): 318 | sigma (float, default=2): 319 | rotation_degree (float, default=0): 320 | padding_type (str, default='reflect'): 321 | antialiasing (bool, default=True): 322 | 323 | Return: 324 | torch.Tensor: 325 | ''' 326 | 327 | if scale is None and sides is None: 328 | raise ValueError('One of scale or sides must be specified!') 329 | if scale is not None and sides is not None: 330 | raise ValueError('Please specify scale or sides to avoid conflict!') 331 | 332 | if x.dim() == 4: 333 | b, c, h, w = x.size() 334 | elif x.dim() == 3: 335 | c, h, w = x.size() 336 | b = None 337 | elif x.dim() == 2: 338 | h, w = x.size() 339 | b = c = None 340 | else: 341 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) 342 | 343 | x = x.view(-1, 1, h, w) 344 | 345 | if sides is None: 346 | # Determine output size 347 | sides = (math.ceil(h * scale), math.ceil(w * scale)) 348 | scale_inv = 1 / scale 349 | if isinstance(kernel, str) and scale_inv.is_integer(): 350 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 351 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 352 | raise ValueError( 353 | 'An integer downsampling factor ' 354 | 'should be used with a predefined kernel!' 355 | ) 356 | 357 | if x.dtype != torch.float32 or x.dtype != torch.float64: 358 | dtype = x.dtype 359 | x = x.float() 360 | else: 361 | dtype = None 362 | 363 | if isinstance(kernel, str): 364 | # Shared keyword arguments across dimensions 365 | kwargs = { 366 | 'kernel': kernel, 367 | 'sigma': sigma, 368 | 'padding_type': padding_type, 369 | 'antialiasing': antialiasing, 370 | } 371 | # Core resizing module 372 | x = resize_1d(x, -2, side=sides[0], **kwargs) 373 | x = resize_1d(x, -1, side=sides[1], **kwargs) 374 | elif isinstance(kernel, torch.Tensor): 375 | x = downsampling_2d(x, kernel, scale=int(1 / scale)) 376 | 377 | rh = x.size(-2) 378 | rw = x.size(-1) 379 | # Back to the original dimension 380 | if b is not None: 381 | x = x.view(b, c, rh, rw) # 4-dim 382 | else: 383 | if c is not None: 384 | x = x.view(c, rh, rw) # 3-dim 385 | else: 386 | x = x.view(rh, rw) # 2-dim 387 | 388 | if dtype is not None: 389 | if not dtype.is_floating_point: 390 | x = x.round() 391 | # To prevent over/underflow when converting types 392 | if dtype is torch.uint8: 393 | x = x.clamp(0, 255) 394 | 395 | x = x.to(dtype=dtype) 396 | 397 | return x 398 | 399 | if __name__ == '__main__': 400 | # Just for debugging 401 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200) 402 | a = torch.arange(64).float().view(1, 1, 8, 8) 403 | z = imresize(a, 0.5) 404 | print(z) 405 | #a = torch.arange(16).float().view(1, 1, 4, 4) 406 | ''' 407 | a = torch.zeros(1, 1, 4, 4) 408 | a[..., 0, 0] = 100 409 | a[..., 1, 0] = 10 410 | a[..., 0, 1] = 1 411 | a[..., 0, -1] = 100 412 | a = torch.zeros(1, 1, 4, 4) 413 | a[..., -1, -1] = 100 414 | a[..., -2, -1] = 10 415 | a[..., -1, -2] = 1 416 | a[..., -1, 0] = 100 417 | ''' 418 | #b = imresize(a, sides=(3, 8), antialiasing=False) 419 | #c = imresize(a, sides=(11, 13), antialiasing=True) 420 | #c = imresize(a, sides=(4, 4), antialiasing=False, kernel='gaussian', sigma=1) 421 | #print(a) 422 | #print(b) 423 | #print(c) 424 | 425 | #r = discrete_kernel('cubic', 1 / 3) 426 | #print(r) 427 | ''' 428 | a = torch.arange(225).float().view(1, 1, 15, 15) 429 | imresize(a, sides=[5, 5]) 430 | ''' 431 | -------------------------------------------------------------------------------- /src/bicubic_pytorch/example/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/example/butterfly.png -------------------------------------------------------------------------------- /src/bicubic_pytorch/example/noise_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/example/noise_input.png -------------------------------------------------------------------------------- /src/bicubic_pytorch/example/noise_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/example/noise_optimized.png -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_butterfly_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_butterfly_irregular_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_butterfly_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_butterfly_irregular_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_irregular_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_irregular_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_small_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_small_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_small_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_small_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_x2_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_x2_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_x3_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_x3_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_x4_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_x4_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/down_down_x5_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/down_down_x5_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/gen_test.m: -------------------------------------------------------------------------------- 1 | input_square = linspace(0, 63, 64); 2 | % Transpose required 3 | input_square = reshape(input_square, [8, 8])'; 4 | 5 | input_small = linspace(0, 15, 16); 6 | input_small = reshape(input_small, [4, 4])'; 7 | 8 | input_15x15 = linspace(0, 224, 225); 9 | input_15x15 = reshape(input_15x15, [15, 15])'; 10 | 11 | input_topleft = zeros(4, 4); 12 | input_topleft(1, 1) = 100; 13 | input_topleft(2, 1) = 10; 14 | input_topleft(1, 2) = 1; 15 | input_topleft(1, 4) = 100; 16 | 17 | input_bottomright = zeros(4, 4); 18 | input_bottomright(4, 4) = 100; 19 | input_bottomright(3, 4) = 10; 20 | input_bottomright(4, 3) = 1; 21 | input_bottomright(4, 1) = 100; 22 | 23 | 24 | fprintf('(4, 4) to (3, 3) without AA\n'); 25 | down_down_small_noaa = imresize( ... 26 | input_small, [3, 3], 'bicubic', 'antialiasing', false ... 27 | ); 28 | save('down_down_small_noaa.mat', 'down_down_small_noaa'); 29 | 30 | fprintf('(4, 4) to (3, 3) with AA\n'); 31 | down_down_small_aa = imresize( ... 32 | input_small, [3, 3], 'bicubic', 'antialiasing', true ... 33 | ); 34 | save('down_down_small_aa.mat', 'down_down_small_aa'); 35 | 36 | 37 | fprintf('(8, 8) to (3, 4) without AA\n'); 38 | down_down_noaa = imresize( ... 39 | input_square, [3, 4], 'bicubic', 'antialiasing', false ... 40 | ); 41 | save('down_down_noaa.mat', 'down_down_noaa'); 42 | 43 | fprintf('(8, 8) to (3, 4) with AA\n'); 44 | down_down_aa = imresize( ... 45 | input_square, [3, 4], 'bicubic', 'antialiasing', true ... 46 | ); 47 | save('down_down_aa.mat', 'down_down_aa'); 48 | 49 | 50 | fprintf('(8, 8) to (5, 7) without AA\n'); 51 | down_down_irregular_noaa = imresize( ... 52 | input_square, [5, 7], 'bicubic', 'antialiasing', false ... 53 | ); 54 | save('down_down_irregular_noaa.mat', 'down_down_irregular_noaa'); 55 | 56 | fprintf('(8, 8) to (5, 7) with AA\n'); 57 | down_down_irregular_aa = imresize( ... 58 | input_square, [5, 7], 'bicubic', 'antialiasing', true ... 59 | ); 60 | save('down_down_irregular_aa.mat', 'down_down_irregular_aa'); 61 | 62 | 63 | fprintf('(4, 4) topleft to (5, 5) without AA\n'); 64 | up_up_topleft_noaa = imresize( ... 65 | input_topleft, [5, 5], 'bicubic', 'antialiasing', false ... 66 | ); 67 | save('up_up_topleft_noaa.mat', 'up_up_topleft_noaa'); 68 | 69 | fprintf('(4, 4) bottomright to (5, 5) without AA\n'); 70 | up_up_bottomright_noaa = imresize( ... 71 | input_bottomright, [5, 5], 'bicubic', 'antialiasing', false ... 72 | ); 73 | save('up_up_bottomright_noaa.mat', 'up_up_bottomright_noaa'); 74 | 75 | 76 | fprintf('(8, 8) to (11, 13) without AA\n'); 77 | up_up_irregular_noaa = imresize( ... 78 | input_square, [11, 13], 'bicubic', 'antialiasing', false ... 79 | ); 80 | save('up_up_irregular_noaa.mat', 'up_up_irregular_noaa'); 81 | 82 | fprintf('(8, 8) to (11, 13) with AA\n'); 83 | up_up_irregular_aa = imresize( ... 84 | input_square, [11, 13], 'bicubic', 'antialiasing', true ... 85 | ); 86 | save('up_up_irregular_aa.mat', 'up_up_irregular_aa'); 87 | 88 | 89 | butterfly = imread(fullfile('..', 'example', 'butterfly.png')); 90 | butterfly = im2double(butterfly); 91 | fprintf('(256, 256) butterfly.png to (123, 234) without AA\n') 92 | down_down_butterfly_irregular_noaa = imresize( ... 93 | butterfly, [123, 234], 'bicubic', 'antialiasing', false ... 94 | ); 95 | save( ... 96 | 'down_down_butterfly_irregular_noaa.mat', ... 97 | 'down_down_butterfly_irregular_noaa' ... 98 | ); 99 | 100 | fprintf('(256, 256) butterfly.png to (123, 234) with AA\n') 101 | down_down_butterfly_irregular_aa = imresize( ... 102 | butterfly, [123, 234], 'bicubic', 'antialiasing', true ... 103 | ); 104 | save( ... 105 | 'down_down_butterfly_irregular_aa.mat', ... 106 | 'down_down_butterfly_irregular_aa' ... 107 | ); 108 | 109 | 110 | fprintf('(256, 256) butterfly.png to (1234, 789) without AA\n') 111 | up_up_butterfly_irregular_noaa = imresize( ... 112 | butterfly, [1234, 789], 'bicubic', 'antialiasing', false ... 113 | ); 114 | save( ... 115 | 'up_up_butterfly_irregular_noaa.mat', ... 116 | 'up_up_butterfly_irregular_noaa' ... 117 | ); 118 | 119 | 120 | fprintf('(8, 8) to (4, 4) with AA\n'); 121 | down_down_x2_aa = imresize( ... 122 | input_square, [4, 4], 'bicubic', 'antialiasing', true ... 123 | ); 124 | save('down_down_x2_aa.mat', 'down_down_x2_aa'); 125 | 126 | fprintf('(8, 8) to (2, 2) with AA\n'); 127 | down_down_x4_aa = imresize( ... 128 | input_square, [2, 2], 'bicubic', 'antialiasing', true ... 129 | ); 130 | save('down_down_x4_aa.mat', 'down_down_x4_aa'); 131 | 132 | 133 | fprintf('(15, 15) to (5, 5) with AA\n'); 134 | down_down_x3_aa = imresize( ... 135 | input_15x15, [5, 5], 'bicubic', 'antialiasing', true ... 136 | ); 137 | save('down_down_x3_aa.mat', 'down_down_x3_aa'); 138 | 139 | fprintf('(15, 15) to (3, 3) with AA\n'); 140 | down_down_x5_aa = imresize( ... 141 | input_15x15, [3, 3], 'bicubic', 'antialiasing', true ... 142 | ); 143 | save('down_down_x5_aa.mat', 'down_down_x5_aa'); 144 | -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/up_up_bottomright_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/up_up_bottomright_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/up_up_butterfly_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/up_up_butterfly_irregular_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/up_up_irregular_aa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/up_up_irregular_aa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/up_up_irregular_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/up_up_irregular_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_answer/up_up_topleft_noaa.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/bicubic_pytorch/test_answer/up_up_topleft_noaa.mat -------------------------------------------------------------------------------- /src/bicubic_pytorch/test_gradient.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import unittest 3 | 4 | import core 5 | import utils 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch import cuda 11 | from torch import optim 12 | 13 | 14 | class TestGradient(unittest.TestCase): 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | super().__init__(*args, **kwargs) 18 | self.n_iters = 200 19 | self.lr = 1e-2 20 | self.input_size = (123, 234) 21 | 22 | if cuda.is_available(): 23 | self.device = torch.device('cuda') 24 | else: 25 | self.device = torch.device('cpu') 26 | 27 | self.target = utils.get_img(path.join('example', 'butterfly.png')) 28 | self.target = self.target.to(self.device) 29 | self.target_size = (self.target.size(-2), self.target.size(-1)) 30 | 31 | def test_backpropagation(self) -> None: 32 | noise = torch.rand( 33 | 1, 34 | self.target.size(1), 35 | self.input_size[0], 36 | self.input_size[1], 37 | device=self.device, 38 | ) 39 | noise_p = nn.Parameter(noise, requires_grad=True) 40 | utils.save_img(noise_p, path.join('example', 'noise_input.png')) 41 | optimizer = optim.Adam([noise_p], lr=self.lr) 42 | 43 | for i in range(self.n_iters): 44 | optimizer.zero_grad() 45 | noise_up = core.imresize(noise_p, sides=self.target_size) 46 | loss = F.mse_loss(noise_up, self.target) 47 | loss.backward() 48 | if i == 0 or (i + 1) % 20 == 0: 49 | print('Iter {:0>4}\tLoss: {:.8f}'.format(i + 1, loss.item())) 50 | 51 | optimizer.step() 52 | 53 | utils.save_img(noise_p, path.join('example', 'noise_optimized.png')) 54 | assert loss.item() < 1e-2, 'Failed to optimize!' 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() -------------------------------------------------------------------------------- /src/bicubic_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Minor utilities for testing. 3 | You do not have to import this code to use core.py 4 | ''' 5 | 6 | import time 7 | import numpy as np 8 | import imageio 9 | 10 | import torch 11 | 12 | 13 | class Timer(object): 14 | 15 | def __init__(self, msg: str) -> None: 16 | self.msg = msg.replace('{}', '{:.6f}s') 17 | self.tic = None 18 | return 19 | 20 | def __enter__(self) -> None: 21 | self.tic = time.time() 22 | return 23 | 24 | def __exit__(self, *args, **kwargs) -> None: 25 | toc = time.time() - self.tic 26 | print('\n' + self.msg.format(toc)) 27 | return 28 | 29 | 30 | def get_img(img_path: str) -> torch.Tensor: 31 | img = imageio.imread(img_path) 32 | img = np.transpose(img, (2, 0, 1)) 33 | img = torch.from_numpy(img) 34 | while img.dim() < 4: 35 | img.unsqueeze_(0) 36 | 37 | img = img.float() / 255 38 | return img 39 | 40 | def save_img(x: torch.Tensor, img_path: str) -> None: 41 | with torch.no_grad(): 42 | x = 255 * x 43 | x = x.round().clamp(min=0, max=255).byte() 44 | x = x.squeeze(0) 45 | 46 | x = x.cpu().numpy() 47 | x = np.transpose(x, (1, 2, 0)) 48 | imageio.imwrite(img_path, x) 49 | return 50 | -------------------------------------------------------------------------------- /src/data_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from bicubic_pytorch import core 5 | 6 | def get_data_loss(img_s, img_gen, data_loss_type, down_filter, args): 7 | criterionL1 = nn.L1Loss() 8 | 9 | if data_loss_type == 'adl': 10 | if down_filter is not None: 11 | padL = args.adl_ksize // 2 12 | padR = args.adl_ksize // 2 13 | if args.adl_ksize % 2 == 0: 14 | padL -= 1 15 | 16 | filtered_img_s = down_filter(F.pad(img_s,(padL,padR,padL,padR),mode='replicate')) 17 | down_filtered_img_s = F.interpolate(filtered_img_s, scale_factor=0.5, mode='nearest', 18 | recompute_scale_factor=False) 19 | return criterionL1(down_filtered_img_s, img_gen) 20 | else: 21 | # use lal for initial few epochs for stablizing 22 | return get_data_loss(img_s, img_gen, 'lfl', down_filter, args) 23 | 24 | elif data_loss_type == 'lfl': 25 | hr_filter = nn.AvgPool2d(kernel_size=args.box_size*2, stride=args.box_size*2) 26 | lr_filter = nn.AvgPool2d(kernel_size=args.box_size, stride=args.box_size) 27 | return criterionL1(lr_filter(img_gen), hr_filter(img_s)) 28 | 29 | elif data_loss_type == 'bic': 30 | return criterionL1(core.imresize(img_s, scale=0.5), img_gen) 31 | 32 | elif data_loss_type == 'gau': 33 | gau_filter = GaussianLoss(scale=int(args.scale), sigma=args.gaussian_sigma, 34 | kernel_size=args.gaussian_ksize, strided=(not args.gaussian_dense)).cuda() 35 | return gau_filter(img_gen, img_s) 36 | 37 | else: 38 | raise NotImplementedError('Not supported data loss type') 39 | 40 | 41 | 42 | class GaussianLoss(nn.Module): 43 | def __init__( 44 | self, 45 | n_colors: int=3, 46 | kernel_size: int=16, 47 | scale: int=2, 48 | sigma: float=2.0, 49 | strided: bool=True, 50 | distance: str='l1') -> None: 51 | 52 | super().__init__() 53 | kx = gaussian_kernel(kernel_size=kernel_size, scale=1, sigma=sigma) 54 | kx = to_4d(kx, n_colors) 55 | self.register_buffer('kx', kx) 56 | 57 | ky = gaussian_kernel( 58 | kernel_size=(scale * kernel_size), 59 | scale=scale, 60 | sigma=sigma, 61 | ) 62 | ky = to_4d(ky, n_colors) 63 | self.register_buffer('ky', ky) 64 | 65 | self.scale = scale 66 | self.strided = strided 67 | self.distance = distance 68 | 69 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 70 | loss = filter_loss( 71 | x, 72 | self.kx, 73 | y, 74 | self.ky, 75 | scale=self.scale, 76 | strided=self.strided, 77 | distance=self.distance, 78 | ) 79 | return loss 80 | 81 | 82 | def filter_loss( 83 | x: torch.Tensor, 84 | kx: torch.Tensor, 85 | y: torch.Tensor, 86 | ky: torch.Tensor, 87 | scale: int=2, 88 | strided: bool=True, 89 | distance: str='l1') -> torch.Tensor: 90 | 91 | wx = x.size(-1) 92 | wy = y.size(-1) 93 | # x should be smaller than y 94 | if wx >= wy: 95 | return filter_loss(y, ky, x, kx, strided=strided, scale=scale) 96 | 97 | if strided: 98 | sx = ky.size(-1) 99 | else: 100 | sx = 1 101 | 102 | sy = scale * sx 103 | x = F.conv2d(x, kx, stride=sx, padding=0) 104 | y = F.conv2d(y, ky, stride=sy, padding=0) 105 | 106 | if distance == 'l1': 107 | loss = F.l1_loss(x, y) 108 | elif distance == 'mse': 109 | loss = F.mse_loss(x, y) 110 | else: 111 | raise ValueError('{} loss is not supported!'.format(distance)) 112 | 113 | return loss 114 | 115 | def gaussian_kernel( 116 | kernel_size: int=16, 117 | scale: int=1, 118 | sigma: float=2.0) -> torch.Tensor: 119 | 120 | kernel_half = kernel_size // 2 121 | # Distance from the center point 122 | if kernel_size % 2 == 0: 123 | r = torch.linspace(-kernel_half + 0.5, kernel_half - 0.5, kernel_size) 124 | else: 125 | r = torch.linspace(-kernel_half, kernel_half, kernel_size) 126 | 127 | # Do not backpropagate through the kernel 128 | r.requires_grad = False 129 | r /= scale 130 | 131 | r = r.view(1, -1) 132 | r = r.repeat(kernel_size, 1) 133 | r = r ** 2 134 | r = r + r.t() 135 | 136 | exponent = -r / (2 * sigma**2) 137 | k = exponent.exp() 138 | k = k / k.sum() 139 | return k 140 | 141 | def to_4d(k: torch.Tensor, n_colors: int) -> torch.Tensor: 142 | with torch.no_grad(): 143 | k.unsqueeze_(0).unsqueeze_(0) 144 | k = k.repeat(n_colors, n_colors, 1, 1) 145 | e = torch.eye(n_colors, n_colors) 146 | e.unsqueeze_(-1).unsqueeze_(-1) 147 | k *= e 148 | 149 | return k -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | import random 5 | import pickle 6 | import imageio 7 | import torch.utils.data as data 8 | from PIL import Image, ImageFile 9 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, ToTensor, Normalize 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | class unpaired_dataset(data.Dataset): 13 | def __init__(self, args, phase='train'): 14 | if phase == 'train': 15 | self.dataroot = args.train_dataroot 16 | Source_type = args.source 17 | Target_type = args.target 18 | 19 | else: 20 | self.dataroot = args.test_dataroot 21 | Source_type = args.source 22 | Target_type = args.source 23 | 24 | self.args = args 25 | 26 | ## Source 27 | images_source = sorted(os.listdir(os.path.join(self.dataroot, Source_type))) 28 | self.images_source = [os.path.join(self.dataroot, Source_type, x) for x in images_source] 29 | ## Target 30 | images_target = sorted(os.listdir(os.path.join(self.dataroot, Target_type))) 31 | self.images_target = [os.path.join(self.dataroot, Target_type, x) for x in images_target] 32 | 33 | self.phase = phase 34 | self.binary = False 35 | 36 | print('\nphase: {}'.format(phase)) 37 | 38 | ## checking or making binary files to boost loading speed 39 | if not args.nobin and not os.path.exists(os.path.join(self.dataroot, 'bin')): 40 | os.mkdir(os.path.join(self.dataroot, 'bin')) 41 | if not args.nobin: 42 | if not os.path.exists(os.path.join(self.dataroot, 'bin', Source_type)): 43 | os.mkdir(os.path.join(self.dataroot, 'bin', Source_type)) 44 | print('no binary file for Source is detected') 45 | print('making binary for Source ...') 46 | for i in tqdm.tqdm(range(len(self.images_source))): 47 | f = os.path.join(self.dataroot, 'bin', Source_type, self.images_source[i].split('/')[-1].split('.')[0]+'.pt') 48 | with open(f, 'wb') as _f: 49 | pickle.dump(imageio.imread(self.images_source[i]), _f) 50 | print('Done') 51 | self.binary = True 52 | else: 53 | print('binary files for {} already exist'.format(Source_type)) 54 | self.binary = True 55 | 56 | if not os.path.exists(os.path.join(self.dataroot, 'bin', Target_type)): 57 | os.mkdir(os.path.join(self.dataroot, 'bin', Target_type)) 58 | print('no binary file for {} are detected'.format(Target_type)) 59 | print('making binary for {} ...'.format(Target_type)) 60 | for j in tqdm.tqdm(range(len(self.images_target))): 61 | f = os.path.join(self.dataroot, 'bin', Target_type, self.images_target[j].split('/')[-1].split('.')[0]+'.pt') 62 | with open(f, 'wb') as _f: 63 | pickle.dump(imageio.imread(self.images_target[j]), _f) 64 | print('Done') 65 | self.binary = True 66 | else: 67 | if phase == 'train': 68 | print('binary files for {} already exist'.format(Target_type)) 69 | self.binary = True 70 | else: 71 | print('do not use binary files') 72 | 73 | ## change base folder to bin if binary option is enabled 74 | if self.binary: 75 | images_source = sorted(os.listdir(os.path.join(self.dataroot, 'bin', Source_type))) 76 | images_target = sorted(os.listdir(os.path.join(self.dataroot, 'bin', Target_type))) 77 | self.images_source = [os.path.join(self.dataroot, 'bin', Source_type, x) for x in images_source] 78 | self.images_target = [os.path.join(self.dataroot, 'bin', Target_type, x) for x in images_target] 79 | 80 | self.images_source_size = len(self.images_source) 81 | self.images_target_size = len(self.images_target) 82 | 83 | if phase=='test': 84 | patches_source_size = len(self.images_source) 85 | patches_target_size = len(self.images_target) 86 | else: 87 | patches_source_size = 0 88 | patches_target_size = 0 89 | for i in range(len(self.images_source)): 90 | img_name = self.images_source[i] 91 | if self.binary: 92 | with open(img_name, 'rb') as _f: 93 | img = pickle.load(_f) 94 | img = Image.fromarray(img) 95 | else: 96 | img = Image.open(img_name).convert('RGB') 97 | 98 | patches_source_size += (img.size[0] // 192 ) * (img.size[1] // 192) * 0.75 # just hyper parameter 99 | 100 | for i in range(len(self.images_target)): 101 | img_name = self.images_target[i] 102 | if self.binary: 103 | with open(img_name, 'rb') as _f: 104 | img = pickle.load(_f) 105 | img = Image.fromarray(img) 106 | else: 107 | img = Image.open(img_name).convert('RGB') 108 | 109 | patches_target_size += (img.size[0] // 96) * (img.size[1] // 96) * 0.75 110 | 111 | 112 | ## since we are dealing with unpaired setting, 113 | ## we can not assure same dataset size between source and target domain. 114 | ## Therefore we just set dataset size as maximum value of two. 115 | self.dataset_size = int(max(patches_source_size, patches_target_size)) 116 | if phase == 'test': 117 | self.dataset_size = self.images_source_size 118 | 119 | if self.phase == 'train': 120 | transforms_source = [RandomCrop(args.patch_size_down)] 121 | transforms_target = [RandomCrop(args.patch_size_down//2)] 122 | if args.flip: 123 | transforms_source.append(RandomHorizontalFlip()) 124 | transforms_source.append(RandomVerticalFlip()) 125 | transforms_target.append(RandomHorizontalFlip()) 126 | transforms_target.append(RandomVerticalFlip()) 127 | else: 128 | transforms_source = [] 129 | transforms_target = [] 130 | 131 | transforms_source.append(ToTensor()) 132 | transforms_source.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 133 | self.transforms_source = Compose(transforms_source) 134 | 135 | transforms_target.append(ToTensor()) 136 | transforms_target.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 137 | self.transforms_target = Compose(transforms_target) 138 | 139 | if phase == 'train': 140 | print('Source: %d, Target: %d images'%(self.images_source_size, self.images_target_size)) 141 | else: 142 | print('Source: %d'%(self.images_source_size)) 143 | 144 | 145 | def __getitem__(self, index): 146 | index_source = index % self.images_source_size 147 | index_target = random.randint(0, self.images_target_size - 1) ## for randomness 148 | 149 | data_source, fn = self.load_img(self.images_source[index_source]) 150 | data_target, _ = self.load_img(self.images_target[index_target], domain='target') 151 | 152 | return data_source, data_target, fn 153 | 154 | def load_img(self, img_name, input_dim=3, domain='source'): 155 | ## loading images 156 | if self.binary: 157 | with open(img_name, 'rb') as _f: 158 | img = pickle.load(_f) 159 | img = Image.fromarray(img) 160 | else: 161 | img = Image.open(img_name).convert('RGB') 162 | fn = img_name.split('/')[-1] 163 | 164 | ## apply different transfomation along domain 165 | if domain == 'source': 166 | img = self.transforms_source(img) 167 | else: 168 | img = self.transforms_target(img) 169 | 170 | ## rotating 171 | rot = self.args.rot and random.random() < 0.5 172 | if rot: 173 | img = img.transpose(1,2) 174 | 175 | ## flipping 176 | flip_h = self.args.flip and random.random() < 0.5 177 | flip_v = self.args.flip and random.random() < 0.5 178 | if flip_h: 179 | img = torch.flip(img, [2]) 180 | if flip_v: 181 | img = torch.flip(img, [1]) 182 | 183 | return img, fn 184 | 185 | def __len__(self): 186 | if self.phase == 'train': 187 | return int( self.dataset_size * 2 ) # one epoch for two cycle of training dataset 188 | else: 189 | return self.dataset_size 190 | 191 | 192 | class paired_dataset(data.Dataset): # only for joint SR 193 | def __init__(self, args): 194 | self.dataroot = args.test_dataroot 195 | self.args = args 196 | 197 | if args.realsr: 198 | test_hr = args.test_lr 199 | else: 200 | if args.test_hr is None: 201 | raise NotImplementedError("test_hr set should be given") 202 | test_hr = args.test_hr 203 | 204 | ## HR 205 | images_hr = sorted(os.listdir(os.path.join(self.dataroot, test_hr))) 206 | images_hr = images_hr[int(args.test_range.split('-')[0])-1: int(args.test_range.split('-')[1]) ] 207 | self.images_hr = [os.path.join(self.dataroot, test_hr, x) for x in images_hr] 208 | ## LR 209 | images_lr = sorted(os.listdir(os.path.join(self.dataroot, args.test_lr))) 210 | images_lr = images_lr[int(args.test_range.split('-')[0])-1: int(args.test_range.split('-')[1]) ] 211 | self.images_lr = [os.path.join(self.dataroot, args.test_lr, x) for x in images_lr] 212 | 213 | self.images_hr_size = len(self.images_hr) 214 | self.images_lr_size = len(self.images_lr) 215 | 216 | assert(self.images_hr_size == self.images_lr_size) 217 | 218 | transforms = [] 219 | transforms.append(ToTensor()) 220 | self.transforms = Compose(transforms) 221 | 222 | 223 | print('\njoint training option is enabled') 224 | print('HR set: {}, LR set: {}'.format(args.test_hr, args.test_lr)) 225 | print('number of test images for SR : %d images' %(self.images_hr_size)) 226 | 227 | 228 | def __getitem__(self, index): 229 | data_hr, fn_hr = self.load_img(self.images_hr[index]) 230 | data_lr, fn_lr = self.load_img(self.images_lr[index], domain='lr') 231 | 232 | return data_hr, data_lr, fn_lr 233 | 234 | def load_img(self, img_name, input_dim=3, domain='hr'): 235 | ## loading images 236 | img = Image.open(img_name).convert('RGB') 237 | fn = img_name.split('/')[-1] 238 | 239 | ## apply transfomation 240 | img = self.transforms(img) 241 | 242 | ## rotating and flipping 243 | rot = self.args.rot and random.random() < 0.5 244 | flip_h = self.args.flip and random.random() < 0.5 245 | flip_v = self.args.flip and random.random() < 0.5 246 | if rot: 247 | img = img.transpose(1,2) 248 | if flip_h: 249 | img = torch.flip(img, [2]) 250 | if flip_v: 251 | img = torch.flip(img, [1]) 252 | 253 | return img, fn 254 | 255 | def __len__(self): 256 | return self.images_hr_size -------------------------------------------------------------------------------- /src/demo.sh: -------------------------------------------------------------------------------- 1 | # retreiving downsampling kernel between HR image dataset 'Source' and LR image dataset 'Target'. 2 | # generate downsampled version of 'Source' image from retrieved downsampling kernel saved in './experiments/save_name/down_results/'. 3 | CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --make_down 4 | 5 | ### (Optional) joint trainig with SR model(default: edsr) 6 | ## edsr style training 7 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint 8 | 9 | ## esrgan style training 10 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint --training_type esrgan 11 | 12 | ## scale x4 training with RRDB model 13 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint --sr_model rrdb --scale 4 14 | 15 | ## specify sr training duration 16 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint --epochs_sr_start 41 --epochs_sr_end 80 17 | 18 | ## you can test your sr model on validation set if you have. 19 | ## note that validation set 'filename_lr' and 'filename_hr' should be located at './dataset/'. 20 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint --test_lr filename_lr --test_hr filename_hr 21 | 22 | ## you may only have target LR images, as common case for real-world target. 23 | ## in that case, you can visualize sr_results and check them in './experiments/save_name/sr_results/'. 24 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --name save_name --joint --test_lr filename_lr --realsr --save_results 25 | 26 | 27 | ## more demo scripts used for main paper are in below 28 | 29 | ## x2 scale sr for synthetic DIV2K target 30 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filname_lr --test_hr filename_hr --save_results --sr_model edsr --scale 2 --epochs_sr_start 41 --training_type edsr --batch_size 24 --name synthetic_div2k-edsrx2 31 | 32 | ## x4 scale sr for synthetic DIV2K target 33 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filename_lr --test_hr filename_hr --save_results --sr_model edsr --scale 4 --epochs_sr_start 41 --training_type edsr --batch_size 24 --name synthetic_div2k-edsrx4 34 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filename_lr --test_hr filename_hr --save_results --sr_model rrdb --scale 4 --epochs_sr_start 61 --training_type edsr --batch_size 24 --pretrain_sr model/pretrain/rrdb_x4-9d40f7f7.pth --name synthetic_div2k-rrdbx4 35 | 36 | ## x2 scale sr for realsr dataset 37 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filename_lr --test_hr filename_hr --save_results --sr_model edsr --scale 2 --epochs_sr_start 41 --training_type edsr --batch_size 24 --name realsr-edsrx2 --noise 38 | 39 | ## x4 scale sr for realsr dataset 40 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filename_lr --test_hr filename_hr --save_results --sr_model edsr --scale 4 --epochs_sr_start 41 --training_type edsr --batch_size 24 --noise --name realsr-edsrx4 --make_down 41 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr filename_lr --test_hr filename_hr --save_results --sr_model rrdb --scale 4 --epochs_sr_start 61 --training_type edsr --batch_size 24 --pretrain_sr model/pretrain/rrdb_x4-9d40f7f7.pth --noise --name realsr-rrdbx4 --make_down 42 | 43 | ## x2 scale sr for dped dataset 44 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model edsr --scale 2 --epochs_sr_start 41 --training_type edsr --batch_size 24 --name dped-edsrx2-edsr 45 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model edsr --scale 2 --epochs_sr_start 41 --training_type esrgan --batch_size 24 --name dped-edsrx2-esrgan 46 | 47 | ## x4 scale sr for dped dataset 48 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model edsr --scale 4 --epochs_sr_start 61 --training_type edsr --batch_size 24 --chop --name dped-edsrx4-edsr 49 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model edsr --scale 4 --epochs_sr_start 61 --training_type esrgan --batch_size 24 --chop --name dped-edsrx4-esrgan 50 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model rrdb --scale 4 --epochs_sr_start 61 --training_type edsr --batch_size 8 --chop --pretrain_sr model/pretrain/rrdb_x4-9d40f7f7.pth --name dped-rrdbx4-edsr 51 | # CUDA_VISIBLE_DEVICES=0 python train.py --source Source --target Target --joint --test_lr Target --realsr --save_results --sr_model rrdb --scale 4 --epochs_sr_start 61 --training_type esrgan --batch_size 8 --chop --pretrain_sr model/pretrain/rrdb_x4-9d40f7f7.pth --name dped-rrdbx4-esrgan #--con_w 0.01 -------------------------------------------------------------------------------- /src/filters.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | def gaussian_kernel(sigma: float=1) -> torch.Tensor: 6 | kernel_width = math.ceil(3 * sigma) 7 | r = torch.linspace(-kernel_width, kernel_width, 2 * kernel_width + 1) 8 | r = r.view(1, -1) 9 | r = r.repeat(2 * kernel_width + 1, 1) 10 | r = r**2 11 | # Squared distance from origin 12 | r = r + r.t() 13 | 14 | exp = -r / (2 * sigma**2) 15 | coeff = exp.exp() 16 | coeff = coeff / coeff.sum() 17 | return coeff 18 | 19 | def filtering(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: 20 | k = k.to(x.device) 21 | kh, kw = k.size() 22 | if x.dim() == 4: 23 | c = x.size(1) 24 | k = k.view(1, 1, kh, kw) 25 | k = k.repeat(c, c, 1, 1) 26 | e = torch.eye(c, c) 27 | e = e.to(x.device) 28 | e = e.view(c, c, 1, 1) 29 | k *= e 30 | else: 31 | raise ValueError('x.dim() == {}! It should be 3 or 4.'.format(x.dim())) 32 | 33 | x = F.pad(x, (kh // 2, kh // 2, kw // 2, kw // 2), mode='replicate') 34 | y = F.conv2d(x, k, padding=0) 35 | return y 36 | 37 | def gaussian_filtering(x: torch.Tensor, sigma: float=1) -> torch.Tensor: 38 | k = gaussian_kernel(sigma=sigma) 39 | y = filtering(x, k) 40 | return y 41 | 42 | def find_kernel( 43 | x: torch.Tensor, 44 | y: torch.Tensor, 45 | scale: int, 46 | k: int, 47 | max_patches: int=None, 48 | threshold: float=1e-5) -> torch.Tensor: 49 | ''' 50 | Args: 51 | x (torch.Tensor): (B x C x H x W or C x H x W) A high-resolution image. 52 | y (torch.Tensor): (B x C x H x W or C x H x W) A low-resolution image. 53 | scale (int): Downsampling scale. 54 | k (int): Kernel size. 55 | max_patches (int, optional): Maximum number of patches to use. 56 | If not specified, use minimum number of patches. 57 | If set to -1, use all possible patches. 58 | You will get a better result with more patches. 59 | 60 | threshold (float, optional): Ignore values smaller than the threshold. 61 | 62 | Return: 63 | torch.Tensor: (k x k) The calculated kernel. 64 | ''' 65 | if x.dim() == 3: 66 | x = x.unsqueeze(0) 67 | 68 | if y.dim() == 3: 69 | y = y.unsqueeze(0) 70 | 71 | bx, cx, hx, wx = x.size() 72 | by, cy, hy, wy = y.size() 73 | 74 | # If y is larger than x 75 | if hx < hy: 76 | return find_kernel(y, x) 77 | 78 | # We convert RGB images to grayscale 79 | def luminance(rgb): 80 | coeff = rgb.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) 81 | l = torch.sum(coeff * rgb, dim=1, keepdim=True) 82 | return l 83 | 84 | if cx == 3: 85 | x = luminance(x) 86 | if cy == 3: 87 | y = luminance(y) 88 | 89 | k_half = k // 2 90 | crop_y = math.ceil((k_half - 1) / 2) 91 | if crop_y > 0: 92 | y = y[..., crop_y:-crop_y, crop_y:-crop_y] 93 | hy_crop = hy - 2 * crop_y 94 | wy_crop = wy - 2 * crop_y 95 | 96 | # Flatten 97 | y = y.reshape(by, -1) 98 | 99 | hx_target = k + scale * (hy_crop - 1) 100 | crop_x = (hx - hx_target) // 2 101 | if crop_x > 0: 102 | x = x[..., crop_x:-crop_x, crop_x:-crop_x] 103 | hx_crop = hx - 2 * crop_x 104 | wx_crop = wx - 2 * crop_x 105 | 106 | x = F.unfold(x, k, stride=scale) 107 | x_spatial = x.view(bx, k, k, -1) 108 | 109 | ''' 110 | Gradient-based sampling 111 | Caculate the gradient to determine which patches to use 112 | ''' 113 | gx = x.new_zeros(1, k, k, 1) 114 | gx[:, k_half - 1, k_half - 1, :] = -1 115 | gx[:, k_half - 1, k_half, :] = 1 116 | grad_x = x_spatial * gx 117 | grad_x = grad_x.view(bx, k * k, -1) 118 | 119 | gy = x.new_zeros(1, k, k, 1) 120 | gy[:, k_half - 1, k_half - 1, :] = -1 121 | gy[:, k_half, k_half - 1, :] = 1 122 | grad_y = x_spatial * gy 123 | grad_y = grad_y.view(by, k * k, -1) 124 | 125 | grad = grad_x.sum(1).pow(2) + grad_y.sum(1).pow(2) 126 | grad_order = grad.view(-1).argsort(dim=-1, descending=True) 127 | 128 | # We need at least k^2 patches 129 | if max_patches is None: 130 | max_patches = k**2 131 | elif max_patches == -1: 132 | max_patches = len(grad_order) 133 | else: 134 | max_patches = min(max(k**2, max_patches), len(grad_order)) 135 | 136 | grad_order = grad_order[:max_patches].view(-1) 137 | ''' 138 | Increase precision for numerical accuracy. 139 | You will get wrong results with FLOAT32!!! 140 | ''' 141 | # We use only one sample in the given batch 142 | x_sampled = x[0, ..., grad_order].double() 143 | x_t = x_sampled.t() 144 | 145 | y_sampled = y[0, ..., grad_order].double() 146 | y = y_sampled.unsqueeze(0) 147 | 148 | kernel = torch.matmul(y, x_t) 149 | kernel_c = torch.matmul(x_sampled, x_t) 150 | kernel_c = torch.inverse(kernel_c) 151 | kernel = torch.matmul(kernel, kernel_c) 152 | # For debugging 153 | #from scipy import io 154 | #io.savemat('tensor.mat', {'x_t': x_t.numpy(), 'y': y.numpy(), 'kernel': kernel.numpy()}) 155 | 156 | # Kernel thresholding and normalization 157 | kernel = kernel * (kernel.abs() > threshold).double() 158 | #kernel = kernel / kernel.sum() 159 | kernel = kernel.view(k, k).float() 160 | return kernel 161 | 162 | ''' 163 | if __name__ == '__main__': 164 | import numpy as np 165 | import imageio 166 | 167 | a = imageio.imread('../../lab/baby.png') 168 | a = np.transpose(a, (2, 0, 1)) 169 | a = torch.from_numpy(a).unsqueeze(0).float() 170 | b = gaussian_filtering(a, sigma=0.3) 171 | b = b.round().clamp(min=0, max=255).byte() 172 | b = b.squeeze(0) 173 | b = b.numpy() 174 | b = np.transpose(b, (1, 2, 0)) 175 | imageio.imwrite('../../lab/baby_filtered.png', b) 176 | 177 | #x = torch.arange(64).view(1, 1, 8, 8).float() 178 | #y = torch.arange(16).view(1, 1, 4, 4).float() 179 | from PIL import Image 180 | from torchvision.transforms import functional as TF 181 | x = Image.open('../../../dataset/DIV2K/DIV2K_train_HR/0001.png') 182 | x = TF.to_tensor(x) 183 | y = Image.open('DIV2K_train_LR_d104/X2/0001x2.png') 184 | y = TF.to_tensor(y) 185 | k = 20 186 | kernel = find_kernel(x, y, scale=2, k=k, max_patches=-1) 187 | 188 | kernel /= kernel.abs().max() 189 | k_pos = kernel * (kernel > 0).float() 190 | k_neg = kernel * (kernel < 0).float() 191 | k_rgb = torch.stack([-k_neg, k_pos, torch.zeros_like(k_pos)], dim=0) 192 | pil = TF.to_pil_image(k_rgb.cpu()) 193 | pil = pil.resize((k * 20, k * 20), resample=Image.NEAREST) 194 | pil.save('kernel.png') 195 | pil.show() 196 | ''' 197 | 198 | -------------------------------------------------------------------------------- /src/model/Discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def SpectralConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): 6 | return spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) 7 | 8 | 9 | class Discriminator(nn.Module): 10 | def __init__(self, args, num_conv_block=4): 11 | super(Discriminator, self).__init__() 12 | block = [] 13 | 14 | in_channels = 3 15 | out_channels = 64 16 | 17 | ConvKind = nn.Conv2d 18 | #if args.dis_spectral_norm: 19 | # ConvKind = SpectralConv2d 20 | 21 | for _ in range(num_conv_block): 22 | block += [nn.ReflectionPad2d(1), 23 | ConvKind(in_channels, out_channels, 3), 24 | nn.LeakyReLU(), 25 | nn.BatchNorm2d(out_channels)] 26 | in_channels = out_channels 27 | 28 | block += [nn.ReflectionPad2d(1), 29 | ConvKind(in_channels, out_channels, 3, 2), 30 | nn.LeakyReLU()] 31 | out_channels *= 2 32 | 33 | out_channels //= 2 34 | in_channels = out_channels 35 | 36 | block += [ConvKind(in_channels, out_channels, 3), 37 | nn.LeakyReLU(0.2), 38 | ConvKind(out_channels, out_channels, 3)] 39 | 40 | self.feature_extraction = nn.Sequential(*block) 41 | 42 | self.classification = nn.Sequential( 43 | nn.Linear(8192, 100), 44 | nn.Linear(100, 1) 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.feature_extraction(x) 49 | x = x.view(x.size(0), -1) 50 | x = self.classification(x) 51 | return x 52 | 53 | 54 | #################################################################### 55 | #--------------------- Spectral Normalization --------------------- 56 | # This part of code is copied from pytorch master branch (0.5.0) 57 | #################################################################### 58 | class SpectralNorm(object): 59 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 60 | self.name = name 61 | self.dim = dim 62 | if n_power_iterations <= 0: 63 | raise ValueError('Expected n_power_iterations to be positive, but ' 64 | 'got n_power_iterations={}'.format(n_power_iterations)) 65 | self.n_power_iterations = n_power_iterations 66 | self.eps = eps 67 | def compute_weight(self, module): 68 | weight = getattr(module, self.name + '_orig') 69 | u = getattr(module, self.name + '_u') 70 | weight_mat = weight 71 | if self.dim != 0: 72 | # permute dim to front 73 | weight_mat = weight_mat.permute(self.dim, 74 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 75 | height = weight_mat.size(0) 76 | weight_mat = weight_mat.reshape(height, -1) 77 | with torch.no_grad(): 78 | for _ in range(self.n_power_iterations): 79 | v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 80 | u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 81 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 82 | weight = weight / sigma 83 | return weight, u 84 | def remove(self, module): 85 | weight = getattr(module, self.name) 86 | delattr(module, self.name) 87 | delattr(module, self.name + '_u') 88 | delattr(module, self.name + '_orig') 89 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 90 | def __call__(self, module, inputs): 91 | if module.training: 92 | weight, u = self.compute_weight(module) 93 | setattr(module, self.name, weight) 94 | setattr(module, self.name + '_u', u) 95 | else: 96 | r_g = getattr(module, self.name + '_orig').requires_grad 97 | getattr(module, self.name).detach_().requires_grad_(r_g) 98 | 99 | @staticmethod 100 | def apply(module, name, n_power_iterations, dim, eps): 101 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 102 | weight = module._parameters[name] 103 | height = weight.size(dim) 104 | u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 105 | delattr(module, fn.name) 106 | module.register_parameter(fn.name + "_orig", weight) 107 | module.register_buffer(fn.name, weight.data) 108 | module.register_buffer(fn.name + "_u", u) 109 | module.register_forward_pre_hook(fn) 110 | return fn 111 | 112 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 113 | if dim is None: 114 | if isinstance(module, (torch.nn.ConvTranspose1d, 115 | torch.nn.ConvTranspose2d, 116 | torch.nn.ConvTranspose3d)): 117 | dim = 1 118 | else: 119 | dim = 0 120 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 121 | return module 122 | 123 | def remove_spectral_norm(module, name='weight'): 124 | for k, hook in module._forward_pre_hooks.items(): 125 | if isinstance(hook, SpectralNorm) and hook.name == name: 126 | hook.remove(module) 127 | del module._forward_pre_hooks[k] 128 | return module 129 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) -------------------------------------------------------------------------------- /src/model/ESRGAN.py: -------------------------------------------------------------------------------- 1 | from model.block import * 2 | 3 | 4 | class ESRGAN(nn.Module): 5 | def __init__(self, in_channels, out_channels, nf=64, gc=32, scale_factor=4, n_basic_block=23): 6 | super(ESRGAN, self).__init__() 7 | 8 | self.conv1 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(in_channels, nf, 3), nn.ReLU()) 9 | 10 | basic_block_layer = [] 11 | 12 | for _ in range(n_basic_block): 13 | basic_block_layer += [ResidualInResidualDenseBlock(nf, gc)] 14 | 15 | self.basic_block = nn.Sequential(*basic_block_layer) 16 | 17 | self.conv2 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(nf, nf, 3), nn.ReLU()) 18 | self.upsample = upsample_block(nf, scale_factor=scale_factor) 19 | self.conv3 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(nf, nf, 3), nn.ReLU()) 20 | self.conv4 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(nf, out_channels, 3), nn.ReLU()) 21 | 22 | def forward(self, x): 23 | x1 = self.conv1(x) 24 | x = self.basic_block(x1) 25 | x = self.conv2(x) 26 | x = self.upsample(x + x1) 27 | x = self.conv3(x) 28 | x = self.conv4(x) 29 | return x 30 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args): 11 | super(Model, self).__init__() 12 | self.scale = args.scale 13 | self.idx_scale = 0 14 | self.input_large = (args.sr_model == 'VDSR') 15 | self.self_ensemble = False 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = False #args.cpu 19 | self.device = torch.device(args.gpu) #'cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = 1 #args.n_GPUs 21 | self.save_models = True 22 | 23 | module = import_module('model.' + args.sr_model.lower()) 24 | self.model = module.make_model(args).to(self.device) 25 | if self.precision == 'half': 26 | self.model.half() 27 | 28 | 29 | def forward(self, x, idx_scale): 30 | self.idx_scale = idx_scale 31 | if hasattr(self.model, 'set_scale'): 32 | self.model.set_scale(idx_scale) 33 | 34 | if self.training: 35 | if self.n_GPUs > 1: 36 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 37 | else: 38 | return self.model(x) 39 | else: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | if self.self_ensemble: 46 | return self.forward_x8(x, forward_function=forward_function) 47 | else: 48 | return forward_function(x) 49 | 50 | def save(self, apath, epoch, is_best=False, save_specific_epoch=False): 51 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 52 | 53 | if is_best: 54 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 55 | if save_specific_epoch: 56 | save_dirs.append(os.path.join(apath, 'model_{:03d}.pt'.format(epoch))) 57 | 58 | for s in save_dirs: 59 | torch.save(self.model.state_dict(), s) 60 | 61 | def load(self, apath, pre_train='', resume=-1, cpu=False): 62 | load_from = None 63 | kwargs = {} 64 | if cpu: 65 | kwargs = {'map_location': lambda storage, loc: storage} 66 | 67 | if resume == -1: 68 | load_from = torch.load( 69 | os.path.join(apath, 'model_latest.pt'), 70 | **kwargs 71 | ) 72 | print('load model from : ', os.path.join(apath, 'model_latest.pt')) 73 | elif resume == 0: 74 | if pre_train == 'download': 75 | print('Download the model') 76 | dir_model = os.path.join('..', 'models') 77 | os.makedirs(dir_model, exist_ok=True) 78 | load_from = torch.utils.model_zoo.load_url( 79 | self.model.url, 80 | model_dir=dir_model, 81 | **kwargs 82 | ) 83 | elif pre_train: 84 | print('Load the model from {}'.format(pre_train)) 85 | load_from = torch.load(pre_train, **kwargs) 86 | else: 87 | load_from = torch.load( 88 | os.path.join(apath, 'model_{}.pt'.format(resume)), 89 | **kwargs 90 | ) 91 | 92 | if load_from: 93 | self.model.load_state_dict(load_from, strict=False) 94 | 95 | def forward_chop(self, *args, shave=10, min_size=1000000): 96 | scale = 1 if self.input_large else self.scale[self.idx_scale] 97 | n_GPUs = min(self.n_GPUs, 4) 98 | # height, width 99 | h, w = args[0].size()[-2:] 100 | 101 | top = slice(0, h//2 + shave) 102 | bottom = slice(h - h//2 - shave, h) 103 | left = slice(0, w//2 + shave) 104 | right = slice(w - w//2 - shave, w) 105 | x_chops = [torch.cat([ 106 | a[..., top, left], 107 | a[..., top, right], 108 | a[..., bottom, left], 109 | a[..., bottom, right] 110 | ]) for a in args] 111 | 112 | y_chops = [] 113 | if h * w < 4 * min_size: 114 | for i in range(0, 4, n_GPUs): 115 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 116 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 117 | if not isinstance(y, list): y = [y] 118 | if not y_chops: 119 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 120 | else: 121 | for y_chop, _y in zip(y_chops, y): 122 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 123 | else: 124 | for p in zip(*x_chops): 125 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 126 | if not isinstance(y, list): y = [y] 127 | if not y_chops: 128 | y_chops = [[_y] for _y in y] 129 | else: 130 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 131 | 132 | h *= int(scale) 133 | w *= int(scale) 134 | top = slice(0, h//2) 135 | bottom = slice(h - h//2, h) 136 | bottom_r = slice(h//2 - h, None) 137 | left = slice(0, w//2) 138 | right = slice(w - w//2, w) 139 | right_r = slice(w//2 - w, None) 140 | 141 | # batch size, number of color channels 142 | b, c = y_chops[0][0].size()[:-2] 143 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 144 | for y_chop, _y in zip(y_chops, y): 145 | _y[..., top, left] = y_chop[0][..., top, left] 146 | _y[..., top, right] = y_chop[1][..., top, right_r] 147 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 148 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 149 | 150 | if len(y) == 1: y = y[0] 151 | 152 | return y 153 | 154 | def forward_x8(self, *args, forward_function=None): 155 | def _transform(v, op): 156 | if self.precision != 'single': v = v.float() 157 | 158 | v2np = v.data.cpu().numpy() 159 | if op == 'v': 160 | tfnp = v2np[:, :, :, ::-1].copy() 161 | elif op == 'h': 162 | tfnp = v2np[:, :, ::-1, :].copy() 163 | elif op == 't': 164 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 165 | 166 | ret = torch.Tensor(tfnp).to(self.device) 167 | if self.precision == 'half': ret = ret.half() 168 | 169 | return ret 170 | 171 | list_x = [] 172 | for a in args: 173 | x = [a] 174 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 175 | 176 | list_x.append(x) 177 | 178 | list_y = [] 179 | for x in zip(*list_x): 180 | y = forward_function(*x) 181 | if not isinstance(y, list): y = [y] 182 | if not list_y: 183 | list_y = [[_y] for _y in y] 184 | else: 185 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 186 | 187 | for _list_y in list_y: 188 | for i in range(len(_list_y)): 189 | if i > 3: 190 | _list_y[i] = _transform(_list_y[i], 't') 191 | if i % 4 > 1: 192 | _list_y[i] = _transform(_list_y[i], 'h') 193 | if (i % 4) % 2 == 1: 194 | _list_y[i] = _transform(_list_y[i], 'v') 195 | 196 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 197 | if len(y) == 1: y = y[0] 198 | 199 | return y 200 | -------------------------------------------------------------------------------- /src/model/__pycache__/Discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/Discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/common_rrdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/common_rrdb.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/edsr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/edsr.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/rrdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaehaKim97/Adaptive-Downsampling-Model/100058d9af5132dcedfbd9b45e0e7c9f5ad38cc6/src/model/__pycache__/rrdb.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ResidualDenseBlock(nn.Module): 6 | def __init__(self, nf, gc=32, res_scale=0.2): 7 | super(ResidualDenseBlock, self).__init__() 8 | self.layer1 = nn.Sequential(nn.Conv2d(nf + 0 * gc, gc, 3, padding=1, bias=True), nn.LeakyReLU()) 9 | self.layer2 = nn.Sequential(nn.Conv2d(nf + 1 * gc, gc, 3, padding=1, bias=True), nn.LeakyReLU()) 10 | self.layer3 = nn.Sequential(nn.Conv2d(nf + 2 * gc, gc, 3, padding=1, bias=True), nn.LeakyReLU()) 11 | self.layer4 = nn.Sequential(nn.Conv2d(nf + 3 * gc, gc, 3, padding=1, bias=True), nn.LeakyReLU()) 12 | self.layer5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, 3, padding=1, bias=True), nn.LeakyReLU()) 13 | 14 | self.res_scale = res_scale 15 | 16 | def forward(self, x): 17 | layer1 = self.layer1(x) 18 | layer2 = self.layer2(torch.cat((x, layer1), 1)) 19 | layer3 = self.layer3(torch.cat((x, layer1, layer2), 1)) 20 | layer4 = self.layer4(torch.cat((x, layer1, layer2, layer3), 1)) 21 | layer5 = self.layer5(torch.cat((x, layer1, layer2, layer3, layer4), 1)) 22 | return layer5.mul(self.res_scale) + x 23 | 24 | 25 | class ResidualInResidualDenseBlock(nn.Module): 26 | def __init__(self, nf, gc=32, res_scale=0.2): 27 | super(ResidualInResidualDenseBlock, self).__init__() 28 | self.layer1 = ResidualDenseBlock(nf, gc) 29 | self.layer2 = ResidualDenseBlock(nf, gc) 30 | self.layer3 = ResidualDenseBlock(nf, gc, ) 31 | self.res_scale = res_scale 32 | 33 | def forward(self, x): 34 | out = self.layer1(x) 35 | out = self.layer2(out) 36 | out = self.layer3(out) 37 | return out.mul(self.res_scale) + x 38 | 39 | 40 | def upsample_block(nf, scale_factor=2): 41 | block = [] 42 | for _ in range(scale_factor//2): 43 | block += [ 44 | nn.Conv2d(nf, nf * (2 ** 2), 1), 45 | nn.PixelShuffle(2), 46 | nn.ReLU() 47 | ] 48 | 49 | return nn.Sequential(*block) 50 | 51 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | scale = int(scale) # default type of scale is string 64 | m = [] 65 | if scale == 1: # Is scale = 2^n? 66 | for _ in range(1): 67 | m.append(conv(n_feats, n_feats, 3, bias))# 4 * n_feats, 3, bias)) 68 | #m.append(nn.PixelShuffle(2)) 69 | if bn: 70 | m.append(nn.BatchNorm2d(n_feats)) 71 | if act == 'relu': 72 | m.append(nn.ReLU(True)) 73 | elif act == 'prelu': 74 | m.append(nn.PReLU(n_feats)) 75 | 76 | elif (scale & (scale - 1)) == 0: # Is scale = 2^n? 77 | for _ in range(int(math.log(scale, 2))): 78 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 79 | m.append(nn.PixelShuffle(2)) 80 | if bn: 81 | m.append(nn.BatchNorm2d(n_feats)) 82 | if act == 'relu': 83 | m.append(nn.ReLU(True)) 84 | elif act == 'prelu': 85 | m.append(nn.PReLU(n_feats)) 86 | 87 | elif scale == 3: 88 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 89 | m.append(nn.PixelShuffle(3)) 90 | if bn: 91 | m.append(nn.BatchNorm2d(n_feats)) 92 | if act == 'relu': 93 | m.append(nn.ReLU(True)) 94 | elif act == 'prelu': 95 | m.append(nn.PReLU(n_feats)) 96 | else: 97 | raise NotImplementedError 98 | 99 | super(Upsampler, self).__init__(*m) 100 | 101 | -------------------------------------------------------------------------------- /src/model/common_rrdb.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import typing 4 | 5 | #from misc import module_utils 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional 10 | from torch.nn import init 11 | from torchvision import models 12 | 13 | 14 | def default_conv( 15 | in_channels: int, 16 | out_channels: int, 17 | kernel_size: int, 18 | stride: int=1, 19 | padding: typing.Optional[int]=None, 20 | bias=True, 21 | padding_mode: str='zeros'): 22 | 23 | if padding is None: 24 | padding = (kernel_size - 1) // 2 25 | 26 | conv = nn.Conv2d( 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride=stride, 31 | padding=padding, 32 | bias=bias, 33 | padding_mode=padding_mode, 34 | ) 35 | return conv 36 | 37 | 38 | def get_model(m, cfg, *args, make=True, conv=default_conv, **kwargs): 39 | ''' 40 | Automatically find a class implementation and instantiate it. 41 | 42 | Args: 43 | m (str): Name of the module. 44 | cfg (Namespace): Global configurations. 45 | args (list): Additional arguments for the model class. 46 | make (bool, optional): If set to False, return model class itself. 47 | conv 48 | kwargs (dict): Additional keyword arguments for the model class. 49 | ''' 50 | model_class = module_utils.find_representative(m) 51 | if model_class is not None: 52 | if hasattr(model_class, 'get_kwargs'): 53 | model_kwargs = model_class.get_kwargs(cfg, conv=conv) 54 | else: 55 | model_kwargs = kwargs 56 | 57 | if make: 58 | return model_class(*args, **model_kwargs) 59 | else: 60 | return model_class 61 | else: 62 | raise NotImplementedError('The model class is not implemented!') 63 | 64 | 65 | def model_class(model_cls, cfg=None, make=True, conv=default_conv): 66 | if make and hasattr(model_cls, 'get_kwargs'): 67 | return model_cls(**model_cls.get_kwargs(cfg, conv=conv)) 68 | else: 69 | return model_cls 70 | 71 | 72 | def init_gans(target): 73 | for m in target.modules(): 74 | if isinstance(m, nn.modules.conv._ConvNd): 75 | m.weight.data.normal_(0.0, 0.02) 76 | if hasattr(m, 'bias') and m.bias is not None: 77 | m.bias.data.zero_() 78 | if isinstance(m, nn.Linear): 79 | m.weight.data.normal_(0.0, 0.02) 80 | if hasattr(m, 'bias') and m.bias is not None: 81 | m.bias.data.zero_() 82 | 83 | 84 | def append_module(m, name, n_feats): 85 | if name is None: 86 | return 87 | 88 | if name == 'batch': 89 | m.append(nn.BatchNorm2d(n_feats)) 90 | elif name == 'layer': 91 | m.append(nn.GroupNorm(1, n_feats)) 92 | elif name == 'instance': 93 | m.append(nn.InstanceNorm2d(n_feats)) 94 | 95 | if name == 'relu': 96 | m.append(nn.ReLU(True)) 97 | elif name == 'lrelu': 98 | m.append(nn.LeakyReLU(negative_slope=0.2, inplace=True)) 99 | elif name == 'prelu': 100 | m.append(nn.PReLU()) 101 | 102 | 103 | class MeanShift(nn.Conv2d): 104 | ''' 105 | Re-normalize input w.r.t given mean and std. 106 | This module assume that input lies in between -1 ~ 1 107 | ''' 108 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 109 | ''' 110 | Default values are ImageNet mean and std. 111 | ''' 112 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 113 | mean = torch.Tensor(mean) 114 | std = torch.Tensor(std) 115 | self.weight.data.copy_(torch.diag(0.5 / std).view(3, 3, 1, 1)) 116 | self.bias.data.copy_((0.5 - mean) / std) 117 | for p in self.parameters(): 118 | p.requires_grad = False 119 | 120 | 121 | class BasicBlock(nn.Sequential): 122 | ''' 123 | Make a basic block which consists of Conv-(Norm)-(Act). 124 | 125 | Args: 126 | in_channels (int): Conv in_channels. 127 | out_channels (int): Conv out_channels. 128 | kernel_size (int): Conv kernel_size. 129 | stride (int, default=1): Conv stride. 130 | norm ( or 'batch' or 'layer'): Norm function. 131 | act (<'relu'> or 'lrelu' or 'prelu'): Activation function. 132 | conv (funcion, optional): A function for making a conv layer. 133 | ''' 134 | 135 | def __init__( 136 | self, 137 | in_channels: int, 138 | out_channels: int, 139 | kernel_size: int, 140 | stride: int=1, 141 | padding: typing.Optional[int]=None, 142 | norm: typing.Optional[str]=None, 143 | act: typing.Optional[str]='relu', 144 | bias: bool=None, 145 | padding_mode: str='zeros', 146 | conv=default_conv): 147 | 148 | if bias is None: 149 | bias = norm is None 150 | 151 | m = [conv( 152 | in_channels, 153 | out_channels, 154 | kernel_size, 155 | bias=bias, 156 | stride=stride, 157 | padding=padding, 158 | padding_mode=padding_mode, 159 | )] 160 | append_module(m, norm, out_channels) 161 | append_module(m, act, out_channels) 162 | super().__init__(*m) 163 | 164 | 165 | class BasicTBlock(BasicBlock): 166 | 167 | def __init__(self, *args, **kwargs): 168 | kwargs['conv'] = nn.ConvTranspose2d 169 | super().__init__(*args, **kwargs) 170 | 171 | 172 | class ResBlock(nn.Sequential): 173 | ''' 174 | Make a residual block which consists of Conv-(Norm)-Act-Conv-(Norm). 175 | 176 | Args: 177 | n_feats (int): Conv in/out_channels. 178 | kernel_size (int): Conv kernel_size. 179 | norm ( or 'batch' or 'layer'): Norm function. 180 | act (<'relu'> or 'lrelu' or 'prelu'): Activation function. 181 | res_scale (float, optional): Residual scaling. 182 | conv (funcion, optional): A function for making a conv layer. 183 | 184 | Note: 185 | Residual scaling: 186 | From Szegedy et al., 187 | "Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning" 188 | See https://arxiv.org/pdf/1602.07261.pdf for more detail. 189 | 190 | To modify stride, change the conv function. 191 | ''' 192 | 193 | def __init__( 194 | self, 195 | n_feats: int, 196 | kernel_size: int, 197 | norm: typing.Optional[str]=None, 198 | act: str='relu', 199 | res_scale: float=1, 200 | res_prob: float=1, 201 | padding_mode: str='zeros', 202 | conv=default_conv) -> None: 203 | 204 | bias = norm is None 205 | m = [] 206 | for i in range(2): 207 | m.append(conv( 208 | n_feats, 209 | n_feats, 210 | kernel_size, 211 | bias=bias, 212 | padding_mode=padding_mode, 213 | )) 214 | append_module(m, norm, n_feats) 215 | if i == 0: 216 | append_module(m, act, n_feats) 217 | 218 | super().__init__(*m) 219 | self.res_scale = res_scale 220 | self.res_prob = res_prob 221 | return 222 | 223 | def forward(self, x): 224 | if self.training and random.random() > self.res_prob: 225 | return x 226 | 227 | x = x + self.res_scale * super(ResBlock, self).forward(x) 228 | return x 229 | 230 | 231 | class Upsampler(nn.Sequential): 232 | ''' 233 | Make an upsampling block using sub-pixel convolution 234 | 235 | Args: 236 | 237 | Note: 238 | From Shi et al., 239 | "Real-Time Single Image and Video Super-Resolution 240 | Using an Efficient Sub-pixel Convolutional Neural Network" 241 | See https://arxiv.org/pdf/1609.05158.pdf for more detail 242 | ''' 243 | 244 | def __init__( 245 | self, 246 | scale: int, 247 | n_feats: int, 248 | norm: typing.Optional[str]=None, 249 | act: typing.Optional[str]=None, 250 | bias: bool=True, 251 | padding_mode: str='zeros', 252 | conv=default_conv): 253 | 254 | bias = norm is None 255 | m = [] 256 | log_scale = math.log(scale, 2) 257 | # check if the scale is power of 2 258 | if int(log_scale) == log_scale: 259 | for _ in range(int(log_scale)): 260 | m.append(conv( 261 | n_feats, 262 | 4 * n_feats, 263 | 3, 264 | bias=bias, 265 | padding_mode=padding_mode, 266 | )) 267 | m.append(nn.PixelShuffle(2)) 268 | append_module(m, norm, n_feats) 269 | append_module(m, act, n_feats) 270 | elif scale == 3: 271 | m.append(conv( 272 | n_feats, 273 | 9 * n_feats, 274 | 3, 275 | bias=bias, 276 | padding_mode=padding_mode, 277 | )) 278 | m.append(nn.PixelShuffle(3)) 279 | append_module(m, norm, n_feats) 280 | append_module(m, act, n_feats) 281 | else: 282 | raise NotImplementedError 283 | 284 | super(Upsampler, self).__init__(*m) 285 | 286 | 287 | class UpsamplerI(nn.Module): 288 | ''' 289 | Interpolation based upsampler 290 | ''' 291 | 292 | def __init__( 293 | self, scale, n_feats, algorithm='nearest', activation=True, conv=default_conv): 294 | 295 | super(UpsamplerI, self).__init__() 296 | log_scale = int(math.log(scale, 2)) 297 | self.algorithm = algorithm 298 | self.activation = activation 299 | self.convs = nn.ModuleList([ 300 | conv(n_feats, n_feats, 3) for _ in range(log_scale) 301 | ]) 302 | 303 | def forward(self, x): 304 | for conv in self.convs: 305 | x = functional.interpolate(x, scale_factor=2, mode=self.algorithm) 306 | x = conv(x) 307 | if self.activation: 308 | x = functional.leaky_relu(x, negative_slope=0.2, inplace=True) 309 | 310 | return x 311 | 312 | 313 | class PixelSort(nn.Module): 314 | ''' 315 | An inverse operation of nn.PixelShuffle. Only for scale 2. 316 | ''' 317 | 318 | def __init__(self): 319 | super(PixelSort, self).__init__() 320 | 321 | def forward(self, x): 322 | ''' 323 | Tiling input into smaller resolutions. 324 | 325 | Args: 326 | x (Tensor): 327 | 328 | Return: 329 | Tensor: 330 | 331 | Example:: 332 | 333 | >>> x = torch.Tensor(16, 64, 256, 256) 334 | >>> ps = PixelSort() 335 | >>> y = ps(x) 336 | >>> y.size() 337 | torch.Size([16, 256, 128, 128]) 338 | 339 | ''' 340 | 341 | ''' 342 | _, c, h, w = x.size() 343 | #h //= self.scale 344 | #w //= self.scale 345 | #x = x.view(-1, c, h, self.scale, w, self.scale) 346 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 347 | #x = x.view(-1, self.scale**2 * c, h, w) 348 | ''' 349 | # we have a jit compatibility issue with the code above... 350 | from_zero = slice(0, None, 2) 351 | from_one = slice(1, None, 2) 352 | tl = x[..., from_zero, from_zero] 353 | tr = x[..., from_zero, from_one] 354 | bl = x[..., from_one, from_zero] 355 | br = x[..., from_one, from_one] 356 | x = torch.cat((tl, tr, bl, br), dim=1) 357 | return x 358 | 359 | class Downsampler(nn.Sequential): 360 | 361 | def __init__( 362 | self, scale, n_feats, 363 | norm=None, act=None, conv=default_conv): 364 | 365 | bias = norm is None 366 | m = [] 367 | log_scale = math.log(scale, 2) 368 | if int(log_scale) == log_scale: 369 | for _ in range(int(log_scale)): 370 | m.append(PixelSort()) 371 | m.append(conv(4 * n_feats, n_feats, 3, bias=bias)) 372 | append_module(m, norm, n_feats) 373 | append_module(m, act, n_feats) 374 | else: 375 | raise NotImplementedError 376 | 377 | super(Downsampler, self).__init__(*m) 378 | 379 | def extract_vgg(name): 380 | gen = models.vgg19(pretrained=True).features 381 | vgg = None 382 | configs = ( 383 | '11', '12', 384 | '21', '22', 385 | '31', '32', '33', '34', 386 | '41', '42', '43', '44', 387 | '51', '52', '53', '54', 388 | ) 389 | sub_mean = MeanShift() 390 | def sub_vgg(config): 391 | sub_modules = [sub_mean] 392 | pool_idx = 0 393 | conv_idx = 0 394 | pools = int(config[0]) 395 | convs = int(config[1]) 396 | for m in gen: 397 | if convs == 0: 398 | return sub_mean 399 | sub_modules.append(m) 400 | if isinstance(m, nn.Conv2d): 401 | conv_idx += 1 402 | elif isinstance(m, nn.MaxPool2d): 403 | conv_idx = 0 404 | pool_idx += 1 405 | 406 | if conv_idx == convs and pool_idx == pools - 1: 407 | return nn.Sequential(*sub_modules) 408 | 409 | for config in configs: 410 | if config in name: 411 | vgg = sub_vgg(config) 412 | break 413 | 414 | if vgg is None: 415 | vgg = sub_vgg('54') 416 | 417 | return vgg 418 | 419 | def extract_resnet(name): 420 | configs = ('18', '34', '50', '101', '152') 421 | resnet = models.resnet50 422 | for config in configs: 423 | if config in name: 424 | resnet = getattr(models, 'resnet{}'.format(config)) 425 | break 426 | 427 | resnet = resnet(pretrained=True) 428 | resnet.avgpool = nn.AdaptiveAvgPool2d(1) 429 | resnet.fc = nn.Identity() 430 | resnet.eval() 431 | resnet_seq = nn.Sequential(MeanShift(), resnet) 432 | return resnet_seq 433 | 434 | 435 | if __name__ == '__main__': 436 | ''' 437 | torch.set_printoptions(precision=3, linewidth=120) 438 | with torch.no_grad(): 439 | x = torch.arange(64).view(1, 1, 8, 8).float() 440 | ps = Downsampler(2, 1) 441 | print(ps(x)) 442 | from torch import jit 443 | jit_traced = jit.trace(ps, x) 444 | print(jit_traced.graph) 445 | print(jit_traced) 446 | jit_traced.save('jit_test.pt') 447 | jit_load = jit.load('jit_test.pt') 448 | print(jit_load(x)) 449 | ''' 450 | x = 2 * torch.rand(1, 3, 4, 4) - 1 451 | print(x) 452 | ms = MeanShift() 453 | print(ms(x)) 454 | -------------------------------------------------------------------------------- /src/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /src/model/didn.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/SonghyunYu//color_model.py 2 | import math 3 | 4 | from model import common 5 | 6 | import torch 7 | from torch import nn 8 | 9 | def make_model(args, parent=False): 10 | return _NetG() 11 | 12 | 13 | class _Residual_Block(nn.Module): 14 | def __init__(self): 15 | super(_Residual_Block, self).__init__() 16 | 17 | #res1 18 | self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.relu2 = nn.PReLU() 20 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.relu4 = nn.PReLU() 22 | #res1 23 | #concat1 24 | 25 | self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1, bias=False) 26 | self.relu6 = nn.PReLU() 27 | 28 | #res2 29 | self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False) 30 | self.relu8 = nn.PReLU() 31 | #res2 32 | #concat2 33 | 34 | self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1, bias=False) 35 | self.relu10 = nn.PReLU() 36 | 37 | #res3 38 | self.conv11 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False) 39 | self.relu12 = nn.PReLU() 40 | #res3 41 | 42 | self.conv13 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1, stride=1, padding=0, bias=False) 43 | self.up14 = nn.PixelShuffle(2) 44 | 45 | #concat2 46 | self.conv15 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False) 47 | #res4 48 | self.conv16 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.relu17 = nn.PReLU() 50 | #res4 51 | 52 | self.conv18 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False) 53 | self.up19 = nn.PixelShuffle(2) 54 | 55 | #concat1 56 | self.conv20 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False) 57 | #res5 58 | self.conv21 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 59 | self.relu22 = nn.PReLU() 60 | self.conv23 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.relu24 = nn.PReLU() 62 | #res5 63 | 64 | self.conv25 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 65 | 66 | 67 | def forward(self, x): 68 | res1 = x 69 | out = self.relu4(self.conv3(self.relu2(self.conv1(x)))) 70 | out = torch.add(res1, out) 71 | cat1 = out 72 | 73 | out = self.relu6(self.conv5(out)) 74 | res2 = out 75 | out = self.relu8(self.conv7(out)) 76 | out = torch.add(res2, out) 77 | cat2 = out 78 | 79 | out = self.relu10(self.conv9(out)) 80 | res3 = out 81 | 82 | out = self.relu12(self.conv11(out)) 83 | out = torch.add(res3, out) 84 | 85 | out = self.up14(self.conv13(out)) 86 | 87 | out = torch.cat([out, cat2], 1) 88 | out = self.conv15(out) 89 | res4 = out 90 | out = self.relu17(self.conv16(out)) 91 | out = torch.add(res4, out) 92 | 93 | out = self.up19(self.conv18(out)) 94 | 95 | out = torch.cat([out, cat1], 1) 96 | out = self.conv20(out) 97 | res5 = out 98 | out = self.relu24(self.conv23(self.relu22(self.conv21(out)))) 99 | out = torch.add(res5, out) 100 | 101 | out = self.conv25(out) 102 | out = torch.add(out, res1) 103 | 104 | return out 105 | 106 | 107 | class Recon_Block(nn.Module): 108 | def __init__(self): 109 | super(Recon_Block, self).__init__() 110 | 111 | self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 112 | self.relu2 = nn.PReLU() 113 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 114 | self.relu4 = nn.PReLU() 115 | 116 | self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 117 | self.relu6= nn.PReLU() 118 | self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 119 | self.relu8 = nn.PReLU() 120 | 121 | self.conv9 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 122 | self.relu10 = nn.PReLU() 123 | self.conv11 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 124 | self.relu12 = nn.PReLU() 125 | 126 | self.conv13 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 127 | self.relu14 = nn.PReLU() 128 | self.conv15 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 129 | self.relu16 = nn.PReLU() 130 | 131 | self.conv17 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 132 | 133 | def forward(self, x): 134 | res1 = x 135 | output = self.relu4(self.conv3(self.relu2(self.conv1(x)))) 136 | output = torch.add(output, res1) 137 | 138 | res2 = output 139 | output = self.relu8(self.conv7(self.relu6(self.conv5(output)))) 140 | output = torch.add(output, res2) 141 | 142 | res3 = output 143 | output = self.relu12(self.conv11(self.relu10(self.conv9(output)))) 144 | output = torch.add(output, res3) 145 | 146 | res4 = output 147 | output = self.relu16(self.conv15(self.relu14(self.conv13(output)))) 148 | output = torch.add(output, res4) 149 | 150 | output = self.conv17(output) 151 | output = torch.add(output, res1) 152 | 153 | return output 154 | 155 | 156 | class _NetG(nn.Module): 157 | def __init__(self, conv=common.default_conv): 158 | super(_NetG, self).__init__() 159 | 160 | self.conv_input = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 161 | self.relu1 = nn.PReLU() 162 | self.conv_down = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False) 163 | self.relu2 = nn.PReLU() 164 | 165 | self.recursive_A = _Residual_Block() 166 | self.recursive_B = _Residual_Block() 167 | self.recursive_C = _Residual_Block() 168 | self.recursive_D = _Residual_Block() 169 | self.recursive_E = _Residual_Block() 170 | self.recursive_F = _Residual_Block() 171 | 172 | self.recon = Recon_Block() 173 | #concat 174 | 175 | self.conv_mid = nn.Conv2d(in_channels=1536, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False) 176 | self.relu3 = nn.PReLU() 177 | self.conv_mid2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 178 | self.relu4 = nn.PReLU() 179 | 180 | self.subpixel = nn.PixelShuffle(2) 181 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) 182 | 183 | @staticmethod 184 | def get_kwargs(cfg, conv=common.default_conv): 185 | kwargs = {} 186 | return kwargs 187 | 188 | def forward(self, x): 189 | residual = x 190 | out = self.relu1(self.conv_input(x)) 191 | out = self.relu2(self.conv_down(out)) 192 | 193 | out1 = self.recursive_A(out) 194 | out2 = self.recursive_B(out1) 195 | out3 = self.recursive_C(out2) 196 | out4 = self.recursive_D(out3) 197 | out5 = self.recursive_E(out4) 198 | out6 = self.recursive_F(out5) 199 | 200 | recon1 = self.recon(out1) 201 | recon2 = self.recon(out2) 202 | recon3 = self.recon(out3) 203 | recon4 = self.recon(out4) 204 | recon5 = self.recon(out5) 205 | recon6 = self.recon(out6) 206 | 207 | out = torch.cat([recon1, recon2, recon3, recon4, recon5, recon6], 1) 208 | 209 | out = self.relu3(self.conv_mid(out)) 210 | residual2 = out 211 | out = self.relu4(self.conv_mid2(out)) 212 | out = torch.add(out, residual2) 213 | 214 | out= self.subpixel(out) 215 | out = self.conv_output(out) 216 | out = torch.add(out, residual) 217 | 218 | return out 219 | 220 | 221 | # For automatic model loading 222 | REPRESENTATIVE = _NetG -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return EDSR(args) 16 | 17 | class EDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(EDSR, self).__init__() 20 | 21 | self.args = args 22 | n_resblocks = 16 #args.n_resblocks 23 | n_feats = 64 #args.n_feats 24 | kernel_size = 3 25 | scale = args.scale #args.scale[0] 26 | rgb_range = 1 #1 # 27 | res_scale = 1 28 | n_colors = 3 29 | act = nn.ReLU(True) 30 | 31 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 32 | if url_name in url: 33 | self.url = url[url_name] 34 | else: 35 | self.url = None 36 | self.sub_mean = common.MeanShift(rgb_range) 37 | self.add_mean = common.MeanShift(rgb_range, sign=1) 38 | 39 | # define head module 40 | m_head = [conv(n_colors, n_feats, kernel_size)] 41 | 42 | # define body module 43 | m_body = [ 44 | common.ResBlock( 45 | conv, n_feats, kernel_size, act=act, res_scale=res_scale 46 | ) for _ in range(n_resblocks) 47 | ] 48 | m_body.append(conv(n_feats, n_feats, kernel_size)) 49 | 50 | # define tail module 51 | m_tail = [ 52 | common.Upsampler(conv, scale, n_feats, act=False), 53 | #common.Upsampler(conv, 1, n_feats, act=False), 54 | conv(n_feats, n_colors, kernel_size) 55 | ] 56 | 57 | self.head = nn.Sequential(*m_head) 58 | self.body = nn.Sequential(*m_body) 59 | self.tail = nn.Sequential(*m_tail) 60 | 61 | def forward(self, x): 62 | x = self.sub_mean(x) 63 | x = self.head(x) 64 | 65 | res = self.body(x) 66 | res += x 67 | 68 | x = self.tail(res) 69 | x = self.add_mean(x) 70 | 71 | return x 72 | 73 | def load_state_dict(self, state_dict, strict=True): 74 | own_state = self.state_dict() 75 | for name, param in state_dict.items(): 76 | if name in own_state: 77 | if isinstance(param, nn.Parameter): 78 | param = param.data 79 | try: 80 | own_state[name].copy_(param) 81 | except Exception: 82 | if name.find('tail') == -1: 83 | raise RuntimeError('While copying the parameter named {}, ' 84 | 'whose dimensions in the model are {} and ' 85 | 'whose dimensions in the checkpoint are {}.' 86 | .format(name, own_state[name].size(), param.size())) 87 | elif strict: 88 | if name.find('tail') == -1: 89 | raise KeyError('unexpected key "{}" in state_dict' 90 | .format(name)) 91 | 92 | -------------------------------------------------------------------------------- /src/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt', 7 | 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return MDSR(args) 12 | 13 | class MDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(MDSR, self).__init__() 16 | n_resblocks = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | act = nn.ReLU(True) 20 | self.scale_idx = 0 21 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 22 | self.sub_mean = common.MeanShift(args.rgb_range) 23 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 24 | 25 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 26 | 27 | self.pre_process = nn.ModuleList([ 28 | nn.Sequential( 29 | common.ResBlock(conv, n_feats, 5, act=act), 30 | common.ResBlock(conv, n_feats, 5, act=act) 31 | ) for _ in args.scale 32 | ]) 33 | 34 | m_body = [ 35 | common.ResBlock( 36 | conv, n_feats, kernel_size, act=act 37 | ) for _ in range(n_resblocks) 38 | ] 39 | m_body.append(conv(n_feats, n_feats, kernel_size)) 40 | 41 | self.upsample = nn.ModuleList([ 42 | common.Upsampler(conv, s, n_feats, act=False) for s in args.scale 43 | ]) 44 | 45 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | x = self.sub_mean(x) 53 | x = self.head(x) 54 | x = self.pre_process[self.scale_idx](x) 55 | 56 | res = self.body(x) 57 | res += x 58 | 59 | x = self.upsample[self.scale_idx](res) 60 | x = self.tail(x) 61 | x = self.add_mean(x) 62 | 63 | return x 64 | 65 | def set_scale(self, scale_idx): 66 | self.scale_idx = scale_idx 67 | 68 | -------------------------------------------------------------------------------- /src/model/rcan.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | 7 | def make_model(args, parent=False): 8 | return RCAN(args) 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | # global average pooling: feature --> point 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | # feature channel downscale and upscale --> channel weight 17 | self.conv_du = nn.Sequential( 18 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.avg_pool(x) 26 | y = self.conv_du(y) 27 | return x * y 28 | 29 | ## Residual Channel Attention Block (RCAB) 30 | class RCAB(nn.Module): 31 | def __init__( 32 | self, conv, n_feat, kernel_size, reduction, 33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 34 | 35 | super(RCAB, self).__init__() 36 | modules_body = [] 37 | for i in range(2): 38 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 39 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 40 | if i == 0: modules_body.append(act) 41 | modules_body.append(CALayer(n_feat, reduction)) 42 | self.body = nn.Sequential(*modules_body) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x) 47 | #res = self.body(x).mul(self.res_scale) 48 | res += x 49 | return res 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | ## Residual Channel Attention Network (RCAN) 69 | class RCAN(nn.Module): 70 | def __init__(self, args, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = args.n_resgroups 74 | n_resblocks = args.n_resblocks 75 | n_feats = args.n_feats 76 | kernel_size = 3 77 | reduction = args.reduction 78 | scale = args.scale[0] 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 82 | self.sub_mean = common.MeanShift(args.rgb_range) 83 | 84 | # define head module 85 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 86 | 87 | # define body module 88 | modules_body = [ 89 | ResidualGroup( 90 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 91 | for _ in range(n_resgroups)] 92 | 93 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 94 | 95 | # define tail module 96 | modules_tail = [ 97 | common.Upsampler(conv, scale, n_feats, act=False), 98 | conv(n_feats, args.n_colors, kernel_size)] 99 | 100 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 101 | 102 | self.head = nn.Sequential(*modules_head) 103 | self.body = nn.Sequential(*modules_body) 104 | self.tail = nn.Sequential(*modules_tail) 105 | 106 | def forward(self, x): 107 | x = self.sub_mean(x) 108 | x = self.head(x) 109 | 110 | res = self.body(x) 111 | res += x 112 | 113 | x = self.tail(res) 114 | x = self.add_mean(x) 115 | 116 | return x 117 | 118 | def load_state_dict(self, state_dict, strict=False): 119 | own_state = self.state_dict() 120 | for name, param in state_dict.items(): 121 | if name in own_state: 122 | if isinstance(param, nn.Parameter): 123 | param = param.data 124 | try: 125 | own_state[name].copy_(param) 126 | except Exception: 127 | if name.find('tail') >= 0: 128 | print('Replace pre-trained upsampler to new one...') 129 | else: 130 | raise RuntimeError('While copying the parameter named {}, ' 131 | 'whose dimensions in the model are {} and ' 132 | 'whose dimensions in the checkpoint are {}.' 133 | .format(name, own_state[name].size(), param.size())) 134 | elif strict: 135 | if name.find('tail') == -1: 136 | raise KeyError('unexpected key "{}" in state_dict' 137 | .format(name)) 138 | 139 | if strict: 140 | missing = set(own_state.keys()) - set(state_dict.keys()) 141 | if len(missing) > 0: 142 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 143 | -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /src/model/rrdb.py: -------------------------------------------------------------------------------- 1 | from model import common_rrdb as common 2 | #from config import get_config 3 | 4 | from torch import nn 5 | from torch.nn import functional 6 | 7 | def make_model(args, parent=False): 8 | return RRDB(args) 9 | 10 | class RDBlock(nn.Module): 11 | 12 | def __init__( 13 | self, n_feats, gf, drep=5, res_scale=0.2, conv=common.default_conv): 14 | 15 | super(RDBlock, self).__init__() 16 | self.n_feats = n_feats 17 | self.n_max = n_feats + (drep - 1) * gf 18 | self.gf = gf 19 | self.res_scale = res_scale 20 | m = [conv(n_feats + i * gf, gf, 3) for i in range(drep - 1)] 21 | m.append(conv(self.n_max, n_feats, 3)) 22 | self.convs = nn.ModuleList(m) 23 | 24 | def forward(self, x): 25 | ''' 26 | We will not use torch.cat since it consumes GPU memory lot. 27 | ''' 28 | b, _, h, w = x.size() 29 | c = self.n_feats 30 | buf = x.new_empty(b, self.n_max, h, w) 31 | buf[:, :c, :, :] = x 32 | for i, conv in enumerate(self.convs): 33 | x_inter = conv(buf[:, :c, :, :]) 34 | if i == len(self.convs) - 1: 35 | return x + self.res_scale * x_inter 36 | else: 37 | x_inter = functional.leaky_relu( 38 | x_inter, negative_slope=0.2, inplace=True 39 | ) 40 | buf[:, c:c + self.gf, :, :] = x_inter 41 | c += self.gf 42 | 43 | 44 | class RRDBlock(nn.Sequential): 45 | 46 | def __init__( 47 | self, n_feats, gf, rep=3, res_scale=0.2, conv=common.default_conv): 48 | 49 | self.res_scale = res_scale 50 | args = [n_feats, gf] 51 | kwargs = {'res_scale': res_scale, 'conv': conv} 52 | m = [RDBlock(*args, **kwargs) for _ in range(rep)] 53 | super().__init__(*m) 54 | 55 | def forward(self, x): 56 | x = x + self.res_scale * super().forward(x) 57 | return x 58 | 59 | 60 | class RRDB(nn.Module): 61 | ''' 62 | RRDB model 63 | 64 | Note: 65 | From 66 | "ESRGAN" 67 | See for more detail. 68 | ''' 69 | 70 | def __init__( 71 | self, args, scale=4, depth=23, n_colors=3, n_feats=64, 72 | gf=32, res_scale=0.2, conv=common.default_conv): 73 | 74 | super(RRDB, self).__init__() 75 | scale = int(args.scale) 76 | self.conv = conv(n_colors, n_feats, 3) 77 | block = lambda: RRDBlock(n_feats, gf, res_scale=res_scale) 78 | m = [block() for _ in range(depth)] 79 | m.append(conv(n_feats, n_feats, 3)) 80 | self.rrdblocks = nn.Sequential(*m) 81 | self.recon = nn.Sequential( 82 | common.UpsamplerI(scale, n_feats, algorithm='nearest'), 83 | conv(n_feats, n_feats, 3), 84 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 85 | conv(n_feats, n_colors, 3) 86 | ) 87 | self.res_scale = res_scale 88 | 89 | @staticmethod 90 | def get_kwargs(cfg, conv=common.default_conv): 91 | parse_list = ['scale', 'n_colors', 'n_feats'] 92 | kwargs = get_config.parse_namespace(cfg, *parse_list) 93 | kwargs['gf'] = cfg.n_feats // 2 94 | kwargs['depth'] = 23 95 | kwargs['res_scale'] = 0.2 96 | kwargs['conv'] = conv 97 | return kwargs 98 | 99 | def forward(self, x): 100 | x = self.conv(x) 101 | x = x + self.rrdblocks(x) 102 | x = self.recon(x) 103 | return x 104 | 105 | def load_state_dict(self, state_dict, strict=True): 106 | own_state = self.state_dict() 107 | for k in own_state.keys(): 108 | if k not in state_dict and 'recon' not in k: 109 | raise RuntimeError(k + ' does not exist!') 110 | else: 111 | if k in state_dict: 112 | own_state[k] = state_dict[k] 113 | 114 | super().load_state_dict(own_state, strict=strict) 115 | 116 | # For loading the module 117 | REPRESENTATIVE = RRDB 118 | 119 | -------------------------------------------------------------------------------- /src/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | import torchvision.models as py_models 8 | import numpy 9 | import copy 10 | 11 | 12 | #################################################################### 13 | #------------------------- Discriminators -------------------------- 14 | #################################################################### 15 | class D_Module(nn.Module): 16 | def __init__(self, args, n_layer=5, norm='None', sn=False): 17 | super(D_Module, self).__init__() 18 | ch = 64 19 | 20 | self.args = args 21 | self.Diss = nn.ModuleList() 22 | self.Diss.append(self._make_net(ch, 3, n_layer, norm, sn)) 23 | 24 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 25 | model = [MyConv2d(input_dim, ch, kernel_size=7, stride=1, padding=3, norm=norm, sn=sn, Leaky=True)] 26 | tch = ch 27 | for _ in range(1,n_layer): 28 | model += [MyConv2d(tch, min(1024, tch * 2), kernel_size=5, stride=2, padding=2, norm=norm, sn=sn, Leaky=True)] 29 | tch *= 2 30 | tch = min(1024, tch) 31 | model += [nn.Conv2d(tch, 1, 2, 1, 0)] 32 | 33 | return nn.Sequential(*model) 34 | 35 | def forward(self, x): 36 | outs = [] 37 | outs.append(self.Diss[0](x)) 38 | 39 | return outs 40 | 41 | 42 | #################################################################### 43 | #--------------------------- Generators ---------------------------- 44 | #################################################################### 45 | 46 | class G_Module(nn.Module): 47 | def __init__(self, args, norm=None, nl_layer=None): 48 | super(G_Module, self).__init__() 49 | 50 | tch = 64 51 | res = True 52 | 53 | headB = [MyConv2d(3, tch, kernel_size=3, stride=1, padding=1, norm=norm)] 54 | self.headB = nn.Sequential(*headB) 55 | 56 | bodyB1 = [MyConv2d(tch, tch, kernel_size=5, stride=1, padding=2, norm=norm, Res=res), 57 | MyConv2d(tch, tch, kernel_size=5, stride=1, padding=2, norm=norm, Res=res), 58 | MyConv2d(tch, tch, kernel_size=5, stride=1, padding=2, norm=norm, Res=res), 59 | MyConv2d(tch, tch, kernel_size=5, stride=1, padding=2, norm=norm, Res=res),] 60 | 61 | bodyB2 = [MyConv2d(tch, tch*2, kernel_size=2, stride=2, padding=0, norm=norm),] 62 | self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2) 63 | 64 | bodyB3 = [MyConv2d(tch*2, tch*2, kernel_size=3, stride=1, padding=1, norm=norm, Res=res), 65 | MyConv2d(tch*2, tch*2, kernel_size=3, stride=1, padding=1, norm=norm, Res=res), 66 | MyConv2d(tch*2, tch*2, kernel_size=3, stride=1, padding=1, norm=norm, Res=res), ] 67 | 68 | self.bodyB1 = nn.Sequential(*bodyB1) 69 | self.bodyB2 = nn.Sequential(*bodyB2) 70 | self.bodyB3 = nn.Sequential(*bodyB3) 71 | 72 | tailB = [ nn.Conv2d(tch*2, 3, kernel_size=1, stride=1, padding=0) ] 73 | self.tailB = nn.Sequential(*tailB) 74 | 75 | 76 | def forward(self, HR): 77 | tres = self.avgpool2(HR) 78 | out = self.headB(HR) 79 | 80 | res = out 81 | out = self.bodyB1(out) 82 | out += res 83 | 84 | out = self.bodyB2(out) 85 | 86 | res = out 87 | out = self.bodyB3(out) 88 | out += res 89 | 90 | out = self.tailB(out) 91 | out += tres 92 | 93 | return out 94 | 95 | #################################################################### 96 | #--------------------------- losses ---------------------------- 97 | #################################################################### 98 | class PerceptualLoss(): 99 | def __init__(self, loss, gpu=0, p_layer=14): 100 | super(PerceptualLoss, self).__init__() 101 | self.criterion = loss 102 | 103 | cnn = py_models.vgg19(pretrained=True).features 104 | cnn = cnn.cuda() 105 | model = nn.Sequential() 106 | model = model.cuda() 107 | for i,layer in enumerate(list(cnn)): 108 | model.add_module(str(i),layer) 109 | if i == p_layer: 110 | break 111 | self.contentFunc = model 112 | 113 | def getloss(self, fakeIm, realIm): 114 | if isinstance(fakeIm, numpy.ndarray): 115 | fakeIm = torch.from_numpy(fakeIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 116 | realIm = torch.from_numpy(realIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 117 | f_fake = self.contentFunc.forward(fakeIm) 118 | f_real = self.contentFunc.forward(realIm) 119 | f_real_no_grad = f_real.detach() 120 | loss = self.criterion(f_fake, f_real_no_grad) 121 | return loss 122 | 123 | class PerceptualLoss16(): 124 | def __init__(self, loss, gpu=0, p_layer=14): 125 | super(PerceptualLoss16, self).__init__() 126 | self.criterion = loss 127 | # conv_3_3_layer = 14 128 | checkpoint = torch.load('/vggface_path/VGGFace16.pth') 129 | vgg16 = py_models.vgg16(num_classes=2622) 130 | vgg16.load_state_dict(checkpoint['state_dict']) 131 | cnn = vgg16.features 132 | cnn = cnn.cuda() 133 | # cnn = cnn.to(gpu) 134 | model = nn.Sequential() 135 | model = model.cuda() 136 | for i,layer in enumerate(list(cnn)): 137 | # print(layer) 138 | model.add_module(str(i),layer) 139 | if i == p_layer: 140 | break 141 | self.contentFunc = model 142 | del vgg16, cnn, checkpoint 143 | 144 | def getloss(self, fakeIm, realIm): 145 | if isinstance(fakeIm, numpy.ndarray): 146 | fakeIm = torch.from_numpy(fakeIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 147 | realIm = torch.from_numpy(realIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 148 | 149 | f_fake = self.contentFunc.forward(fakeIm) 150 | f_real = self.contentFunc.forward(realIm) 151 | f_real_no_grad = f_real.detach() 152 | loss = self.criterion(f_fake, f_real_no_grad) 153 | return loss 154 | 155 | #################################################################### 156 | #------------------------- Basic Functions ------------------------- 157 | #################################################################### 158 | def get_non_linearity(layer_type='relu'): 159 | if layer_type == 'relu': 160 | nl_layer = functools.partial(nn.ReLU, inplace=False) 161 | elif layer_type == 'lrelu': 162 | nl_layer = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=False) 163 | elif layer_type == 'elu': 164 | nl_layer = functools.partial(nn.ELU, inplace=False) 165 | else: 166 | raise NotImplementedError('nonlinearity activitation [%s] is not found' % layer_type) 167 | return nl_layer 168 | def conv3x3(in_planes, out_planes): 169 | return [nn.ReflectionPad2d(1), nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0, bias=True)] 170 | 171 | def gaussian_weights_init(m): 172 | classname = m.__class__.__name__ 173 | if classname.find('Conv') != -1 and classname.find('Conv') == 0: 174 | m.weight.data.normal_(0.0, 0.02) 175 | 176 | #################################################################### 177 | #-------------------------- Basic Blocks -------------------------- 178 | #################################################################### 179 | class MyConv2d(nn.Module): 180 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0, norm=None, Res=False, sn=False, Leaky=False): 181 | super(MyConv2d, self).__init__() 182 | model = [nn.ReflectionPad2d(padding)] 183 | 184 | if sn: 185 | model += [spectral_norm(nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True))] 186 | else: 187 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True)] 188 | 189 | if norm == 'Instance': 190 | model += [nn.InstanceNorm2d(n_out, affine=False)] 191 | elif norm == 'Batch': 192 | model += [nn.BatchNorm2d(n_out)] 193 | elif norm == 'Layer': 194 | model += [LayerNorm(n_out)] 195 | elif norm != 'None': 196 | raise NotImplementedError('not implemeted norm type') 197 | 198 | if Leaky: 199 | model += [nn.LeakyReLU(inplace=False)] 200 | else: 201 | model += [nn.ReLU(inplace=False)] 202 | 203 | self.model = nn.Sequential(*model) 204 | self.model.apply(gaussian_weights_init) 205 | self.Res = Res 206 | 207 | def forward(self, x): 208 | if self.Res: 209 | return self.model(x) + x 210 | else: 211 | return self.model(x) 212 | 213 | ###################################################################### 214 | ## The code of LayerNorm is modified from MUNIT (https://github.com/NVlabs/MUNIT) 215 | class LayerNorm(nn.Module): 216 | def __init__(self, n_out, eps=1e-5, affine=True): 217 | super(LayerNorm, self).__init__() 218 | self.n_out = n_out 219 | self.affine = affine 220 | if self.affine: 221 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) 222 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) 223 | return 224 | def forward(self, x): 225 | normalized_shape = x.size()[1:] 226 | if self.affine: 227 | return F.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape), self.bias.expand(normalized_shape)) 228 | else: 229 | return F.layer_norm(x, normalized_shape) 230 | 231 | 232 | #################################################################### 233 | #--------------------- Spectral Normalization --------------------- 234 | # This part of code is copied from pytorch master branch (0.5.0) 235 | #################################################################### 236 | class SpectralNorm(object): 237 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 238 | self.name = name 239 | self.dim = dim 240 | if n_power_iterations <= 0: 241 | raise ValueError('Expected n_power_iterations to be positive, but ' 242 | 'got n_power_iterations={}'.format(n_power_iterations)) 243 | self.n_power_iterations = n_power_iterations 244 | self.eps = eps 245 | def compute_weight(self, module): 246 | weight = getattr(module, self.name + '_orig') 247 | u = getattr(module, self.name + '_u') 248 | weight_mat = weight 249 | if self.dim != 0: 250 | # permute dim to front 251 | weight_mat = weight_mat.permute(self.dim, 252 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 253 | height = weight_mat.size(0) 254 | weight_mat = weight_mat.reshape(height, -1) 255 | with torch.no_grad(): 256 | for _ in range(self.n_power_iterations): 257 | v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 258 | u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 259 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 260 | weight = weight / sigma 261 | return weight, u 262 | def remove(self, module): 263 | weight = getattr(module, self.name) 264 | delattr(module, self.name) 265 | delattr(module, self.name + '_u') 266 | delattr(module, self.name + '_orig') 267 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 268 | def __call__(self, module, inputs): 269 | if module.training: 270 | weight, u = self.compute_weight(module) 271 | setattr(module, self.name, weight) 272 | setattr(module, self.name + '_u', u) 273 | else: 274 | r_g = getattr(module, self.name + '_orig').requires_grad 275 | getattr(module, self.name).detach_().requires_grad_(r_g) 276 | 277 | @staticmethod 278 | def apply(module, name, n_power_iterations, dim, eps): 279 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 280 | weight = module._parameters[name] 281 | height = weight.size(dim) 282 | u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 283 | delattr(module, fn.name) 284 | module.register_parameter(fn.name + "_orig", weight) 285 | module.register_buffer(fn.name, weight.data) 286 | module.register_buffer(fn.name + "_u", u) 287 | module.register_forward_pre_hook(fn) 288 | return fn 289 | 290 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 291 | if dim is None: 292 | if isinstance(module, (torch.nn.ConvTranspose1d, 293 | torch.nn.ConvTranspose2d, 294 | torch.nn.ConvTranspose3d)): 295 | dim = 1 296 | else: 297 | dim = 0 298 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 299 | return module 300 | 301 | def remove_spectral_norm(module, name='weight'): 302 | for k, hook in module._forward_pre_hooks.items(): 303 | if isinstance(hook, SpectralNorm) and hook.name == name: 304 | hook.remove(module) 305 | del module._forward_pre_hooks[k] 306 | return module 307 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | 5 | class Options(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser() 8 | 9 | ## dataset related 10 | # learning downsampling 11 | self.parser.add_argument('--source', type=str, default='Source', help='Source type') 12 | self.parser.add_argument('--target', type=str, default='Target', help='target type') 13 | # validation set for SR, only used in joint training 14 | self.parser.add_argument('--test_hr', type=str, help='HR images for validating') 15 | self.parser.add_argument('--test_lr', type=str, help='LR images for validating') 16 | ## data loader related 17 | self.parser.add_argument('--train_dataroot', type=str, default='../datasets/', help='path of train data') 18 | self.parser.add_argument('--test_dataroot', type=str, default='../datasets/', help='path of test data') 19 | self.parser.add_argument('--phase', type=str, default='train', help='phase for dataloading') 20 | self.parser.add_argument('--batch_size', type=int, default=24, help='batch size') 21 | self.parser.add_argument('--patch_size_down', type=int, default=128, help='cropped image size for learning downsampling') 22 | self.parser.add_argument('--nThreads', type=int, default=4, help='# of threads for data loader') 23 | self.parser.add_argument('--flip', action='store_true', help='specified if flip') 24 | self.parser.add_argument('--rot', action='store_true', help='specified if rotate') 25 | self.parser.add_argument('--nobin', action='store_true', help='specified if not use bin') 26 | ## ouptput related 27 | self.parser.add_argument('--name', type=str, default='', help='folder name to save outputs') 28 | self.parser.add_argument('--experiment_dir', type=str, default='../experiments', help='path for saving result images and models') 29 | self.parser.add_argument('--result_dir', type=str, default='../results', help='path for saving result images and models') 30 | self.parser.add_argument('--img_save_freq', type=int, default=1, help='freq (epoch) of saving images') 31 | self.parser.add_argument('--model_save_freq', type=int, default=10, help='freq (epoch) of saving models') 32 | self.parser.add_argument('--make_down', action='store_true', help='specified if test') 33 | 34 | ## training related 35 | # common 36 | self.parser.add_argument('--gpu', type=str, default='cuda', help='gpu ids: e.g. 0 0,1,2, 0,2') 37 | self.parser.add_argument('--scale', type=str, choices=('2', '4'), default='2', help='scale to SR, only support [2, 4]') 38 | # learning downsampling 39 | self.parser.add_argument('--resume_down', type=str, default=None, help='load training states for resume the downsampling learning') 40 | self.parser.add_argument('--epochs_down', type=int, default=80, help='number of epochs for training downsampling') 41 | self.parser.add_argument('--lr_down', type=float, default=0.00005, help='learning rate for learning downsampling') 42 | #self.parser.add_argument('--lr_policy', type=str, default='step', help='type of learn rate decay') 43 | self.parser.add_argument('--decay_batch_size_down', type=int, default=400000, help='decay batch size for learning downsampling') # currently, not using 44 | self.parser.add_argument('--dis_norm', type=str, default='Instance', help='normalization layer in discriminator [None, Batch, Instance, Layer]') 45 | self.parser.add_argument('--gen_norm', type=str, default='Instance', help='normalization layer in generator [None, Batch, Instance, Layer]') 46 | self.parser.add_argument('--cycle_recon', action='store_true', help='use self reconstruction loss for training downsampler, not that only available with jointly sr training case') 47 | self.parser.add_argument('--cycle_recon_ratio', type=float, default=0.1, help='hyper parameter for self reconstruction loss') 48 | # training SR 49 | self.parser.add_argument('--joint', action='store_true', help='jointly training downsampler and SR network') 50 | self.parser.add_argument('--pretrain_sr', type=str, default=None, help='load pretrained SR model for stable SR learning') 51 | self.parser.add_argument('--resume_sr', type=str, default=None, help='load training states for resume the downsampling learning') 52 | self.parser.add_argument('--epochs_sr_start', type=int, default=41, help='start epochs for training SR') 53 | self.parser.add_argument('--epochs_sr_end', type=int, default=80, help='end epochs for training SR') 54 | self.parser.add_argument('--lr_sr', type=float, default=0.00010, help='learning rate for training SR') 55 | self.parser.add_argument('--adv_w', type=float, default=0.01, help='weight for adversarial loss in esrgan training') 56 | self.parser.add_argument('--per_w', type=float, default=1.0, help='weight for adversarial loss in esrgan training') 57 | self.parser.add_argument('--con_w', type=float, default=0.1, help='weight for adversarial loss in esrgan training') 58 | self.parser.add_argument('--noise', action='store_true', help='inject noise in SR training') 59 | self.parser.add_argument('--noise_std', type=float, default=5.0, help='injected std of noise') 60 | self.parser.add_argument('--decay_batch_size_sr', type=int, default=50000, help='decay batch size for training SR') 61 | self.parser.add_argument('--sr_model', type=str, choices=('edsr','rrdb'), default='edsr', help='choose model to SR') 62 | self.parser.add_argument('--training_type', type=str, default='edsr', choices=('edsr', 'esrgan'), help='choose training type of SR') 63 | self.parser.add_argument('--precision', type=str, choices=('single','half'), default='singe', help='precision for forwarding SR') 64 | self.parser.add_argument('--realsr', action='store_true', help='just make SR image without calculating PSNR') 65 | self.parser.add_argument('--baseline', action='store_true', help='just train SR network with bicubic downsampled image') 66 | self.parser.add_argument('--patch_size_sr', type=int, default=128, help='cropped image size for learning sr') 67 | self.parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') 68 | self.parser.add_argument('--test_range', type=str, default='1-10', help='test data range') 69 | 70 | ## experimnet related 71 | self.parser.add_argument('--save_snapshot', type=int, default = 20, help='save snapshot') 72 | self.parser.add_argument('--save_log', action='store_true', help='enable saving log option') 73 | self.parser.add_argument('--save_results', action='store_true', help='enable saving intermediate image option') 74 | self.parser.add_argument('--save_intermodel', action='store_true', help='enable saving intermediate model option') 75 | self.parser.add_argument('--edsr_format', type=bool, default=False, help='save image as EDSR format') 76 | 77 | ## data loss related 78 | self.parser.add_argument('--data_loss_type', type=str, choices=('adl', 'lfl', 'bic', 'gau'), default='adl', help='type of available data type') 79 | # lfl 80 | self.parser.add_argument('--box_size', type=int, default=16, help='box size for filtering') 81 | # adl 82 | self.parser.add_argument('--adl_interval', type=int, default=10, help='update interval of data loss') 83 | self.parser.add_argument('--adl_ksize', type=int, default=20, help='kernel size for kernel estimation') 84 | self.parser.add_argument('--num_for_kernel_estimate', type=int, default=50, help='number of image to estimate kernel') 85 | # gau 86 | self.parser.add_argument('--gaussian_sigma', type=float, default=2.0, help='gaussian std') 87 | self.parser.add_argument('--gaussian_ksize', type=int, default=16, help='gaussian kernel size') 88 | self.parser.add_argument('--gaussian_dense', action='store_true', help='option for dense gaussian') 89 | # balance b/w adv loss 90 | self.parser.add_argument('--ratio', type=float, default=100, help='ratio between adv loss and data loss') 91 | 92 | 93 | def parse(self): 94 | self.opt = self.parser.parse_args() 95 | 96 | if self.opt.name == '': 97 | self.opt.name = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+self.opt.phase+'_'+self.opt.data_loss_type+'_'+self.opt.source+'_'+self.opt.target 98 | 99 | return self.opt 100 | 101 | class TestOptions(): 102 | def __init__(self): 103 | self.parser = argparse.ArgumentParser() 104 | 105 | ## dataset related 106 | # learning downsampling 107 | self.parser.add_argument('--source', type=str, default='Source', help='Source type') 108 | self.parser.add_argument('--target', type=str, default='Target', help='target type') 109 | # validation set for SR, only used in joint training 110 | self.parser.add_argument('--test_hr', type=str, help='HR images for validating') 111 | self.parser.add_argument('--test_lr', type=str, help='LR images for validating') 112 | ## data loader related 113 | self.parser.add_argument('--train_dataroot', type=str, default='../datasets/', help='path of train data') 114 | self.parser.add_argument('--test_dataroot', type=str, default='../datasets/', help='path of test data') 115 | self.parser.add_argument('--phase', type=str, default='test', help='phase for dataloading') 116 | self.parser.add_argument('--batch_size', type=int, default=1, help='batch size') 117 | self.parser.add_argument('--nThreads', type=int, default=4, help='# of threads for data loader') 118 | self.parser.add_argument('--nobin', action='store_true', help='specified if not use bin') 119 | self.parser.add_argument('--flip', action='store_true', help='specified if flip') 120 | self.parser.add_argument('--rot', action='store_true', help='specified if rotate') 121 | ## ouptput related 122 | self.parser.add_argument('--name', type=str, default='', help='folder name to save outputs') 123 | self.parser.add_argument('--result_dir', type=str, default='../results', help='path for saving result images and models') 124 | self.parser.add_argument('--make_down', action='store_true', help='specified if test') 125 | self.parser.add_argument('--img_save_freq', type=int, default=1, help='freq (epoch) of saving images') 126 | self.parser.add_argument('--model_save_freq', type=int, default=10, help='freq (epoch) of saving models') 127 | 128 | ## testing related 129 | # common 130 | self.parser.add_argument('--gpu', type=str, default='cuda', help='gpu ids: e.g. 0 0,1,2, 0,2') 131 | self.parser.add_argument('--scale', type=str, choices=('2', '4'), default='2', help='scale to SR, only support [2, 4]') 132 | # testing downsampler 133 | self.parser.add_argument('--resume_down', type=str, default=None, help='load training states for resume the downsampling learning') 134 | self.parser.add_argument('--dis_norm', type=str, default='Instance', help='normalization layer in discriminator [None, Batch, Instance, Layer]') 135 | self.parser.add_argument('--gen_norm', type=str, default='Instance', help='normalization layer in generator [None, Batch, Instance, Layer]') 136 | # testing SR 137 | self.parser.add_argument('--joint', type=bool, default=True, help='always set true in test mode') 138 | self.parser.add_argument('--pretrain_sr', type=str, default=None, help='load pretrained SR model for stable SR learning') 139 | self.parser.add_argument('--resume_sr', type=str, default=None, help='load training states for resume the downsampling learning') 140 | self.parser.add_argument('--sr_model', type=str, choices=('edsr','rrdb'), default='edsr', help='choose model to SR') 141 | self.parser.add_argument('--training_type', type=str, default='edsr', choices=('edsr', 'esrgan'), help='choose training type of SR') 142 | self.parser.add_argument('--precision', type=str, choices=('single','half'), default='singe', help='precision for forwarding SR') 143 | self.parser.add_argument('--realsr', action='store_true', help='just make SR image without calculating PSNR') 144 | self.parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') 145 | self.parser.add_argument('--test_range', type=str, default='1-10', help='test data range') 146 | ## experimnet related 147 | self.parser.add_argument('--save_log', action='store_true', help='enable saving log option') 148 | self.parser.add_argument('--save_results', action='store_true', help='enable saving intermediate image option') 149 | self.parser.add_argument('--edsr_format', type=bool, default=False, help='save image as EDSR format') 150 | 151 | ## data loss related 152 | self.parser.add_argument('--data_loss_type', type=str, choices=('adl', 'lfl', 'bic', 'gau'), default='adl', help='type of available data type') 153 | # lfl 154 | self.parser.add_argument('--box_size', type=int, default=16, help='box size for filtering') 155 | # adl 156 | self.parser.add_argument('--adl_ksize', type=int, default=20, help='kernel size for kernel estimation') 157 | self.parser.add_argument('--num_for_kernel_estimate', type=int, default=50, help='number of image to estimate kernel') 158 | # gau 159 | self.parser.add_argument('--gaussian_sigma', type=float, default=2.0, help='gaussian std') 160 | self.parser.add_argument('--gaussian_ksize', type=int, default=16, help='gaussian kernel size') 161 | self.parser.add_argument('--gaussian_dense', action='store_true', help='option for dense gaussian') 162 | 163 | 164 | def parse(self): 165 | self.opt = self.parser.parse_args() 166 | args = vars(self.opt) 167 | print('\n--- loading options ---') 168 | for name, value in sorted(args.items()): 169 | print('%s: %s' % (str(name), str(value))) 170 | # set irrelevant options 171 | self.opt.dis_norm = 'None' 172 | self.opt.dis_spectral_norm = False 173 | 174 | if self.opt.name == '': 175 | self.opt.name = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+self.opt.phase 176 | 177 | return self.opt 178 | -------------------------------------------------------------------------------- /src/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | from PIL import Image 8 | from numpy import savetxt 9 | import matplotlib.pyplot as plt 10 | from torchvision.transforms import functional as TF 11 | from torchvision.transforms import ToPILImage, Compose 12 | 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | 17 | class Saver(): 18 | def __init__(self, args, test=False): 19 | self.args = args 20 | if test: 21 | default_dir = os.path.join(args.result_dir, args.name) 22 | else: 23 | default_dir = os.path.join(args.experiment_dir, args.name) 24 | self.display_dir = os.path.join(default_dir, 'training_progress') 25 | self.model_dir = os.path.join(default_dir, 'models') 26 | self.image_dir = os.path.join(default_dir, 'down_results') 27 | self.kernel_dir = os.path.join(default_dir, 'estimated_kernels') 28 | 29 | if args.edsr_format: 30 | self.image_dir = os.path.join(default_dir, args.name) 31 | 32 | self.img_save_freq = args.img_save_freq 33 | self.model_save_freq = args.model_save_freq 34 | 35 | ## make directory 36 | if not os.path.exists(self.display_dir): os.makedirs(self.display_dir) 37 | if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) 38 | if not os.path.exists(self.image_dir): os.makedirs(self.image_dir) 39 | if not os.path.exists(self.kernel_dir): os.makedirs(self.kernel_dir) 40 | 41 | if args.joint: 42 | self.image_sr_dir = os.path.join(default_dir, 'sr_results') 43 | if not os.path.exists(self.image_sr_dir): os.makedirs(self.image_sr_dir) 44 | 45 | config = os.path.join(default_dir,'config.yml') 46 | with open(config, 'w') as outfile: 47 | yaml.dump(args.__dict__, outfile, default_flow_style=False) 48 | 49 | ## save result images 50 | def write_img_down(self, ep, model): 51 | if (ep + 1) % self.img_save_freq == 0: 52 | assembled_images = model.assemble_outputs() 53 | img_filename = '%s/gen_%05d.png' % (self.display_dir, ep) 54 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 55 | elif ep == -1: 56 | assembled_images = model.assemble_outputs() 57 | img_filename = '%s/gen_last.png' % (self.display_dir, ep) 58 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 59 | 60 | ## save result images 61 | def write_img_LR(self, ep, num, model, args, fn): 62 | result_savepath = os.path.join(self.image_dir, 'ep_%03d'%ep) 63 | filename = fn[0].split('.')[0] 64 | 65 | if args.edsr_format: 66 | if args.scale == '2': 67 | scale = 'x2' 68 | elif args.scale == '4': 69 | scale = 'x4' 70 | else: 71 | raise NotImplementedError('Scale 2 and 4 are only available.') 72 | result_savepath = os.path.join(self.image_dir, scale) 73 | else: 74 | scale = '' 75 | 76 | if not os.path.exists(result_savepath): 77 | os.mkdir(result_savepath) 78 | 79 | images_list = model.get_outputs() 80 | 81 | img_filename = os.path.join(result_savepath, '%s%s.png'%(filename, scale)) 82 | torchvision.utils.save_image(images_list[1] / 2 + 0.5, img_filename, nrow=1) 83 | 84 | ## save result images 85 | def write_img_SR(self, ep, sr, filename): 86 | result_savepath = os.path.join(self.image_sr_dir, 'ep_%03d'%ep) 87 | 88 | if not os.path.exists(result_savepath): 89 | os.mkdir(result_savepath) 90 | 91 | img_filename = os.path.join(result_savepath, filename[0]) 92 | 93 | torchvision.utils.save_image(sr, img_filename, nrow=1) 94 | 95 | ## save model 96 | def write_model_down(self, ep, total_it, model): 97 | if ep != -1: 98 | print('save the down model @ ep %d' % (ep)) 99 | model.state_save('%s/training_down_%04d.pth' % (self.model_dir, ep), ep, total_it) 100 | model.model_save('%s/model_down_%04d.pth' % (self.model_dir, ep), ep, total_it) 101 | else: 102 | model.state_save('%s/training_down_last.pth' % (self.model_dir), ep, total_it) 103 | model.model_save('%s/model_down_last.pth' % (self.model_dir), ep, total_it) 104 | 105 | def write_model_sr(self, ep, total_it, model): 106 | if ep != -1: 107 | print('save the sr model @ ep %d' % (ep)) 108 | model.state_save('%s/training_sr_%04d.pth' % (self.model_dir, ep), ep, total_it) 109 | model.model_save('%s/model_sr_%04d.pth' % (self.model_dir, ep), ep, total_it) 110 | else: 111 | model.state_save('%s/training_sr_last.pth' % (self.model_dir), ep, total_it) 112 | model.model_save('%s/model_sr_last.pth' % (self.model_dir), ep, total_it) 113 | 114 | ## visualzie estimated kernel 115 | def write_kernel(self, ep, kernel): 116 | kernel_np = np.array(kernel.cpu()) 117 | savetxt(os.path.join(self.kernel_dir, 'kernel_%02d.csv'%(ep)), kernel_np, delimiter=',') 118 | 119 | kernel /= kernel.abs().max() 120 | k_pos = kernel * (kernel > 0).float() 121 | k_neg = kernel * (kernel < 0).float() 122 | k_rgb = torch.stack([-k_neg, k_pos, torch.zeros_like(k_pos)], dim=0) 123 | pil = TF.to_pil_image(k_rgb.cpu()) 124 | pil = pil.resize((self.args.adl_ksize * 20, self.args.adl_ksize * 20), resample=Image.NEAREST) 125 | pil.save(os.path.join(self.kernel_dir, 'kernel_%02d.png'%(ep))) # save kernel png 126 | 127 | -------------------------------------------------------------------------------- /src/test_down.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from saver import Saver 4 | from options import TestOptions 5 | from dataset import unpaired_dataset 6 | from utility import quantize, _normalize 7 | from trainer_down import AdaptiveDownsamplingModel 8 | 9 | 10 | # parse options 11 | parser = TestOptions() 12 | args = parser.parse() 13 | 14 | # test mode 15 | args.batch_size = 1 16 | 17 | # daita loader 18 | print('\nmaking dataset ...') 19 | dataset = unpaired_dataset(args, phase='test') 20 | test_loader_down = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.nThreads) 21 | 22 | # model 23 | print('\nmaking model ...') 24 | ADM = AdaptiveDownsamplingModel(args) 25 | 26 | if args.resume_down is None: 27 | raise NotImplementedError('put trained downsampling model for testing') 28 | else: 29 | ep0, total_it = ADM.resume(args.resume_down, train=False) 30 | ep0 += 1 31 | print('load model successfully!') 32 | 33 | saver = Saver(args, test=True) 34 | 35 | print('\ntest start ...') 36 | ADM.eval() 37 | with torch.no_grad(): 38 | for number, (img_s, _, fn) in enumerate(test_loader_down): 39 | 40 | if (number+1) % (len(test_loader_down)//10) == 0: 41 | print('[{:05d} / {:05d}] ...'.format(number+1, len(test_loader_down))) 42 | 43 | ADM.update_img(img_s) 44 | ADM.generate_LR() 45 | 46 | if args.scale == '4': 47 | [ _x ] = _normalize(ADM.img_gen) # normalize [-1,1] to [0,1] 48 | _x = quantize(_x) 49 | [ _x ] = _normalize(_x, mul=2, add=-0.5, reverse=True) # normalize [0,1] to [-1,1] 50 | ADM.img_gen = _x 51 | ADM.generate_LR(scale=args.scale) 52 | 53 | saver.write_img_LR(1, (number+1), ADM, args, fn) 54 | print('\ntest done!') -------------------------------------------------------------------------------- /src/test_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from saver import Saver 5 | from trainer_sr import SRModel 6 | from options import TestOptions 7 | from utility import timer, calc_psnr, quantize 8 | from trainer_down import AdaptiveDownsamplingModel 9 | from dataset import unpaired_dataset, paired_dataset 10 | 11 | # parse options 12 | parser = TestOptions() 13 | args = parser.parse() 14 | 15 | # test mode 16 | saver = Saver(args, test=True) 17 | args.batch_size = 1 18 | 19 | # daita loader 20 | dataset = paired_dataset(args) 21 | test_loader_sr = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.nThreads) 22 | 23 | # model 24 | SRM = SRModel(args, train=False) 25 | if (args.resume_sr is None) and (args.pretrain_sr is None): 26 | raise NotImplementedError('put pretrained model for test') 27 | elif args.resume_sr is not None: 28 | _, _ = SRM.resume(args.resume_sr, train=False) 29 | print('load model successfully!') 30 | 31 | eval_timer_sr = timer() 32 | 33 | eval_timer_sr.tic() 34 | SRM.eval() 35 | ep0 = 0 36 | psnr_sum = 0 37 | cnt = 0 38 | with torch.no_grad(): 39 | for img_hr, img_lr, fn in tqdm(test_loader_sr, ncols=80): 40 | img_hr, img_lr = img_hr.cuda(), img_lr.cuda() 41 | if args.precision == 'half': 42 | img_lr = img_lr.half() 43 | 44 | SRM.update_img(img_lr) 45 | SRM.generate_HR() 46 | 47 | img_sr = quantize(SRM.img_gen) 48 | 49 | if args.save_results: 50 | saver.write_img_SR(ep0, img_sr, fn) 51 | 52 | if not args.realsr: 53 | psnr_sum += calc_psnr( 54 | img_sr, img_hr, args.scale, rgb_range=1 55 | ) 56 | cnt += 1 57 | 58 | eval_timer_sr.hold() 59 | if not args.realsr: 60 | print('PSNR on test set: %.04f, %.01fs' % (psnr_sum/(cnt), eval_timer_sr.release())) 61 | 62 | 63 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from saver import Saver 7 | from options import Options 8 | from trainer_sr import SRModel 9 | from bicubic_pytorch import core 10 | from trainer_down import AdaptiveDownsamplingModel 11 | from dataset import unpaired_dataset, paired_dataset 12 | from utility import log_writer, plot_loss_down, plot_psnr, timer, calc_psnr, quantize, _normalize 13 | 14 | ## parse options 15 | parser = Options() 16 | args = parser.parse() 17 | 18 | ## data loader 19 | print('preparing dataset ...') 20 | dataset = unpaired_dataset(args, phase='train') 21 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.nThreads) 22 | dataset = unpaired_dataset(args, phase='test') 23 | test_loader_down = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.nThreads) 24 | if args.joint and (args.test_lr is not None): 25 | dataset = paired_dataset(args) 26 | test_loader_sr = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.nThreads) 27 | ep0 = 0 28 | total_it = 0 29 | 30 | ## SR model only if joint training 31 | if args.joint: 32 | print('\nMaking SR-Model... ') 33 | SRM = SRModel(args) 34 | 35 | ## Adaptive Downsampling Model 36 | print('\nMaking Adpative-Downsampling-Model...') 37 | ADM = AdaptiveDownsamplingModel(args) 38 | if args.resume_down is not None: 39 | ep0, total_it = ADM.resume(args.resume_down) 40 | print('\nLoad downsampling model from {}'.format(args.resume_down)) 41 | if args.resume_sr is not None: 42 | ep0, total_it = SRM.resume(args.resume_sr) 43 | print('\nLoad SR model from {}'.format(args.resume_sr)) 44 | 45 | ## saver and training log 46 | saver = Saver(args) 47 | data_timer, train_timer_down, kernel_estimator_timer = timer(), timer(), timer() 48 | if args.joint: 49 | train_timer_sr, eval_timer_sr = timer(), timer() 50 | training_log = log_writer(args.experiment_dir, args.name) 51 | 52 | ## losses 53 | loss_dis = [] # discriminator loss 54 | loss_gen = [] # generator loss 55 | loss_data = [] # data loss 56 | if args.joint: 57 | psnrs = [] # L1 loss for SR 58 | 59 | max_epochs = max(args.epochs_down, args.epochs_sr_end) if args.joint else args.epochs_down 60 | 61 | print('\ntraining start') 62 | for ep in range(ep0, max_epochs): 63 | sr_txt1, sr_txt2 = '', '' 64 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)): 65 | sr_txt1 = ' SR lr %.06f' % (SRM.gen_opt.param_groups[0]['lr']) 66 | sr_txt2 = ' SR loss |' 67 | 68 | training_log.write('\n[ epoch %03d/%03d ] G lr %.06f D lr %.06f%s' 69 | % (ep+1, max_epochs ,ADM.gen_opt.param_groups[0]['lr'], ADM.dis_opt.param_groups[0]['lr'], sr_txt1)) 70 | print_txt = '| Progress | Dis | Gen | data |%s' % (sr_txt2) 71 | training_log.write('-'*len(print_txt)) 72 | training_log.write(print_txt) 73 | 74 | loss_dis_item = 0 75 | loss_gen_item = 0 76 | loss_data_item = 0 77 | if args.joint: 78 | loss_sr_item = 0 79 | cnt = 0 80 | 81 | data_timer.tic() 82 | for it, (img_s, img_t, _) in enumerate(train_loader): 83 | if img_t.size(0) != args.batch_size: 84 | continue 85 | data_timer.hold() 86 | 87 | train_timer_down.tic() 88 | ADM.update_img(img_s, img_t) 89 | ADM.generate_LR() 90 | train_timer_down.hold() 91 | 92 | ## train downsampling network ADM 93 | train_timer_down.tic() 94 | if ((ep+1) in range(0,args.epochs_down+1)) and (not args.baseline): 95 | ADM.update_D() 96 | if args.cycle_recon: 97 | img_lr, img_hr = ADM.img_gen.clamp(min=-1.0, max=1.0), ADM.img_s.detach() 98 | img_lr, img_hr = _normalize(img_lr, img_hr) # normalize [-1,1] to [0,1] 99 | 100 | img_lr = quantize(img_lr, fake=True) 101 | 102 | SRM.update_img(img_lr, img_hr) 103 | SRM.generate_HR() 104 | SRM.calculate_grad() 105 | ADM.update_G(SRM_recon_loss=SRM.recon_loss) 106 | else: 107 | ADM.update_G() 108 | train_timer_down.hold() 109 | 110 | ## if joint training is enabled, train SR network SRM with generated image by ADM 111 | sr_loss_txt, sr_timer_txt = '', '' 112 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)): 113 | train_timer_sr.tic() 114 | if args.baseline: # use bicubic 115 | img_lr, img_hr = core.imresize(ADM.img_s, scale=0.5).detach(), ADM.img_s.detach() 116 | else: 117 | if args.scale == '4': 118 | [ _x ] = _normalize(ADM.img_gen) # normalize [-1,1] to [0,1] 119 | _x = quantize(_x) 120 | [ _x ] = _normalize(_x, mul=2, add=-0.5, reverse=True) # normalize [0,1] to [-1,1] 121 | ADM.img_gen = _x 122 | ADM.generate_LR(scale=args.scale) 123 | img_lr, img_hr = ADM.img_gen.detach().clamp(min=-1.0, max=1.0), ADM.img_s.detach() 124 | img_lr, img_hr = _normalize(img_lr, img_hr) # normalize [-1,1] to [0,1] 125 | 126 | if args.noise: 127 | n = args.noise_std * torch.Tensor(np.random.normal(size=img_lr.shape)).cuda() / 255.0 128 | img_lr = (img_lr + n).clamp(max=1.0, min=0.0) 129 | 130 | #b = max(args.patch_size_down - args.patch_size_sr, 0) // 4 131 | #if ( b != 0 ): img_lr, img_hr = img_lr[:,:,b:-b, b:-b], img_hr[:,:,2*b:-2*b, 2*b:-2*b] 132 | 133 | img_lr = quantize(img_lr) 134 | 135 | SRM.update_img(img_lr, img_hr) 136 | SRM.generate_HR() 137 | SRM.update_G() 138 | 139 | loss_sr_item += SRM.gen_loss 140 | train_timer_sr.hold() 141 | 142 | loss_dis_item += ADM.loss_dis 143 | loss_gen_item += ADM.loss_gen 144 | loss_data_item += ADM.loss_data 145 | cnt += 1 146 | 147 | ## print training log with save 148 | if (it+1) % (len(train_loader)//10) == 0: 149 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)): 150 | loss_sr_item_avg = loss_sr_item/cnt 151 | sr_loss_txt = ' %0.5f |' % (loss_sr_item_avg) 152 | sr_timer_txt = ' +%.01fs' % (train_timer_sr.release()) 153 | loss_sr_item = 0 154 | 155 | 156 | loss_dis_item_avg = loss_dis_item/cnt 157 | loss_gen_item_avg = loss_gen_item/cnt 158 | loss_data_item_avg = loss_data_item/cnt 159 | training_log.write('| %04d/%04d | %.05f | %.05f | %.06f |%s %.01f+%.01fs %s' 160 | % ( (it+1), len(train_loader), loss_dis_item_avg, loss_gen_item_avg, 161 | loss_data_item_avg, sr_loss_txt, 162 | train_timer_down.release(), data_timer.release(), sr_timer_txt)) 163 | loss_dis_item = 0 164 | loss_gen_item = 0 165 | loss_data_item = 0 166 | cnt = 0 167 | 168 | if args.save_results: 169 | saver.write_img_down(ep*len(train_loader) + (it+1), ADM) 170 | 171 | data_timer.tic() 172 | training_log.write('-'*len(print_txt)) 173 | 174 | loss_dis.append(loss_dis_item_avg) 175 | loss_gen.append(loss_gen_item_avg) 176 | loss_data.append(loss_data_item_avg) 177 | plot_loss_down(os.path.join(args.experiment_dir, args.name), loss_dis, loss_gen, loss_data) 178 | 179 | 180 | ## 2d linear kernel estimating 181 | kernel_estimator_timer.tic() 182 | ADM.eval() 183 | with torch.no_grad(): 184 | for cnt, (img_s, _, _) in enumerate(test_loader_down): 185 | img_s = img_s[:, :, 0:min(img_s.shape[2], 1000), 0:min(img_s.shape[3], 1000)] # for memory efficiency 186 | ADM.update_img(img_s) 187 | ADM.generate_LR() 188 | kernel = ADM.find_kernel() # estimate 2d kernel of current generator network 189 | estimated_kernel = ADM.stack_kernel(cnt+1, kernel) # stack to average retrieved 2d kernel 190 | if cnt == args.num_for_kernel_estimate: # not use all of test set to save computational cost 191 | break 192 | saver.write_kernel(ep+1, estimated_kernel) 193 | if (args.data_loss_type == 'adl') and ((ep+1) != args.epochs_down) and ((ep+1) % args.adl_interval == 0): 194 | training_log.write('Data Loss Update with Estimated kernel at %02d'%(ep+1)) 195 | ADM.update_dataloss() 196 | training_log.write('Total time to estimate kernel: %.01f s'%kernel_estimator_timer.toc()) 197 | ADM.train() 198 | 199 | 200 | ## sr evaluation for joint training 201 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)) and (args.test_lr is not None): 202 | if (not args.realsr) or args.save_results: 203 | eval_timer_sr.tic() 204 | SRM.eval() 205 | psnr_sum = 0 206 | cnt = 0 207 | with torch.no_grad(): 208 | for img_hr, img_lr, fn in tqdm(test_loader_sr, ncols=80): 209 | img_hr, img_lr = img_hr.cuda(), img_lr.cuda() 210 | SRM.update_img(img_lr) 211 | SRM.generate_HR() 212 | 213 | img_sr = quantize(SRM.img_gen) 214 | if args.save_results: 215 | saver.write_img_SR(ep, img_sr, fn) 216 | 217 | if not args.realsr: 218 | psnr_sum += calc_psnr( 219 | img_sr, img_hr, args.scale 220 | ) 221 | cnt += 1 222 | eval_timer_sr.hold() 223 | if not args.realsr: 224 | training_log.write('PSNR on test set: %.04f, %.01fs' % (psnr_sum/(cnt), eval_timer_sr.release())) 225 | psnrs.append(psnr_sum/(cnt)) 226 | plot_psnr(os.path.join(args.experiment_dir, args.name), psnrs) 227 | else: 228 | training_log.write('Total time elapsed: %.01fs' % (eval_timer_sr.release())) 229 | SRM.train() 230 | 231 | if (ep+1) % args.save_snapshot == 0: 232 | saver.write_model_down(ep+1, total_it+1, ADM) 233 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)): 234 | saver.write_model_sr(ep+1, total_it+1, SRM) 235 | 236 | ## Save last model and state 237 | training_log.write('Saving last model and training state..') 238 | saver.write_model_down(-1, total_it+1, ADM) 239 | if args.joint and ((ep+1) in range(args.epochs_sr_start, args.epochs_sr_end+1)): 240 | saver.write_model_sr(-1, total_it+1, SRM) 241 | 242 | 243 | if args.make_down: 244 | print('\nmaking downwampling images ...') 245 | ADM.eval() 246 | with torch.no_grad(): 247 | for number, (img_s, _, fn) in enumerate(test_loader_down): 248 | ADM.update_img(img_s) 249 | ADM.generate_LR() 250 | if args.scale == '4': 251 | ADM.generate_LR(scale=args.scale) 252 | saver.write_img_LR(ep+1, (number+1), ADM, args, fn) 253 | ADM.train() 254 | print('\ndone!') 255 | 256 | ## Save network weights 257 | #saver.write_model(ep+1, total_it+1, ADM) 258 | 259 | -------------------------------------------------------------------------------- /src/trainer_down.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import networks 5 | from PIL import Image 6 | import torch.nn as nn 7 | from filters import find_kernel 8 | from data_loss import get_data_loss 9 | from utility import get_gaussian_kernel, get_avgpool_kernel 10 | 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | class AdaptiveDownsamplingModel(nn.Module): 15 | def __init__(self, args): 16 | super(AdaptiveDownsamplingModel, self).__init__() 17 | 18 | self.args = args 19 | self.gpu = args.gpu 20 | self.data_loss_type = args.data_loss_type # data loss option 21 | 22 | self.gen = networks.G_Module(args, norm=args.gen_norm, nl_layer=networks.get_non_linearity(layer_type='lrelu')) # generator 23 | self.gen.apply(networks.gaussian_weights_init) 24 | self.gen.cuda(args.gpu) 25 | 26 | self.down_filter = None 27 | 28 | if self.args.phase == 'train': 29 | self.dis = networks.D_Module(args, norm=args.dis_norm) # discriminators 30 | self.dis.apply(networks.gaussian_weights_init) 31 | self.dis.cuda(args.gpu) 32 | 33 | self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=args.lr_down, betas=(0.9, 0.999)) 34 | self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=args.lr_down, betas=(0.9, 0.999)) 35 | 36 | self.gen_sch = torch.optim.lr_scheduler.StepLR(self.gen_opt, args.decay_batch_size_down, gamma=0.5) 37 | self.dis_sch = torch.optim.lr_scheduler.StepLR(self.dis_opt, args.decay_batch_size_down, gamma=0.5) 38 | 39 | self.ratio = args.ratio 40 | print('Data loss type : ', args.data_loss_type) 41 | 42 | ## update images to model 43 | def update_img(self, img_s, img_t=None): 44 | self.img_s = img_s.cuda(self.args.gpu).detach() 45 | if img_t is not None: 46 | self.img_t = img_t.cuda(self.args.gpu).detach() 47 | 48 | self.loss_dis = 0 49 | self.loss_gen = 0 50 | self.loss_data = 0 51 | 52 | ## generating LR iamges 53 | def generate_LR(self, scale='2'): 54 | if scale == '2': 55 | self.img_gen = self.gen(self.img_s) 56 | elif scale == '4': 57 | self.img_gen = self.gen(self.img_gen) 58 | else: 59 | raise NotImplementedError('scale is only available for [2, 4]') 60 | 61 | ## update discriminator D 62 | def update_D(self): 63 | self.dis_opt.zero_grad() 64 | 65 | loss_D = self.backward_D_gan(self.dis, self.img_t, self.img_gen) 66 | 67 | self.loss_dis = loss_D.item() 68 | 69 | self.dis_opt.step() 70 | self.dis_sch.step() 71 | 72 | ## update generator G 73 | def update_G(self, SRM_recon_loss=0): 74 | self.gen_opt.zero_grad() 75 | 76 | loss_gan = self.backward_G_gan(self.img_gen, self.dis) 77 | loss_data = get_data_loss(self.img_s, self.img_gen, self.data_loss_type, self.down_filter, self.args) * self.ratio 78 | 79 | if self.args.cycle_recon: 80 | loss_G = loss_gan + loss_data + SRM_recon_loss * self.args.cycle_recon_ratio 81 | else: 82 | loss_G = loss_gan + loss_data 83 | 84 | loss_G.backward() #retain_graph=True) 85 | 86 | self.loss_gen = loss_gan.item() 87 | self.loss_data = loss_data.item() 88 | 89 | self.gen_opt.step() 90 | self.gen_sch.step() 91 | 92 | ## loss function for discriminator D 93 | ## real to ones, and fake to zeros 94 | def backward_D_gan(self, netD, real, fake): 95 | pred_fake = netD.forward(fake.detach()) 96 | pred_real = netD.forward(real) 97 | loss_D = 0 98 | for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)): 99 | out_fake = torch.sigmoid(out_a).clamp(min=0.0, max=1.0) 100 | out_real = torch.sigmoid(out_b).clamp(min=0.0, max=1.0) 101 | all0 = torch.zeros_like(out_fake).cuda(self.gpu).clamp(min=0.0, max=1.0) 102 | all1 = torch.ones_like(out_real).cuda(self.gpu).clamp(min=0.0, max=1.0) 103 | ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0) 104 | ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1) 105 | loss_D += ad_true_loss + ad_fake_loss 106 | loss_D.backward() 107 | return loss_D 108 | 109 | def backward_G_gan(self, fake, netD=None): 110 | outs_fake = netD.forward(fake) 111 | loss_G = 0 112 | for out_a in outs_fake: 113 | outputs_fake = torch.sigmoid(out_a) 114 | all_ones = torch.ones_like(outputs_fake).cuda(self.gpu) 115 | loss_G += nn.functional.binary_cross_entropy(outputs_fake, all_ones) 116 | return loss_G 117 | 118 | ## estimated 2d kernel with linear approximation 119 | def find_kernel(self): 120 | return find_kernel(self.img_s[0], self.img_gen[0], scale=2, k=self.args.adl_ksize, max_patches=-1) 121 | 122 | ## averaging estimated kernel for stabilization 123 | def stack_kernel(self, cnt, kernel): 124 | if cnt == 1: 125 | self.estimated_kernel = kernel 126 | else: 127 | self.estimated_kernel += kernel 128 | 129 | return self.estimated_kernel / float(cnt) 130 | 131 | ## ADL; update data loss with retrieved kernel 132 | ## customize 2d convolution filter weight with estimated 2d kernel, 133 | ## and set require_grad as False. 134 | def update_dataloss(self): 135 | channels = 3 136 | kernel_size = self.args.adl_ksize 137 | my_kernel = self.estimated_kernel 138 | 139 | my_kernel = my_kernel / torch.sum(my_kernel) # sum to one 140 | 141 | my_kernel = my_kernel.view(1, 1, kernel_size, kernel_size) 142 | my_kernel = my_kernel.repeat(channels, 1, 1, 1) 143 | 144 | my_filter = nn.Conv2d(in_channels=channels, out_channels=channels, 145 | kernel_size=kernel_size, groups=channels, bias=False) 146 | 147 | my_filter.weight.data = my_kernel 148 | my_filter.weight.requires_grad = False 149 | 150 | self.down_filter = my_filter 151 | 152 | 153 | def resume(self, model_dir, train=True): 154 | checkpoint = torch.load(model_dir, map_location=lambda storage, loc: storage) 155 | self.gen.load_state_dict(checkpoint['gen']) 156 | if train: 157 | self.dis.load_state_dict(checkpoint['dis']) 158 | self.dis_opt.load_state_dict(checkpoint['dis_opt']) 159 | self.gen_opt.load_state_dict(checkpoint['gen_opt']) 160 | 161 | if self.data_loss_type == 'adl' and (checkpoint['ep']+1 > self.args.adl_interval): 162 | channels = 3 163 | kernel_size = self.args.adl_ksize 164 | self.down_filter = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels, bias=False) 165 | self.down_filter.load_state_dict(checkpoint['down_filter']) 166 | self.down_filter.weight.requires_grad = False 167 | self.down_filter = self.down_filter.cuda() 168 | 169 | return checkpoint['ep'], checkpoint['total_it'] 170 | 171 | def state_save(self, filename, ep, total_it): 172 | state = {'dis': self.dis.state_dict(), 173 | 'gen': self.gen.state_dict(), 174 | 'dis_opt': self.dis_opt.state_dict(), 175 | 'gen_opt': self.gen_opt.state_dict(), 176 | 'ep': ep, 177 | 'total_it': total_it 178 | } 179 | if self.data_loss_type == 'adl' and (self.down_filter is not None): 180 | state['down_filter'] = self.down_filter.state_dict() 181 | time.sleep(5) 182 | torch.save(state, filename) 183 | return 184 | 185 | def model_save(self, filename, ep, total_it): 186 | state = {'dis': self.dis.state_dict(), 187 | 'gen': self.gen.state_dict(), 188 | } 189 | time.sleep(5) 190 | torch.save(state, filename) 191 | return 192 | 193 | def assemble_outputs(self): 194 | images_source = self.img_s.detach() 195 | 196 | images_target = torch.zeros_like(self.img_s) # template 197 | margin = (self.img_s.shape[2] - self.img_t.shape[2]) // 2 198 | images_target[:, :, margin:-margin, margin:-margin] = self.img_t.detach() 199 | 200 | images_generated = torch.zeros_like(self.img_s) # template 201 | margin = (self.img_s.shape[2] - self.img_gen.shape[2]) // 2 202 | images_generated[:, :, margin:-margin, margin:-margin] = self.img_gen.detach() 203 | 204 | images_blank = torch.zeros_like(self.img_s).detach() # blank 205 | 206 | row1 = torch.cat((images_source[0:1, ::], images_blank[0:1, ::], images_generated[0:1, ::]),3) 207 | row2 = torch.cat((images_target[0:1, ::], images_blank[0:1, ::], images_blank[0:1, ::]),3) 208 | 209 | return torch.cat((row1,row2),2) 210 | 211 | def get_outputs(self): 212 | img_s = self.img_s.detach() 213 | img_gen = self.img_gen.detach() 214 | 215 | return [img_s, img_gen] 216 | -------------------------------------------------------------------------------- /src/trainer_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import torch 5 | import model 6 | import utility 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.nn.utils as utils 10 | 11 | from model import common 12 | 13 | from tqdm import tqdm 14 | from PIL import Image 15 | from decimal import Decimal 16 | from torch.optim.adam import Adam 17 | from torchvision.models.vgg import vgg19 18 | from model.Discriminator import Discriminator 19 | 20 | 21 | class SRModel(nn.Module): 22 | def __init__(self, args, train=True): 23 | super(SRModel, self).__init__() 24 | 25 | self.args = args 26 | self.scale = args.scale 27 | self.gpu = 'cuda' 28 | self.error_last = 1e8 29 | 30 | self.training_type = args.training_type 31 | print('sr model : ',self.args.sr_model) 32 | print('training type : ',self.training_type) 33 | 34 | # define model, optimizer, scheduler, loss 35 | self.gen = model.Model(args) 36 | 37 | if args.pretrain_sr is not None: 38 | checkpoint = torch.load(args.pretrain_sr, map_location=lambda storage, loc: storage) 39 | self.gen.load_state_dict(checkpoint) 40 | print('Load pretrained SR model from {}'.format(args.pretrain_sr)) 41 | 42 | if train: 43 | self.gen_opt = Adam(self.gen.parameters(), lr=args.lr_sr, betas=(0.9, 0.999)) 44 | self.gen_sch = torch.optim.lr_scheduler.StepLR(self.gen_opt, args.decay_batch_size_sr, gamma=0.5) #args.gamma) 45 | 46 | self.content_criterion = nn.L1Loss().to(self.gpu) 47 | 48 | if self.training_type == 'esrgan': 49 | self.dis = Discriminator(args).to(self.gpu) 50 | self.dis_opt = Adam(self.dis.parameters(), lr=args.lr_sr, betas=(0.9, 0.999)) 51 | self.dis_sch = torch.optim.lr_scheduler.StepLR(self.dis_opt, args.decay_batch_size_sr, gamma=0.5) 52 | 53 | self.adversarial_criterion = nn.BCEWithLogitsLoss().to(self.gpu) 54 | self.perception_criterion = PerceptualLoss().to(self.gpu) 55 | self.dis.train() 56 | 57 | self.gen.train() 58 | 59 | self.gen_loss = 0 60 | self.recon_loss = 0 61 | 62 | ## update images to the model 63 | def update_img(self, lr, hr=None): 64 | self.img_lr = lr 65 | self.img_hr = hr 66 | 67 | self.gen_loss = 0 68 | self.recon_loss = 0 69 | 70 | def generate_HR(self): 71 | #self.img_lr *= 255 72 | self.img_gen = self.gen(self.img_lr, 0) 73 | #self.img_gen /= 255 74 | 75 | def update_G(self): 76 | # EDSR style 77 | if self.training_type == 'edsr': 78 | self.gen_opt.zero_grad() 79 | 80 | self.recon_loss = self.content_criterion(self.img_gen, self.img_hr) * 255.0 # compensate range of 0 to 1 81 | 82 | self.recon_loss.backward() 83 | self.gen_opt.step() 84 | self.gen_loss = self.recon_loss.item() 85 | self.gen_sch.step() 86 | 87 | # ESRGAN style 88 | elif self.training_type == 'esrgan': 89 | if self.args.cycle_recon: 90 | raise NotImplementedError('Do not support using cycle reconstruction loss in ESRGAN training') 91 | real_labels = torch.ones((self.img_hr.size(0), 1)).to(self.gpu) 92 | fake_labels = torch.zeros((self.img_hr.size(0), 1)).to(self.gpu) 93 | 94 | # training generator 95 | self.gen_opt.zero_grad() 96 | 97 | score_real = self.dis(self.img_hr) 98 | score_fake = self.dis(self.img_gen) 99 | 100 | discriminator_rf = score_real - score_fake.mean() 101 | discriminator_fr = score_fake - score_real.mean() 102 | 103 | adversarial_loss_rf = self.adversarial_criterion(discriminator_rf, fake_labels) 104 | adversarial_loss_fr = self.adversarial_criterion(discriminator_fr, real_labels) 105 | adversarial_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2 106 | 107 | perceptual_loss = self.perception_criterion(self.img_gen, self.img_hr) 108 | content_loss = self.content_criterion(self.img_gen, self.img_hr) * 255.0 # compensate range of 0 to 1 109 | 110 | gen_loss = adversarial_loss * self.args.adv_w + \ 111 | perceptual_loss * self.args.per_w + \ 112 | content_loss * self.args.con_w 113 | 114 | gen_loss.backward() 115 | self.gen_loss = gen_loss.item() 116 | self.gen_opt.step() 117 | 118 | # training discriminator 119 | self.dis_opt.zero_grad() 120 | 121 | score_real = self.dis(self.img_hr) 122 | score_fake = self.dis(self.img_gen.detach()) 123 | discriminator_rf = score_real - score_fake.mean() 124 | discriminator_fr = score_fake - score_real.mean() 125 | 126 | adversarial_loss_rf = self.adversarial_criterion(discriminator_rf, real_labels) 127 | adversarial_loss_fr = self.adversarial_criterion(discriminator_fr, fake_labels) 128 | discriminator_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2 129 | 130 | discriminator_loss.backward() 131 | self.dis_opt.step() 132 | 133 | self.gen_sch.step() 134 | self.dis_sch.step() 135 | else: 136 | raise NotImplementedError('training type is not possible') 137 | 138 | def resume(self, model_dir, train=True): 139 | checkpoint = torch.load(model_dir, map_location=lambda storage, loc: storage) 140 | self.gen.load_state_dict(checkpoint['gen']) 141 | if train: 142 | self.gen_opt.load_state_dict(checkpoint['gen_opt']) 143 | if self.training_type == 'esrgan': 144 | self.dis.load_state_dict(checkpoint['dis']) 145 | self.dis_opt.load_state_dict(checkpoint['dis_opt']) 146 | return checkpoint['ep'], checkpoint['total_it'] 147 | 148 | def state_save(self, filename, ep, total_it): 149 | state = {'gen': self.gen.state_dict(), 150 | 'gen_opt': self.gen_opt.state_dict(), 151 | 'ep': ep, 152 | 'total_it': total_it 153 | } 154 | if self.training_type == 'esrgan': 155 | state['dis'] = self.dis.state_dict(), 156 | state['dis_opt'] = self.dis_opt.state_dict() 157 | time.sleep(5) 158 | torch.save(state, filename) 159 | return 160 | 161 | def model_save(self, filename, ep, total_it): 162 | state = {'gen': self.gen.state_dict()} 163 | if self.training_type == 'esrgan': 164 | state['dis'] = self.dis.state_dict() 165 | time.sleep(5) 166 | torch.save(state, filename) 167 | return 168 | 169 | 170 | class PerceptualLoss(nn.Module): 171 | def __init__(self): 172 | super(PerceptualLoss, self).__init__() 173 | 174 | vgg = vgg19(pretrained=True) 175 | loss_network = nn.Sequential(*list(vgg.features)[:35]).eval() 176 | for param in loss_network.parameters(): 177 | param.requires_grad = False 178 | self.loss_network = loss_network 179 | self.l1_loss = nn.L1Loss() 180 | 181 | def forward(self, high_resolution, fake_high_resolution): 182 | perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution)) 183 | return perception_loss 184 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import time 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import scipy.io as scio 9 | import random 10 | 11 | 12 | class log_writer(): 13 | def __init__(self, experiment_dir, name): 14 | self.log_txt = os.path.join(os.path.join(experiment_dir, name),'training_log.txt') 15 | f = open(self.log_txt, 'w') 16 | 17 | def write(self, log): 18 | print(log) 19 | f = open(self.log_txt, 'a') 20 | f.write(log+'\n') 21 | 22 | 23 | def plot_loss_down(save, loss_d, loss_g, loss_dl): 24 | assert len(loss_d) == len(loss_g) 25 | assert len(loss_d) == len(loss_dl) 26 | axis = np.linspace(1, len(loss_d), len(loss_d)) 27 | fig = plt.figure() 28 | 29 | plt.plot(axis, loss_d, axis, loss_g, axis, loss_dl) 30 | 31 | #plt.legend() 32 | plt.xlabel('epoch') 33 | plt.ylabel('loss_d(blue), loss_g(orange), loss_dl(green)') 34 | plt.grid(True) 35 | plt.savefig(os.path.join(save,'down_loss_graph.pdf')) 36 | plt.close(fig) 37 | 38 | def plot_psnr(save, psnrs): 39 | axis = np.linspace(1, len(psnrs), len(psnrs)) 40 | fig = plt.figure() 41 | 42 | plt.plot(axis, psnrs) 43 | 44 | #plt.legend() 45 | plt.xlabel('epoch') 46 | plt.ylabel('PSNR') 47 | plt.grid(True) 48 | plt.savefig(os.path.join(save,'sr_psnr_graph.pdf')) 49 | plt.close(fig) 50 | 51 | 52 | class timer(): 53 | def __init__(self): 54 | self.acc = 0 55 | self.tic() 56 | 57 | def tic(self): 58 | self.t0 = time.time() 59 | 60 | def toc(self, restart=False): 61 | diff = time.time() - self.t0 62 | if restart: self.t0 = time.time() 63 | return diff 64 | 65 | def hold(self): 66 | self.acc += self.toc() 67 | 68 | def release(self): 69 | ret = self.acc 70 | self.acc = 0 71 | 72 | return ret 73 | 74 | def reset(self): 75 | self.acc = 0 76 | 77 | 78 | def get_gaussian_kernel(kernel_size=5, sigma=1, channels=3): 79 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 80 | x_coord = torch.arange(kernel_size) 81 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 82 | y_grid = x_grid.t() 83 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 84 | 85 | mean = (kernel_size - 1)/2. 86 | variance = sigma**2. 87 | 88 | # Calculate the 2-dimensional gaussian kernel which is 89 | # the product of two gaussian distributions for two different 90 | # variables (in this case called x and y) 91 | gaussian_kernel = (1./(2.*math.pi*variance)) *\ 92 | torch.exp( 93 | -torch.sum((xy_grid - mean)**2., dim=-1) /\ 94 | (2*variance) 95 | ) 96 | 97 | # Make sure sum of values in gaussian kernel equals 1. 98 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 99 | 100 | # Reshape to 2d depthwise convolutional weight 101 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 102 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) 103 | 104 | gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, 105 | kernel_size=kernel_size, groups=channels, bias=False) 106 | 107 | gaussian_filter.weight.data = gaussian_kernel 108 | gaussian_filter.weight.requires_grad = False 109 | 110 | return gaussian_filter 111 | 112 | def get_avgpool_kernel(kernel_size=16, stride=1, channels=3): 113 | my_zeros = torch.empty(kernel_size, kernel_size) 114 | my_kernel = torch.ones_like(my_zeros) 115 | my_kernel = my_kernel / torch.sum(my_kernel) 116 | 117 | # Reshape to 2d depthwise convolutional weight 118 | my_kernel = my_kernel.view(1, 1, kernel_size, kernel_size) 119 | my_kernel = my_kernel.repeat(channels, 1, 1, 1) 120 | 121 | my_filter = nn.Conv2d(in_channels=channels, out_channels=channels, 122 | kernel_size=kernel_size, groups=channels, stride=stride, bias=False) 123 | 124 | my_filter.weight.data = my_kernel 125 | my_filter.weight.requires_grad = False 126 | 127 | return my_filter 128 | 129 | 130 | def calc_psnr(sr, hr, scale, rgb_range=1, dataset=None): 131 | if hr.nelement() == 1: return 0 132 | 133 | diff = (sr - hr) / rgb_range 134 | if False: #dataset and dataset.dataset.benchmark: 135 | shave = scale 136 | if diff.size(1) > 1: 137 | gray_coeffs = [65.738, 129.057, 25.064] 138 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 139 | diff = diff.mul(convert).sum(dim=1) 140 | else: 141 | shave = int(scale) + 6 142 | 143 | valid = diff[..., shave:-shave, shave:-shave] 144 | mse = valid.pow(2).mean() 145 | 146 | return -10 * math.log10(mse) 147 | 148 | def quantize(img, rgb_range=1, fake=False): 149 | pixel_range = 255 / rgb_range 150 | if fake: 151 | fake_img = img.mul(pixel_range).clamp(0, 255).round().div(pixel_range).detach() 152 | res = fake_img - img.detach() 153 | return img + res 154 | else: 155 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 156 | 157 | def _normalize(*args, mul=0.5, add=0.5, reverse=False): 158 | 159 | if reverse: 160 | ret = [ 161 | (args[0] + add) * mul, 162 | *[(a + add) * mul for a in args[1:]] 163 | ] 164 | else: 165 | ret = [ 166 | args[0] * mul + add, 167 | *[a * mul + add for a in args[1:]] 168 | ] 169 | 170 | return ret 171 | 172 | 173 | if __name__ == '__main__': 174 | torch.set_printoptions(precision=4, linewidth=200, sci_mode=False) 175 | k = get_avgpool_kernel(kernel_size=16) 176 | --------------------------------------------------------------------------------