├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── assets └── teaser.jpg ├── augments.py ├── data ├── __init__.py ├── benchmark.py ├── div2k.py └── realsr.py ├── example ├── demo_hr_inputting.ipynb └── inputs │ ├── 0869.png │ ├── Canon_003_HR.png │ ├── Canon_003_LR4.png │ └── Nikon_006_HR.png ├── inference.py ├── main.py ├── model ├── carn.py ├── edsr.py ├── ops.py └── rcan.py ├── option.py ├── requirements.txt ├── solver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | *.swp 4 | *.ipynb_checkpoints 5 | pt/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | CutBlur 2 | 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------------- 24 | 25 | This project contains subcomponents with separate copyright notices and license terms. 26 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 27 | 28 | ===== 29 | 30 | nmhkahn/PCARN-pytorch 31 | https://github.com/nmhkahn/PCARN-pytorch 32 | 33 | 34 | 35 | MIT License 36 | 37 | Copyright (c) 2019 Namhyuk Ahn 38 | 39 | Permission is hereby granted, free of charge, to any person obtaining a copy 40 | of this software and associated documentation files (the "Software"), to deal 41 | in the Software without restriction, including without limitation the rights 42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 43 | copies of the Software, and to permit persons to whom the Software is 44 | furnished to do so, subject to the following conditions: 45 | 46 | The above copyright notice and this permission notice shall be included in all 47 | copies or substantial portions of the Software. 48 | 49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 55 | SOFTWARE. 56 | 57 | ===== 58 | 59 | thstkdgus35/EDSR-PyTorch 60 | https://github.com/thstkdgus35/EDSR-PyTorch 61 | 62 | 63 | MIT License 64 | 65 | Copyright (c) 2018 Sanghyun Son 66 | 67 | Permission is hereby granted, free of charge, to any person obtaining a copy 68 | of this software and associated documentation files (the "Software"), to deal 69 | in the Software without restriction, including without limitation the rights 70 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 71 | copies of the Software, and to permit persons to whom the Software is 72 | furnished to do so, subject to the following conditions: 73 | 74 | The above copyright notice and this permission notice shall be included in all 75 | copies or substantial portions of the Software. 76 | 77 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 78 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 79 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 80 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 81 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 82 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 83 | SOFTWARE. 84 | 85 | ===== 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Rethinking Data Augmentation for Image Super-resolution (CVPR 2020) 2 |

3 | 4 | This repository provides the official PyTorch implementation of the following paper: 5 | > **Rethinking Data Augmentation for Image Super-resolution: A Comprehensive Analysis and a New Strategy**
6 | > [Jaejun Yoo](https://www.linkedin.com/in/jaejunyoo/)\*1, [Namhyuk Ahn](https://nmhkahn.github.io)\*2, [Kyung-Ah Sohn](https://sites.google.com/site/kasohn/home)2
7 | > * indicates equal contribution. Most work was done in NAVER Corp.
8 | > 1 EPFL
9 | > 2 Ajou University
10 | > [https://arxiv.org/abs/2004.00448](https://arxiv.org/abs/2004.00448)
11 | > 12 | > **Abstract:** *Data augmentation is an effective way to improve the performance of deep networks. Unfortunately, current methods are mostly developed for high-level vision tasks (e.g., classification) and few are studied for low-level vision tasks (e.g., image restoration). In this paper, we provide a comprehensive analysis of the existing augmentation methods applied to the super-resolution task. We find that the methods discarding or manipulating the pixels or features too much hamper the image restoration, where the spatial relationship is very important. Based on our analyses, we propose **CutBlur** that cuts a low-resolution patch and pastes it to the corresponding high-resolution image region and vice versa. The key intuition of CutBlur is to enable a model to learn not only "how" but also "where" to super-resolve an image. By doing so, the model can understand "how much", instead of blindly learning to apply super-resolution to every given pixel. Our method consistently and significantly improves the performance across various scenarios, especially when the model size is big and the data is collected under real-world environments. We also show that our method improves other low-level vision tasks, such as denoising and compression artifact removal.* 13 | 14 | ### 0. Requirement 15 | Simpy run: 16 | ```shell 17 | pip3 install -r requirements.txt 18 | ``` 19 | 20 | ### 1. Quick Start (Demo) 21 | You can test our models with any images. Place images in `./input` directory and run the below script.
22 | Before executing the script, please download the pretrained model on [CutBlur_model](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing) and change the `--model` and `--pretrain` arguments appropriately. 23 | 24 | ```shell 25 | python inference.py \ 26 | --model [EDSR|RCAN|CARN] \ 27 | --pretrain \ 28 | --dataset_root ./input \ 29 | --save_root ./output 30 | ``` 31 | 32 | We also provide a [demo](./example/demo_hr_inputting.ipynb) to visualize how the mixture of augmentation (MoA) prevent the SR model from over-sharpening. 33 | 34 | ### 2. Dataset 35 | #### DIV2K 36 | We use the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) dataset to train the model. Download and unpack the tar file any directory you want.
37 | **Important:** For the DIV2K dataset only, all the train and valid images should be placed in the `DIV2K_train_HR` and `DIV2K_train_LR_bicubic` directories (We parse train and valid images using `--div2k_range` argument). 38 | 39 | #### SR Benchmark 40 | For the benchmark dataset used in the paper (Set14, Urban100, and manga109), we provide original images on [here](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing). 41 | 42 | #### RealSR 43 | We use the [RealSR](https://github.com/csjcai/RealSR) dataset (version 1). In the paper, we utilized both Canon and Nikon images for train and test. 44 | 45 | ### 3. Evaluate Pre-trained Models 46 | For the folks who want to compare our result, we provide [our result](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing) generated on the RealSR and SR benchmark dataset. 47 | 48 | We present an example script to evaluate the pretrained model below: 49 | 50 | ```shell 51 | python main.py \ 52 | --dataset [DIV2K_SR|RealSR|Set14_SR|Urban100_SR|manga109_SR] \ 53 | --model [EDSR|RCAN|CARN] \ 54 | --pretrain \ 55 | --dataset_root \ 56 | --save_root ./output 57 | --test_only 58 | ``` 59 | 60 | For `--dataset` argument, `_SR` postfix is required to identify the task among SR, DN and JPEG restoration.
61 | And if you evaluate on the `DIV2K_SR`, please add `--div2k_range 1-800/801-900` argument to specify the range of the images you use. 62 | 63 | Note that, `[DIV2K, Set14_SR, Urban100_SR, manga109_SR]` have to be evaluate using the model trained on the DIV2K dataset (e.g. `DIV2K_EDSR_moa.pt`) while `[RealSR]` via the model with RealSR dataset (e.g. `RealSR_EDSR_moa.pt`). 64 | 65 | ### 4. Train Models 66 | #### DIV2K (bicubic) 67 | To achieve the result in the paper, X2 scale pretraining is necessary. 68 | First, train the model on the X2 scale as below: 69 | ```shell 70 | python main.py \ 71 | --use_moa \ 72 | --model [EDSR|RCAN|CARN] \ 73 | --dataset DIV2K_SR \ 74 | --div2k_range 1-800/801-810 \ 75 | --scale 2 \ 76 | --dataset_root 77 | ``` 78 | If you want to train the baseline model, discard `--use_moa` option. 79 | 80 | By default, the trained model will be saved in `./pt` directory. And since evaluating the whole valid images takes a lot of time, we just validated the model on the first ten images during training. 81 | 82 | Then, fine-tune the trained model on the X4 scale: 83 | ```shell 84 | python main.py \ 85 | --use_moa \ 86 | --model [EDSR|RCAN|CARN] \ 87 | --dataset DIV2K_SR \ 88 | --div2k_range 1-800/801-810 \ 89 | --scale 4 \ 90 | --pretrain \ 91 | --dataset_root 92 | ``` 93 | Please see the `option.py` for more detailed options. 94 | 95 | #### RealSR 96 | Simply run this code: 97 | ```shell 98 | python main.py \ 99 | --use_moa \ 100 | --model [EDSR|RCAN|CARN] \ 101 | --dataset RealSR \ 102 | --scale 4 --camera all \ 103 | --dataset_root 104 | ``` 105 | 106 | ### Updates 107 | - **02 Apr, 2020**: Initial upload. 108 | 109 | ### License 110 | ``` 111 | Copyright (c) 2020-present NAVER Corp. 112 | 113 | Permission is hereby granted, free of charge, to any person obtaining a copy 114 | of this software and associated documentation files (the "Software"), to deal 115 | in the Software without restriction, including without limitation the rights 116 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 117 | copies of the Software, and to permit persons to whom the Software is 118 | furnished to do so, subject to the following conditions: 119 | 120 | The above copyright notice and this permission notice shall be included in 121 | all copies or substantial portions of the Software. 122 | 123 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 124 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 125 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 126 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 127 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 128 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 129 | THE SOFTWARE. 130 | ``` 131 | 132 | ### Citation 133 | ``` 134 | @article{yoo2020rethinking, 135 | title={Rethinking Data Augmentation for Image Super-resolution: A Comprehensive Analysis and a New Strategy}, 136 | author={Yoo, Jaejun and Ahn, Namhyuk and Sohn, Kyung-Ah}, 137 | journal={arXiv preprint arXiv:2004.00448}, 138 | year={2020} 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/assets/teaser.jpg -------------------------------------------------------------------------------- /augments.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | def apply_augment( 11 | im1, im2, 12 | augs, probs, alphas, 13 | aux_prob=None, aux_alpha=None, 14 | mix_p=None 15 | ): 16 | idx = np.random.choice(len(augs), p=mix_p) 17 | aug = augs[idx] 18 | prob = float(probs[idx]) 19 | alpha = float(alphas[idx]) 20 | mask = None 21 | 22 | if aug == "none": 23 | im1_aug, im2_aug = im1.clone(), im2.clone() 24 | elif aug == "blend": 25 | im1_aug, im2_aug = blend( 26 | im1.clone(), im2.clone(), 27 | prob=prob, alpha=alpha 28 | ) 29 | elif aug == "mixup": 30 | im1_aug, im2_aug, = mixup( 31 | im1.clone(), im2.clone(), 32 | prob=prob, alpha=alpha, 33 | ) 34 | elif aug == "cutout": 35 | im1_aug, im2_aug, mask, _ = cutout( 36 | im1.clone(), im2.clone(), 37 | prob=prob, alpha=alpha 38 | ) 39 | elif aug == "cutmix": 40 | im1_aug, im2_aug = cutmix( 41 | im1.clone(), im2.clone(), 42 | prob=prob, alpha=alpha, 43 | ) 44 | elif aug == "cutmixup": 45 | im1_aug, im2_aug = cutmixup( 46 | im1.clone(), im2.clone(), 47 | mixup_prob=aux_prob, mixup_alpha=aux_alpha, 48 | cutmix_prob=prob, cutmix_alpha=alpha, 49 | ) 50 | elif aug == "cutblur": 51 | im1_aug, im2_aug = cutblur( 52 | im1.clone(), im2.clone(), 53 | prob=prob, alpha=alpha 54 | ) 55 | elif aug == "rgb": 56 | im1_aug, im2_aug = rgb( 57 | im1.clone(), im2.clone(), 58 | prob=prob 59 | ) 60 | else: 61 | raise ValueError("{} is not invalid.".format(aug)) 62 | 63 | return im1_aug, im2_aug, mask, aug 64 | 65 | 66 | def blend(im1, im2, prob=1.0, alpha=0.6): 67 | if alpha <= 0 or np.random.rand(1) >= prob: 68 | return im1, im2 69 | 70 | c = torch.empty((im2.size(0), 3, 1, 1), device=im2.device).uniform_(0, 255) 71 | rim2 = c.repeat((1, 1, im2.size(2), im2.size(3))) 72 | rim1 = c.repeat((1, 1, im1.size(2), im1.size(3))) 73 | 74 | v = np.random.uniform(alpha, 1) 75 | im1 = v * im1 + (1-v) * rim1 76 | im2 = v * im2 + (1-v) * rim2 77 | 78 | return im1, im2 79 | 80 | 81 | def mixup(im1, im2, prob=1.0, alpha=1.2): 82 | if alpha <= 0 or np.random.rand(1) >= prob: 83 | return im1, im2 84 | 85 | v = np.random.beta(alpha, alpha) 86 | r_index = torch.randperm(im1.size(0)).to(im2.device) 87 | 88 | im1 = v * im1 + (1-v) * im1[r_index, :] 89 | im2 = v * im2 + (1-v) * im2[r_index, :] 90 | return im1, im2 91 | 92 | 93 | def _cutmix(im2, prob=1.0, alpha=1.0): 94 | if alpha <= 0 or np.random.rand(1) >= prob: 95 | return None 96 | 97 | cut_ratio = np.random.randn() * 0.01 + alpha 98 | 99 | h, w = im2.size(2), im2.size(3) 100 | ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio) 101 | 102 | fcy = np.random.randint(0, h-ch+1) 103 | fcx = np.random.randint(0, w-cw+1) 104 | tcy, tcx = fcy, fcx 105 | rindex = torch.randperm(im2.size(0)).to(im2.device) 106 | 107 | return { 108 | "rindex": rindex, "ch": ch, "cw": cw, 109 | "tcy": tcy, "tcx": tcx, "fcy": fcy, "fcx": fcx, 110 | } 111 | 112 | 113 | def cutmix(im1, im2, prob=1.0, alpha=1.0): 114 | c = _cutmix(im2, prob, alpha) 115 | if c is None: 116 | return im1, im2 117 | 118 | scale = im1.size(2) // im2.size(2) 119 | rindex, ch, cw = c["rindex"], c["ch"], c["cw"] 120 | tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"] 121 | 122 | hch, hcw = ch*scale, cw*scale 123 | hfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scale 124 | 125 | im2[..., tcy:tcy+ch, tcx:tcx+cw] = im2[rindex, :, fcy:fcy+ch, fcx:fcx+cw] 126 | im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[rindex, :, hfcy:hfcy+hch, hfcx:hfcx+hcw] 127 | 128 | return im1, im2 129 | 130 | 131 | def cutmixup( 132 | im1, im2, 133 | mixup_prob=1.0, mixup_alpha=1.0, 134 | cutmix_prob=1.0, cutmix_alpha=1.0 135 | ): 136 | c = _cutmix(im2, cutmix_prob, cutmix_alpha) 137 | if c is None: 138 | return im1, im2 139 | 140 | scale = im1.size(2) // im2.size(2) 141 | rindex, ch, cw = c["rindex"], c["ch"], c["cw"] 142 | tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"] 143 | 144 | hch, hcw = ch*scale, cw*scale 145 | hfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scale 146 | 147 | v = np.random.beta(mixup_alpha, mixup_alpha) 148 | if mixup_alpha <= 0 or np.random.rand(1) >= mixup_prob: 149 | im2_aug = im2[rindex, :] 150 | im1_aug = im1[rindex, :] 151 | 152 | else: 153 | im2_aug = v * im2 + (1-v) * im2[rindex, :] 154 | im1_aug = v * im1 + (1-v) * im1[rindex, :] 155 | 156 | # apply mixup to inside or outside 157 | if np.random.random() > 0.5: 158 | im2[..., tcy:tcy+ch, tcx:tcx+cw] = im2_aug[..., fcy:fcy+ch, fcx:fcx+cw] 159 | im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1_aug[..., hfcy:hfcy+hch, hfcx:hfcx+hcw] 160 | else: 161 | im2_aug[..., tcy:tcy+ch, tcx:tcx+cw] = im2[..., fcy:fcy+ch, fcx:fcx+cw] 162 | im1_aug[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[..., hfcy:hfcy+hch, hfcx:hfcx+hcw] 163 | im2, im1 = im2_aug, im1_aug 164 | 165 | return im1, im2 166 | 167 | 168 | def cutblur(im1, im2, prob=1.0, alpha=1.0): 169 | if im1.size() != im2.size(): 170 | raise ValueError("im1 and im2 have to be the same resolution.") 171 | 172 | if alpha <= 0 or np.random.rand(1) >= prob: 173 | return im1, im2 174 | 175 | cut_ratio = np.random.randn() * 0.01 + alpha 176 | 177 | h, w = im2.size(2), im2.size(3) 178 | ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio) 179 | cy = np.random.randint(0, h-ch+1) 180 | cx = np.random.randint(0, w-cw+1) 181 | 182 | # apply CutBlur to inside or outside 183 | if np.random.random() > 0.5: 184 | im2[..., cy:cy+ch, cx:cx+cw] = im1[..., cy:cy+ch, cx:cx+cw] 185 | else: 186 | im2_aug = im1.clone() 187 | im2_aug[..., cy:cy+ch, cx:cx+cw] = im2[..., cy:cy+ch, cx:cx+cw] 188 | im2 = im2_aug 189 | 190 | return im1, im2 191 | 192 | 193 | def cutout(im1, im2, prob=1.0, alpha=0.1): 194 | scale = im1.size(2) // im2.size(2) 195 | fsize = (im2.size(0), 1)+im2.size()[2:] 196 | 197 | if alpha <= 0 or np.random.rand(1) >= prob: 198 | fim2 = np.ones(fsize) 199 | fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device) 200 | fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest") 201 | return im1, im2, fim1, fim2 202 | 203 | fim2 = np.random.choice([0.0, 1.0], size=fsize, p=[alpha, 1-alpha]) 204 | fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device) 205 | fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest") 206 | 207 | im2 *= fim2 208 | 209 | return im1, im2, fim1, fim2 210 | 211 | 212 | def rgb(im1, im2, prob=1.0): 213 | if np.random.rand(1) >= prob: 214 | return im1, im2 215 | 216 | perm = np.random.permutation(3) 217 | im1 = im1[:, perm] 218 | im2 = im2[:, perm] 219 | 220 | return im1, im2 221 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import glob 8 | import importlib 9 | import numpy as np 10 | import skimage.io as io 11 | import skimage.color as color 12 | import torch 13 | import utils 14 | 15 | def generate_loader(phase, opt): 16 | cname = opt.dataset.replace("_", "") 17 | if "DIV2K" in opt.dataset: 18 | mname = importlib.import_module("data.div2k") 19 | elif "RealSR" in opt.dataset: 20 | mname = importlib.import_module("data.realsr") 21 | elif "SR" in opt.dataset: # SR benchmark datasets 22 | mname = importlib.import_module("data.benchmark") 23 | cname = "BenchmarkSR" 24 | elif "DN" in opt.dataset: # DN benchmark datasets 25 | mname = importlib.import_module("data.benchmark") 26 | cname = "BenchmarkSR" 27 | elif "JPEG" in opt.dataset: # JPEG benchmark datasets 28 | mname = importlib.import_module("data.benchmark") 29 | cname = "BenchmarkSR" 30 | else: 31 | raise ValueError("Unsupported dataset: {}".format(opt.dataset)) 32 | 33 | kwargs = { 34 | "batch_size": opt.batch_size if phase == "train" else 1, 35 | "num_workers": opt.num_workers if phase == "train" else 0, 36 | "shuffle": phase == "train", 37 | "drop_last": phase == "train", 38 | } 39 | 40 | dataset = getattr(mname, cname)(phase, opt) 41 | return torch.utils.data.DataLoader(dataset, **kwargs) 42 | 43 | 44 | class BaseDataset(torch.utils.data.Dataset): 45 | def __init__(self, phase, opt): 46 | print("Load dataset... (phase: {}, len: {})".format(phase, len(self.HQ_paths))) 47 | self.HQ, self.LQ = list(), list() 48 | for HQ_path, LQ_path in zip(self.HQ_paths, self.LQ_paths): 49 | self.HQ += [io.imread(HQ_path)] 50 | self.LQ += [io.imread(LQ_path)] 51 | 52 | self.phase = phase 53 | self.opt = opt 54 | 55 | def __getitem__(self, index): 56 | # follow the setup of EDSR-pytorch 57 | if self.phase == "train": 58 | index = index % len(self.HQ) 59 | 60 | def im2tensor(im): 61 | np_t = np.ascontiguousarray(im.transpose((2, 0, 1))) 62 | tensor = torch.from_numpy(np_t).float() 63 | return tensor 64 | 65 | HQ, LQ = self.HQ[index], self.LQ[index] 66 | if len(HQ.shape) < 3: 67 | HQ = color.gray2rgb(HQ) 68 | if len(LQ.shape) < 3: 69 | LQ = color.gray2rgb(LQ) 70 | 71 | if self.phase == "train": 72 | inp_scale = HQ.shape[0] // LQ.shape[0] 73 | HQ, LQ = utils.crop(HQ, LQ, self.opt.patch_size, inp_scale) 74 | HQ, LQ = utils.flip_and_rotate(HQ, LQ) 75 | return im2tensor(HQ), im2tensor(LQ) 76 | 77 | def __len__(self): 78 | # follow the setup of EDSR-pytorch 79 | if self.phase == "train": 80 | return (1000 * self.opt.batch_size) // len(self.HQ) * len(self.HQ) 81 | return len(self.HQ) 82 | -------------------------------------------------------------------------------- /data/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import glob 8 | import data 9 | 10 | class BenchmarkSR(data.BaseDataset): 11 | def __init__(self, phase, opt): 12 | root = opt.dataset_root 13 | 14 | self.scale = opt.scale 15 | dir_HQ, dir_LQ = self.get_subdir() 16 | self.HQ_paths = sorted(glob.glob(os.path.join(root, dir_HQ, "*.png"))) 17 | self.LQ_paths = sorted(glob.glob(os.path.join(root, dir_LQ, "*.png"))) 18 | 19 | super().__init__(phase, opt) 20 | 21 | def get_subdir(self): 22 | dir_HQ = "HR" 23 | dir_LQ = "X{}".format(self.scale) 24 | return dir_HQ, dir_LQ 25 | 26 | 27 | class BenchmarkDN(BenchmarkSR): 28 | def __init__(self, phase, opt): 29 | self.sigma = opt.sigma 30 | 31 | super().__init__(phase, opt) 32 | 33 | def get_subdir(self): 34 | dir_HQ = "HQ" 35 | dir_LQ = "{}".format(self.sigma) 36 | return dir_HQ, dir_LQ 37 | 38 | 39 | class BenchmarkJPEG(BenchmarkSR): 40 | def __init__(self, phase, opt): 41 | self.quality = opt.quality 42 | 43 | super().__init__(phase, opt) 44 | 45 | def get_subdir(self): 46 | dir_HQ = "HQ" 47 | dir_LQ = "{}".format(self.quality) 48 | return dir_HQ, dir_LQ 49 | -------------------------------------------------------------------------------- /data/div2k.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import glob 8 | import data 9 | 10 | class DIV2KSR(data.BaseDataset): 11 | def __init__(self, phase, opt): 12 | root = opt.dataset_root 13 | 14 | self.scale = opt.scale 15 | dir_HQ, dir_LQ = self.get_subdir() 16 | self.HQ_paths = sorted(glob.glob(os.path.join(root, dir_HQ, "*.png"))) 17 | self.LQ_paths = sorted(glob.glob(os.path.join(root, dir_LQ, "*.png"))) 18 | 19 | split = [int(n) for n in opt.div2k_range.replace("/", "-").split("-")] 20 | if phase == "train": 21 | s = slice(split[0]-1, split[1]) 22 | self.HQ_paths, self.LQ_paths = self.HQ_paths[s], self.LQ_paths[s] 23 | else: 24 | s = slice(split[2]-1, split[3]) 25 | self.HQ_paths, self.LQ_paths = self.HQ_paths[s], self.LQ_paths[s] 26 | 27 | super().__init__(phase, opt) 28 | 29 | def get_subdir(self): 30 | dir_HQ = "DIV2K_train_HR" 31 | dir_LQ = "DIV2K_train_LR_bicubic/X{}".format(self.scale) 32 | return dir_HQ, dir_LQ 33 | 34 | 35 | class DIV2KDN(DIV2KSR): 36 | def __init__(self, phase, opt): 37 | self.sigma = opt.sigma 38 | 39 | super().__init__(phase, opt) 40 | 41 | def get_subdir(self): 42 | dir_HQ = "DIV2K_train_HR" 43 | dir_LQ = "DIV2K_train_DN/{}".format(self.sigma) 44 | return dir_HQ, dir_LQ 45 | 46 | 47 | class DIV2KJPEG(DIV2KSR): 48 | def __init__(self, phase, opt): 49 | self.quality = opt.quality 50 | 51 | super().__init__(phase, opt) 52 | 53 | def get_subdir(self): 54 | dir_HQ = "DIV2K_train_HR" 55 | dir_LQ = "DIV2K_train_JPEG/{}".format(self.quality) 56 | return dir_HQ, dir_LQ 57 | -------------------------------------------------------------------------------- /data/realsr.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import glob 8 | import data 9 | 10 | class RealSR(data.BaseDataset): 11 | def __init__(self, phase, opt): 12 | root = opt.dataset_root 13 | 14 | self.scale = opt.scale 15 | 16 | subdir = "Train" if phase == "train" else "Test" 17 | path_canon_all = sorted(glob.glob(os.path.join( 18 | root, "Canon", subdir, str(self.scale), "*.png" 19 | ))) 20 | path_nikon_all = sorted(glob.glob(os.path.join( 21 | root, "Nikon", subdir, str(self.scale), "*.png" 22 | ))) 23 | 24 | path_canon_HR = [p for p in path_canon_all if "HR" in p] 25 | path_canon_LR = [p for p in path_canon_all if "LR" in p] 26 | path_nikon_HR = [p for p in path_nikon_all if "HR" in p] 27 | path_nikon_LR = [p for p in path_nikon_all if "LR" in p] 28 | 29 | if opt.camera == "canon": 30 | self.HQ_paths = path_canon_HR 31 | self.LQ_paths = path_canon_LR 32 | elif opt.camera == "nikon": 33 | self.HQ_paths = path_nikon_HR 34 | self.LQ_paths = path_nikon_LR 35 | elif opt.camera == "all": 36 | self.HQ_paths = path_canon_HR+path_nikon_HR 37 | self.LQ_paths = path_canon_LR+path_nikon_LR 38 | else: 39 | raise ValueError("camera must be one of the [canon, nikon, all].") 40 | 41 | super().__init__(phase, opt) 42 | -------------------------------------------------------------------------------- /example/inputs/0869.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/0869.png -------------------------------------------------------------------------------- /example/inputs/Canon_003_HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Canon_003_HR.png -------------------------------------------------------------------------------- /example/inputs/Canon_003_LR4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Canon_003_LR4.png -------------------------------------------------------------------------------- /example/inputs/Nikon_006_HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Nikon_006_HR.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import glob 8 | import importlib 9 | from tqdm import tqdm 10 | import numpy as np 11 | import skimage.io as io 12 | import skimage.color as color 13 | import torch 14 | import torch.nn.functional as F 15 | import option 16 | 17 | def im2tensor(im): 18 | np_t = np.ascontiguousarray(im.transpose((2, 0, 1))) 19 | tensor = torch.from_numpy(np_t).float() 20 | return tensor 21 | 22 | 23 | @torch.no_grad() 24 | def main(opt): 25 | os.makedirs(opt.save_root, exist_ok=True) 26 | 27 | dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | module = importlib.import_module("model.{}".format(opt.model.lower())) 29 | net = module.Net(opt).to(dev) 30 | 31 | state_dict = torch.load(opt.pretrain, map_location=lambda storage, loc: storage) 32 | net.load_state_dict(state_dict) 33 | 34 | paths = sorted(glob.glob(os.path.join(opt.dataset_root, "*.png"))) 35 | for path in tqdm(paths): 36 | name = path.split("/")[-1] 37 | 38 | LR = color.gray2rgb(io.imread(path)) 39 | LR = im2tensor(LR).unsqueeze(0).to(dev) 40 | LR = F.interpolate(LR, scale_factor=opt.scale, mode="nearest") 41 | 42 | SR = net(LR).detach() 43 | SR = SR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy() 44 | 45 | save_path = os.path.join(opt.save_root, name) 46 | io.imsave(save_path, SR) 47 | 48 | 49 | if __name__ == "__main__": 50 | opt = option.get_option() 51 | main(opt) 52 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import json 7 | import importlib 8 | import torch 9 | from option import get_option 10 | from solver import Solver 11 | 12 | def main(): 13 | opt = get_option() 14 | torch.manual_seed(opt.seed) 15 | 16 | module = importlib.import_module("model.{}".format(opt.model.lower())) 17 | 18 | if not opt.test_only: 19 | print(json.dumps(vars(opt), indent=4)) 20 | 21 | solver = Solver(module, opt) 22 | if opt.test_only: 23 | print("Evaluate {} (loaded from {})".format(opt.model, opt.pretrain)) 24 | psnr = solver.evaluate() 25 | print("{:.2f}".format(psnr)) 26 | else: 27 | solver.fit() 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /model/carn.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | Referenced from PCARN-pytorch, https://github.com/nmhkahn/PCARN-pytorch 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from model import ops 10 | 11 | class Group(nn.Module): 12 | def __init__(self, num_channels, num_blocks, res_scale=1.0): 13 | super().__init__() 14 | 15 | for nb in range(num_blocks): 16 | setattr(self, 17 | "b{}".format(nb+1), 18 | ops.ResBlock(num_channels, res_scale) 19 | ) 20 | setattr(self, 21 | "c{}".format(nb+1), 22 | nn.Conv2d(num_channels*(nb+2), num_channels, 1, 1, 0) 23 | ) 24 | self.num_blocks = num_blocks 25 | 26 | def forward(self, x): 27 | c = out = x 28 | for nb in range(self.num_blocks): 29 | unit_b = getattr(self, "b{}".format(nb+1)) 30 | unit_c = getattr(self, "c{}".format(nb+1)) 31 | 32 | b = unit_b(out) 33 | c = torch.cat([c, b], dim=1) 34 | out = unit_c(c) 35 | 36 | return out 37 | 38 | 39 | class Net(nn.Module): 40 | def __init__(self, opt): 41 | super().__init__() 42 | 43 | self.sub_mean = ops.MeanShift(255) 44 | self.add_mean = ops.MeanShift(255, sign=1) 45 | 46 | head = [ 47 | ops.DownBlock(opt.scale), 48 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1) 49 | ] 50 | 51 | # define body module 52 | for ng in range(opt.num_groups): 53 | setattr(self, 54 | "c{}".format(ng+1), 55 | nn.Conv2d(opt.num_channels*(ng+2), opt.num_channels, 1, 1, 0) 56 | ) 57 | setattr(self, 58 | "b{}".format(ng+1), 59 | Group(opt.num_channels, opt.num_blocks) 60 | ) 61 | 62 | tail = [ 63 | ops.Upsampler(opt.num_channels, opt.scale), 64 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1) 65 | ] 66 | 67 | self.head = nn.Sequential(*head) 68 | self.tail = nn.Sequential(*tail) 69 | 70 | self.opt = opt 71 | 72 | def forward(self, x): 73 | x = self.sub_mean(x) 74 | x = self.head(x) 75 | 76 | c = out = x 77 | for ng in range(self.opt.num_groups): 78 | group = getattr(self, "b{}".format(ng+1)) 79 | conv = getattr(self, "c{}".format(ng+1)) 80 | 81 | g = group(out) 82 | c = torch.cat([c, g], dim=1) 83 | out = conv(c) 84 | res = out 85 | res += x 86 | 87 | x = self.tail(res) 88 | x = self.add_mean(x) 89 | 90 | return x 91 | -------------------------------------------------------------------------------- /model/edsr.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | Referenced from EDSR-PyTorch, https://github.com/thstkdgus35/EDSR-PyTorch 6 | """ 7 | import torch.nn as nn 8 | from model import ops 9 | 10 | class Net(nn.Module): 11 | def __init__(self, opt): 12 | super().__init__() 13 | 14 | self.sub_mean = ops.MeanShift(255) 15 | self.add_mean = ops.MeanShift(255, sign=1) 16 | 17 | head = [ 18 | ops.DownBlock(opt.scale), 19 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1) 20 | ] 21 | 22 | body = list() 23 | for _ in range(opt.num_blocks): 24 | body += [ops.ResBlock(opt.num_channels, opt.res_scale)] 25 | body += [nn.Conv2d(opt.num_channels, opt.num_channels, 3, 1, 1)] 26 | 27 | tail = [ 28 | ops.Upsampler(opt.num_channels, opt.scale), 29 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1) 30 | ] 31 | 32 | self.head = nn.Sequential(*head) 33 | self.body = nn.Sequential(*body) 34 | self.tail = nn.Sequential(*tail) 35 | 36 | self.opt = opt 37 | 38 | def forward(self, x): 39 | x = self.sub_mean(x) 40 | x = self.head(x) 41 | 42 | res = self.body(x) 43 | res += x 44 | 45 | x = self.tail(res) 46 | x = self.add_mean(x) 47 | 48 | return x 49 | -------------------------------------------------------------------------------- /model/ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | class MeanShift(nn.Conv2d): 11 | def __init__( 12 | self, 13 | rgb_range, sign=-1, 14 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), 15 | ): 16 | super().__init__(3, 3, kernel_size=1) 17 | std = torch.Tensor(rgb_std) 18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 20 | for p in self.parameters(): 21 | p.requires_grad = False 22 | 23 | 24 | class ResBlock(nn.Module): 25 | def __init__(self, num_channels, res_scale=1.0): 26 | super().__init__() 27 | 28 | self.body = nn.Sequential( 29 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(num_channels, num_channels, 3, 1, 1) 32 | ) 33 | self.res_scale = res_scale 34 | 35 | def forward(self, x): 36 | res = self.body(x).mul(self.res_scale) 37 | res += x 38 | 39 | return res 40 | 41 | 42 | class Upsampler(nn.Sequential): 43 | def __init__(self, num_channels, scale): 44 | m = list() 45 | if (scale & (scale-1)) == 0: 46 | for _ in range(int(math.log(scale, 2))): 47 | m += [nn.Conv2d(num_channels, 4*num_channels, 3, 1, 1)] 48 | m.append(nn.PixelShuffle(2)) 49 | elif scale == 3: 50 | m += [nn.Conv2d(num_channels, 9*num_channels, 3, 1, 1)] 51 | m.append(nn.PixelShuffle(3)) 52 | else: 53 | raise NotImplementedError 54 | 55 | super().__init__(*m) 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, scale): 60 | super().__init__() 61 | 62 | self.scale = scale 63 | 64 | def forward(self, x): 65 | n, c, h, w = x.size() 66 | x = x.view(n, c, h//self.scale, self.scale, w//self.scale, self.scale) 67 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() 68 | x = x.view(n, c * (self.scale**2), h//self.scale, w//self.scale) 69 | return x 70 | -------------------------------------------------------------------------------- /model/rcan.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | Referenced from EDSR-PyTorch, https://github.com/thstkdgus35/EDSR-PyTorch 6 | """ 7 | import torch.nn as nn 8 | from model import ops 9 | 10 | class CALayer(nn.Module): 11 | def __init__(self, num_channels, reduction=16): 12 | super().__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.conv_du = nn.Sequential( 15 | nn.Conv2d(num_channels, num_channels//reduction, 1, 1, 0), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(num_channels//reduction, num_channels, 1, 1, 0), 18 | nn.Sigmoid() 19 | ) 20 | 21 | def forward(self, x): 22 | y = self.avg_pool(x) 23 | y = self.conv_du(y) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | def __init__(self, num_channels, reduction, res_scale): 29 | super().__init__() 30 | 31 | body = [ 32 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 35 | ] 36 | body.append(CALayer(num_channels, reduction)) 37 | 38 | self.body = nn.Sequential(*body) 39 | self.res_scale = res_scale 40 | 41 | def forward(self, x): 42 | res = self.body(x).mul(self.res_scale) 43 | res += x 44 | return res 45 | 46 | 47 | class Group(nn.Module): 48 | def __init__(self, num_channels, num_blocks, reduction, res_scale=1.0): 49 | super().__init__() 50 | 51 | body = list() 52 | for _ in range(num_blocks): 53 | body += [RCAB(num_channels, reduction, res_scale)] 54 | body += [nn.Conv2d(num_channels, num_channels, 3, 1, 1)] 55 | self.body = nn.Sequential(*body) 56 | 57 | def forward(self, x): 58 | res = self.body(x) 59 | res += x 60 | return res 61 | 62 | 63 | class Net(nn.Module): 64 | def __init__(self, opt): 65 | super().__init__() 66 | 67 | self.sub_mean = ops.MeanShift(255) 68 | self.add_mean = ops.MeanShift(255, sign=1) 69 | 70 | head = [ 71 | ops.DownBlock(opt.scale), 72 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1) 73 | ] 74 | 75 | body = list() 76 | for _ in range(opt.num_groups): 77 | body += [ 78 | Group(opt.num_channels, opt.num_blocks, opt.reduction, opt.res_scale 79 | )] 80 | body += [nn.Conv2d(opt.num_channels, opt.num_channels, 3, 1, 1)] 81 | 82 | tail = [ 83 | ops.Upsampler(opt.num_channels, opt.scale), 84 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1) 85 | ] 86 | 87 | self.head = nn.Sequential(*head) 88 | self.body = nn.Sequential(*body) 89 | self.tail = nn.Sequential(*tail) 90 | 91 | self.opt = opt 92 | 93 | def forward(self, x): 94 | x = self.sub_mean(x) 95 | x = self.head(x) 96 | 97 | res = self.body(x) 98 | res += x 99 | 100 | x = self.tail(res) 101 | x = self.add_mean(x) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import argparse 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--seed", type=int, default=1) 11 | 12 | # models 13 | parser.add_argument("--pretrain", type=str) 14 | parser.add_argument("--model", type=str, default="EDSR") 15 | 16 | # augmentations 17 | parser.add_argument("--use_moa", action="store_true") 18 | parser.add_argument("--augs", nargs="*", default=["none"]) 19 | parser.add_argument("--prob", nargs="*", default=[1.0]) 20 | parser.add_argument("--mix_p", nargs="*") 21 | parser.add_argument("--alpha", nargs="*", default=[1.0]) 22 | parser.add_argument("--aux_prob", type=float, default=1.0) 23 | parser.add_argument("--aux_alpha", type=float, default=1.2) 24 | 25 | # dataset 26 | parser.add_argument("--dataset_root", type=str, default="") 27 | parser.add_argument("--dataset", type=str, default="DIV2K_SR") 28 | parser.add_argument("--camera", type=str, default="all") # RealSR 29 | parser.add_argument("--div2k_range", type=str, default="1-800/801-810") 30 | parser.add_argument("--scale", type=int, default=4) # SR 31 | parser.add_argument("--sigma", type=int, default=10) # DN 32 | parser.add_argument("--quality", type=int, default=10) # DeJPEG 33 | parser.add_argument("--type", type=int, default=1) # DeBlur 34 | 35 | # training setups 36 | parser.add_argument("--lr", type=float, default=1e-4) 37 | parser.add_argument("--decay", type=str, default="200-400-600") 38 | parser.add_argument("--gamma", type=int, default=0.5) 39 | parser.add_argument("--patch_size", type=int, default=48) 40 | parser.add_argument("--batch_size", type=int, default=16) 41 | parser.add_argument("--max_steps", type=int, default=700000) 42 | parser.add_argument("--eval_steps", type=int, default=1000) 43 | parser.add_argument("--num_workers", type=int, default=2) 44 | parser.add_argument("--gclip", type=int, default=0) 45 | 46 | # misc 47 | parser.add_argument("--test_only", action="store_true") 48 | parser.add_argument("--save_result", action="store_true") 49 | parser.add_argument("--ckpt_root", type=str, default="./pt") 50 | parser.add_argument("--save_root", type=str, default="./output") 51 | 52 | return parser.parse_args() 53 | 54 | 55 | def make_template(opt): 56 | opt.strict_load = opt.test_only 57 | 58 | # model 59 | if "EDSR" in opt.model: 60 | opt.num_blocks = 32 61 | opt.num_channels = 256 62 | opt.res_scale = 0.1 63 | if "RCAN" in opt.model: 64 | opt.num_groups = 10 65 | opt.num_blocks = 20 66 | opt.num_channels = 64 67 | opt.reduction = 16 68 | opt.res_scale = 1.0 69 | opt.max_steps = 1000000 70 | opt.decay = "200-400-600-800" 71 | opt.gclip = 0.5 if opt.pretrain else opt.gclip 72 | if "CARN" in opt.model: 73 | opt.num_groups = 3 74 | opt.num_blocks = 3 75 | opt.num_channels = 64 76 | opt.res_scale = 1.0 77 | opt.batch_size = 64 78 | opt.decay = "400" 79 | 80 | # training setup 81 | if "DN" in opt.dataset or "JPEG" in opt.dataset: 82 | opt.max_steps = 1000000 83 | opt.decay = "300-550-800" 84 | if "RealSR" in opt.dataset: 85 | opt.patch_size *= opt.scale # identical (LR, HR) resolution 86 | 87 | # evaluation setup 88 | opt.crop = 6 if "DIV2K" in opt.dataset else 0 89 | opt.crop += opt.scale if "SR" in opt.dataset else 4 90 | 91 | # note: we tested on color DN task 92 | if "DIV2K" in opt.dataset or "DN" in opt.dataset: 93 | opt.eval_y_only = False 94 | else: 95 | opt.eval_y_only = True 96 | 97 | # default augmentation policies 98 | if opt.use_moa: 99 | opt.augs = ["blend", "rgb", "mixup", "cutout", "cutmix", "cutmixup", "cutblur"] 100 | opt.prob = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 101 | opt.alpha = [0.6, 1.0, 1.2, 0.001, 0.7, 0.7, 0.7] 102 | opt.aux_prob, opt.aux_alpha = 1.0, 1.2 103 | opt.mix_p = None 104 | 105 | if "RealSR" in opt.dataset: 106 | opt.mix_p = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4] 107 | 108 | if "DN" in opt.dataset or "JPEG" in opt.dataset: 109 | opt.prob = [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6] 110 | if "CARN" in opt.model and not "RealSR" in opt.dataset: 111 | opt.prob = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2] 112 | 113 | 114 | def get_option(): 115 | opt = parse_args() 116 | make_template(opt) 117 | return opt 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.4.0 3 | scikit-image 4 | numpy 5 | tqdm -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import os 7 | import time 8 | import skimage.io as io 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import utils 13 | import augments 14 | from data import generate_loader 15 | 16 | class Solver(): 17 | def __init__(self, module, opt): 18 | self.opt = opt 19 | 20 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | self.net = module.Net(opt).to(self.dev) 22 | print("# params:", sum(map(lambda x: x.numel(), self.net.parameters()))) 23 | 24 | if opt.pretrain: 25 | self.load(opt.pretrain) 26 | 27 | self.loss_fn = nn.L1Loss() 28 | self.optim = torch.optim.Adam( 29 | self.net.parameters(), opt.lr, 30 | betas=(0.9, 0.999), eps=1e-8 31 | ) 32 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR( 33 | self.optim, [1000*int(d) for d in opt.decay.split("-")], 34 | gamma=opt.gamma, 35 | ) 36 | 37 | if not opt.test_only: 38 | self.train_loader = generate_loader("train", opt) 39 | self.test_loader = generate_loader("test", opt) 40 | 41 | self.t1, self.t2 = None, None 42 | self.best_psnr, self.best_step = 0, 0 43 | 44 | def fit(self): 45 | opt = self.opt 46 | 47 | self.t1 = time.time() 48 | for step in range(opt.max_steps): 49 | try: 50 | inputs = next(iters) 51 | except (UnboundLocalError, StopIteration): 52 | iters = iter(self.train_loader) 53 | inputs = next(iters) 54 | 55 | HR = inputs[0].to(self.dev) 56 | LR = inputs[1].to(self.dev) 57 | 58 | # match the resolution of (LR, HR) due to CutBlur 59 | if HR.size() != LR.size(): 60 | scale = HR.size(2) // LR.size(2) 61 | LR = F.interpolate(LR, scale_factor=scale, mode="nearest") 62 | 63 | HR, LR, mask, aug = augments.apply_augment( 64 | HR, LR, 65 | opt.augs, opt.prob, opt.alpha, 66 | opt.aux_alpha, opt.aux_alpha, opt.mix_p 67 | ) 68 | 69 | SR = self.net(LR) 70 | if aug == "cutout": 71 | SR, HR = SR*mask, HR*mask 72 | 73 | loss = self.loss_fn(SR, HR) 74 | self.optim.zero_grad() 75 | loss.backward() 76 | 77 | if opt.gclip > 0: 78 | torch.nn.utils.clip_grad_value_(self.net.parameters(), opt.gclip) 79 | 80 | self.optim.step() 81 | self.scheduler.step() 82 | 83 | if (step+1) % opt.eval_steps == 0: 84 | self.summary_and_save(step) 85 | 86 | def summary_and_save(self, step): 87 | step, max_steps = (step+1)//1000, self.opt.max_steps//1000 88 | psnr = self.evaluate() 89 | self.t2 = time.time() 90 | 91 | if psnr >= self.best_psnr: 92 | self.best_psnr, self.best_step = psnr, step 93 | self.save(step) 94 | 95 | curr_lr = self.scheduler.get_lr()[0] 96 | eta = (self.t2-self.t1) * (max_steps-step) / 3600 97 | print("[{}K/{}K] {:.2f} (Best: {:.2f} @ {}K step) LR: {}, ETA: {:.1f} hours" 98 | .format(step, max_steps, psnr, self.best_psnr, self.best_step, 99 | curr_lr, eta)) 100 | 101 | self.t1 = time.time() 102 | 103 | @torch.no_grad() 104 | def evaluate(self): 105 | opt = self.opt 106 | self.net.eval() 107 | 108 | if opt.save_result: 109 | save_root = os.path.join(opt.save_root, opt.dataset) 110 | os.makedirs(save_root, exist_ok=True) 111 | 112 | psnr = 0 113 | for i, inputs in enumerate(self.test_loader): 114 | HR = inputs[0].to(self.dev) 115 | LR = inputs[1].to(self.dev) 116 | 117 | # match the resolution of (LR, HR) due to CutBlur 118 | if HR.size() != LR.size(): 119 | scale = HR.size(2) // LR.size(2) 120 | LR = F.interpolate(LR, scale_factor=scale, mode="nearest") 121 | 122 | SR = self.net(LR).detach() 123 | HR = HR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy() 124 | SR = SR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy() 125 | 126 | if opt.save_result: 127 | save_path = os.path.join(save_root, "{:04d}.png".format(i+1)) 128 | io.imsave(save_path, SR) 129 | 130 | HR = HR[opt.crop:-opt.crop, opt.crop:-opt.crop, :] 131 | SR = SR[opt.crop:-opt.crop, opt.crop:-opt.crop, :] 132 | if opt.eval_y_only: 133 | HR = utils.rgb2ycbcr(HR) 134 | SR = utils.rgb2ycbcr(SR) 135 | psnr += utils.calculate_psnr(HR, SR) 136 | 137 | self.net.train() 138 | 139 | return psnr/len(self.test_loader) 140 | 141 | def load(self, path): 142 | state_dict = torch.load(path, map_location=lambda storage, loc: storage) 143 | 144 | if self.opt.strict_load: 145 | self.net.load_state_dict(state_dict) 146 | return 147 | 148 | # when to fine-tune the pre-trained model 149 | own_state = self.net.state_dict() 150 | for name, param in state_dict.items(): 151 | if name in own_state: 152 | if isinstance(param, nn.Parameter): 153 | param = param.data 154 | 155 | try: 156 | own_state[name].copy_(param) 157 | except Exception: 158 | # head and tail modules can be different 159 | if name.find("head") == -1 and name.find("tail") == -1: 160 | raise RuntimeError( 161 | "While copying the parameter named {}, " 162 | "whose dimensions in the model are {} and " 163 | "whose dimensions in the checkpoint are {}." 164 | .format(name, own_state[name].size(), param.size()) 165 | ) 166 | else: 167 | raise RuntimeError( 168 | "Missing key {} in model's state_dict".format(name) 169 | ) 170 | 171 | def save(self, step): 172 | os.makedirs(self.opt.ckpt_root, exist_ok=True) 173 | save_path = os.path.join(self.opt.ckpt_root, str(step)+".pt") 174 | torch.save(self.net.state_dict(), save_path) 175 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | CutBlur 3 | Copyright 2020-present NAVER corp. 4 | MIT license 5 | """ 6 | import random 7 | import numpy as np 8 | 9 | def crop(HQ, LQ, psize, scale=4): 10 | h, w = LQ.shape[:-1] 11 | x = random.randrange(0, w-psize+1) 12 | y = random.randrange(0, h-psize+1) 13 | 14 | crop_HQ = HQ[y*scale:y*scale+psize*scale, x*scale:x*scale+psize*scale] 15 | crop_LQ = LQ[y:y+psize, x:x+psize] 16 | 17 | return crop_HQ.copy(), crop_LQ.copy() 18 | 19 | 20 | def flip_and_rotate(HQ, LQ): 21 | hflip = random.random() < 0.5 22 | vflip = random.random() < 0.5 23 | rot90 = random.random() < 0.5 24 | 25 | if hflip: 26 | HQ, LQ = HQ[:, ::-1, :], LQ[:, ::-1, :] 27 | if vflip: 28 | HQ, LQ = HQ[::-1, :, :], LQ[::-1, :, :] 29 | if rot90: 30 | HQ, LQ = HQ.transpose(1, 0, 2), LQ.transpose(1, 0, 2) 31 | 32 | return HQ, LQ 33 | 34 | 35 | def rgb2ycbcr(img, y_only=True): 36 | in_img_type = img.dtype 37 | img.astype(np.float32) 38 | if in_img_type != np.uint8: 39 | img *= 255. 40 | 41 | if y_only: 42 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 43 | else: 44 | rlt = np.matmul( 45 | img, 46 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 47 | [24.966, 112.0, -18.214]] 48 | ) / 255.0 + [16, 128, 128] 49 | if in_img_type == np.uint8: 50 | rlt = rlt.round() 51 | else: 52 | rlt /= 255. 53 | return rlt.astype(in_img_type) 54 | 55 | 56 | def calculate_psnr(img1, img2): 57 | img1 = img1.astype(np.float64) 58 | img2 = img2.astype(np.float64) 59 | 60 | mse = np.mean((img1 - img2)**2) 61 | if mse == 0: 62 | return float("inf") 63 | return 20 * np.log10(255.0 / np.sqrt(mse)) 64 | --------------------------------------------------------------------------------