├── .gitignore
├── LICENSE
├── NOTICE
├── README.md
├── assets
└── teaser.jpg
├── augments.py
├── data
├── __init__.py
├── benchmark.py
├── div2k.py
└── realsr.py
├── example
├── demo_hr_inputting.ipynb
└── inputs
│ ├── 0869.png
│ ├── Canon_003_HR.png
│ ├── Canon_003_LR4.png
│ └── Nikon_006_HR.png
├── inference.py
├── main.py
├── model
├── carn.py
├── edsr.py
├── ops.py
└── rcan.py
├── option.py
├── requirements.txt
├── solver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
3 | *.swp
4 | *.ipynb_checkpoints
5 | pt/
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020-present NAVER Corp.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | CutBlur
2 |
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
23 | --------------------------------------------------------------------------------------
24 |
25 | This project contains subcomponents with separate copyright notices and license terms.
26 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
27 |
28 | =====
29 |
30 | nmhkahn/PCARN-pytorch
31 | https://github.com/nmhkahn/PCARN-pytorch
32 |
33 |
34 |
35 | MIT License
36 |
37 | Copyright (c) 2019 Namhyuk Ahn
38 |
39 | Permission is hereby granted, free of charge, to any person obtaining a copy
40 | of this software and associated documentation files (the "Software"), to deal
41 | in the Software without restriction, including without limitation the rights
42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
43 | copies of the Software, and to permit persons to whom the Software is
44 | furnished to do so, subject to the following conditions:
45 |
46 | The above copyright notice and this permission notice shall be included in all
47 | copies or substantial portions of the Software.
48 |
49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
55 | SOFTWARE.
56 |
57 | =====
58 |
59 | thstkdgus35/EDSR-PyTorch
60 | https://github.com/thstkdgus35/EDSR-PyTorch
61 |
62 |
63 | MIT License
64 |
65 | Copyright (c) 2018 Sanghyun Son
66 |
67 | Permission is hereby granted, free of charge, to any person obtaining a copy
68 | of this software and associated documentation files (the "Software"), to deal
69 | in the Software without restriction, including without limitation the rights
70 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
71 | copies of the Software, and to permit persons to whom the Software is
72 | furnished to do so, subject to the following conditions:
73 |
74 | The above copyright notice and this permission notice shall be included in all
75 | copies or substantial portions of the Software.
76 |
77 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
78 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
79 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
80 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
81 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
82 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
83 | SOFTWARE.
84 |
85 | =====
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Rethinking Data Augmentation for Image Super-resolution (CVPR 2020)
2 |

3 |
4 | This repository provides the official PyTorch implementation of the following paper:
5 | > **Rethinking Data Augmentation for Image Super-resolution: A Comprehensive Analysis and a New Strategy**
6 | > [Jaejun Yoo](https://www.linkedin.com/in/jaejunyoo/)\*1, [Namhyuk Ahn](https://nmhkahn.github.io)\*2, [Kyung-Ah Sohn](https://sites.google.com/site/kasohn/home)2
7 | > * indicates equal contribution. Most work was done in NAVER Corp.
8 | > 1 EPFL
9 | > 2 Ajou University
10 | > [https://arxiv.org/abs/2004.00448](https://arxiv.org/abs/2004.00448)
11 | >
12 | > **Abstract:** *Data augmentation is an effective way to improve the performance of deep networks. Unfortunately, current methods are mostly developed for high-level vision tasks (e.g., classification) and few are studied for low-level vision tasks (e.g., image restoration). In this paper, we provide a comprehensive analysis of the existing augmentation methods applied to the super-resolution task. We find that the methods discarding or manipulating the pixels or features too much hamper the image restoration, where the spatial relationship is very important. Based on our analyses, we propose **CutBlur** that cuts a low-resolution patch and pastes it to the corresponding high-resolution image region and vice versa. The key intuition of CutBlur is to enable a model to learn not only "how" but also "where" to super-resolve an image. By doing so, the model can understand "how much", instead of blindly learning to apply super-resolution to every given pixel. Our method consistently and significantly improves the performance across various scenarios, especially when the model size is big and the data is collected under real-world environments. We also show that our method improves other low-level vision tasks, such as denoising and compression artifact removal.*
13 |
14 | ### 0. Requirement
15 | Simpy run:
16 | ```shell
17 | pip3 install -r requirements.txt
18 | ```
19 |
20 | ### 1. Quick Start (Demo)
21 | You can test our models with any images. Place images in `./input` directory and run the below script.
22 | Before executing the script, please download the pretrained model on [CutBlur_model](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing) and change the `--model` and `--pretrain` arguments appropriately.
23 |
24 | ```shell
25 | python inference.py \
26 | --model [EDSR|RCAN|CARN] \
27 | --pretrain \
28 | --dataset_root ./input \
29 | --save_root ./output
30 | ```
31 |
32 | We also provide a [demo](./example/demo_hr_inputting.ipynb) to visualize how the mixture of augmentation (MoA) prevent the SR model from over-sharpening.
33 |
34 | ### 2. Dataset
35 | #### DIV2K
36 | We use the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) dataset to train the model. Download and unpack the tar file any directory you want.
37 | **Important:** For the DIV2K dataset only, all the train and valid images should be placed in the `DIV2K_train_HR` and `DIV2K_train_LR_bicubic` directories (We parse train and valid images using `--div2k_range` argument).
38 |
39 | #### SR Benchmark
40 | For the benchmark dataset used in the paper (Set14, Urban100, and manga109), we provide original images on [here](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing).
41 |
42 | #### RealSR
43 | We use the [RealSR](https://github.com/csjcai/RealSR) dataset (version 1). In the paper, we utilized both Canon and Nikon images for train and test.
44 |
45 | ### 3. Evaluate Pre-trained Models
46 | For the folks who want to compare our result, we provide [our result](https://drive.google.com/file/d/11lMnMsPv3uZHcKMOw7GAXsyi2suBCkUu/view?usp=sharing) generated on the RealSR and SR benchmark dataset.
47 |
48 | We present an example script to evaluate the pretrained model below:
49 |
50 | ```shell
51 | python main.py \
52 | --dataset [DIV2K_SR|RealSR|Set14_SR|Urban100_SR|manga109_SR] \
53 | --model [EDSR|RCAN|CARN] \
54 | --pretrain \
55 | --dataset_root \
56 | --save_root ./output
57 | --test_only
58 | ```
59 |
60 | For `--dataset` argument, `_SR` postfix is required to identify the task among SR, DN and JPEG restoration.
61 | And if you evaluate on the `DIV2K_SR`, please add `--div2k_range 1-800/801-900` argument to specify the range of the images you use.
62 |
63 | Note that, `[DIV2K, Set14_SR, Urban100_SR, manga109_SR]` have to be evaluate using the model trained on the DIV2K dataset (e.g. `DIV2K_EDSR_moa.pt`) while `[RealSR]` via the model with RealSR dataset (e.g. `RealSR_EDSR_moa.pt`).
64 |
65 | ### 4. Train Models
66 | #### DIV2K (bicubic)
67 | To achieve the result in the paper, X2 scale pretraining is necessary.
68 | First, train the model on the X2 scale as below:
69 | ```shell
70 | python main.py \
71 | --use_moa \
72 | --model [EDSR|RCAN|CARN] \
73 | --dataset DIV2K_SR \
74 | --div2k_range 1-800/801-810 \
75 | --scale 2 \
76 | --dataset_root
77 | ```
78 | If you want to train the baseline model, discard `--use_moa` option.
79 |
80 | By default, the trained model will be saved in `./pt` directory. And since evaluating the whole valid images takes a lot of time, we just validated the model on the first ten images during training.
81 |
82 | Then, fine-tune the trained model on the X4 scale:
83 | ```shell
84 | python main.py \
85 | --use_moa \
86 | --model [EDSR|RCAN|CARN] \
87 | --dataset DIV2K_SR \
88 | --div2k_range 1-800/801-810 \
89 | --scale 4 \
90 | --pretrain \
91 | --dataset_root
92 | ```
93 | Please see the `option.py` for more detailed options.
94 |
95 | #### RealSR
96 | Simply run this code:
97 | ```shell
98 | python main.py \
99 | --use_moa \
100 | --model [EDSR|RCAN|CARN] \
101 | --dataset RealSR \
102 | --scale 4 --camera all \
103 | --dataset_root
104 | ```
105 |
106 | ### Updates
107 | - **02 Apr, 2020**: Initial upload.
108 |
109 | ### License
110 | ```
111 | Copyright (c) 2020-present NAVER Corp.
112 |
113 | Permission is hereby granted, free of charge, to any person obtaining a copy
114 | of this software and associated documentation files (the "Software"), to deal
115 | in the Software without restriction, including without limitation the rights
116 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
117 | copies of the Software, and to permit persons to whom the Software is
118 | furnished to do so, subject to the following conditions:
119 |
120 | The above copyright notice and this permission notice shall be included in
121 | all copies or substantial portions of the Software.
122 |
123 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
124 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
125 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
126 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
127 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
128 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
129 | THE SOFTWARE.
130 | ```
131 |
132 | ### Citation
133 | ```
134 | @article{yoo2020rethinking,
135 | title={Rethinking Data Augmentation for Image Super-resolution: A Comprehensive Analysis and a New Strategy},
136 | author={Yoo, Jaejun and Ahn, Namhyuk and Sohn, Kyung-Ah},
137 | journal={arXiv preprint arXiv:2004.00448},
138 | year={2020}
139 | }
140 | ```
141 |
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/assets/teaser.jpg
--------------------------------------------------------------------------------
/augments.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | def apply_augment(
11 | im1, im2,
12 | augs, probs, alphas,
13 | aux_prob=None, aux_alpha=None,
14 | mix_p=None
15 | ):
16 | idx = np.random.choice(len(augs), p=mix_p)
17 | aug = augs[idx]
18 | prob = float(probs[idx])
19 | alpha = float(alphas[idx])
20 | mask = None
21 |
22 | if aug == "none":
23 | im1_aug, im2_aug = im1.clone(), im2.clone()
24 | elif aug == "blend":
25 | im1_aug, im2_aug = blend(
26 | im1.clone(), im2.clone(),
27 | prob=prob, alpha=alpha
28 | )
29 | elif aug == "mixup":
30 | im1_aug, im2_aug, = mixup(
31 | im1.clone(), im2.clone(),
32 | prob=prob, alpha=alpha,
33 | )
34 | elif aug == "cutout":
35 | im1_aug, im2_aug, mask, _ = cutout(
36 | im1.clone(), im2.clone(),
37 | prob=prob, alpha=alpha
38 | )
39 | elif aug == "cutmix":
40 | im1_aug, im2_aug = cutmix(
41 | im1.clone(), im2.clone(),
42 | prob=prob, alpha=alpha,
43 | )
44 | elif aug == "cutmixup":
45 | im1_aug, im2_aug = cutmixup(
46 | im1.clone(), im2.clone(),
47 | mixup_prob=aux_prob, mixup_alpha=aux_alpha,
48 | cutmix_prob=prob, cutmix_alpha=alpha,
49 | )
50 | elif aug == "cutblur":
51 | im1_aug, im2_aug = cutblur(
52 | im1.clone(), im2.clone(),
53 | prob=prob, alpha=alpha
54 | )
55 | elif aug == "rgb":
56 | im1_aug, im2_aug = rgb(
57 | im1.clone(), im2.clone(),
58 | prob=prob
59 | )
60 | else:
61 | raise ValueError("{} is not invalid.".format(aug))
62 |
63 | return im1_aug, im2_aug, mask, aug
64 |
65 |
66 | def blend(im1, im2, prob=1.0, alpha=0.6):
67 | if alpha <= 0 or np.random.rand(1) >= prob:
68 | return im1, im2
69 |
70 | c = torch.empty((im2.size(0), 3, 1, 1), device=im2.device).uniform_(0, 255)
71 | rim2 = c.repeat((1, 1, im2.size(2), im2.size(3)))
72 | rim1 = c.repeat((1, 1, im1.size(2), im1.size(3)))
73 |
74 | v = np.random.uniform(alpha, 1)
75 | im1 = v * im1 + (1-v) * rim1
76 | im2 = v * im2 + (1-v) * rim2
77 |
78 | return im1, im2
79 |
80 |
81 | def mixup(im1, im2, prob=1.0, alpha=1.2):
82 | if alpha <= 0 or np.random.rand(1) >= prob:
83 | return im1, im2
84 |
85 | v = np.random.beta(alpha, alpha)
86 | r_index = torch.randperm(im1.size(0)).to(im2.device)
87 |
88 | im1 = v * im1 + (1-v) * im1[r_index, :]
89 | im2 = v * im2 + (1-v) * im2[r_index, :]
90 | return im1, im2
91 |
92 |
93 | def _cutmix(im2, prob=1.0, alpha=1.0):
94 | if alpha <= 0 or np.random.rand(1) >= prob:
95 | return None
96 |
97 | cut_ratio = np.random.randn() * 0.01 + alpha
98 |
99 | h, w = im2.size(2), im2.size(3)
100 | ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio)
101 |
102 | fcy = np.random.randint(0, h-ch+1)
103 | fcx = np.random.randint(0, w-cw+1)
104 | tcy, tcx = fcy, fcx
105 | rindex = torch.randperm(im2.size(0)).to(im2.device)
106 |
107 | return {
108 | "rindex": rindex, "ch": ch, "cw": cw,
109 | "tcy": tcy, "tcx": tcx, "fcy": fcy, "fcx": fcx,
110 | }
111 |
112 |
113 | def cutmix(im1, im2, prob=1.0, alpha=1.0):
114 | c = _cutmix(im2, prob, alpha)
115 | if c is None:
116 | return im1, im2
117 |
118 | scale = im1.size(2) // im2.size(2)
119 | rindex, ch, cw = c["rindex"], c["ch"], c["cw"]
120 | tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"]
121 |
122 | hch, hcw = ch*scale, cw*scale
123 | hfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scale
124 |
125 | im2[..., tcy:tcy+ch, tcx:tcx+cw] = im2[rindex, :, fcy:fcy+ch, fcx:fcx+cw]
126 | im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[rindex, :, hfcy:hfcy+hch, hfcx:hfcx+hcw]
127 |
128 | return im1, im2
129 |
130 |
131 | def cutmixup(
132 | im1, im2,
133 | mixup_prob=1.0, mixup_alpha=1.0,
134 | cutmix_prob=1.0, cutmix_alpha=1.0
135 | ):
136 | c = _cutmix(im2, cutmix_prob, cutmix_alpha)
137 | if c is None:
138 | return im1, im2
139 |
140 | scale = im1.size(2) // im2.size(2)
141 | rindex, ch, cw = c["rindex"], c["ch"], c["cw"]
142 | tcy, tcx, fcy, fcx = c["tcy"], c["tcx"], c["fcy"], c["fcx"]
143 |
144 | hch, hcw = ch*scale, cw*scale
145 | hfcy, hfcx, htcy, htcx = fcy*scale, fcx*scale, tcy*scale, tcx*scale
146 |
147 | v = np.random.beta(mixup_alpha, mixup_alpha)
148 | if mixup_alpha <= 0 or np.random.rand(1) >= mixup_prob:
149 | im2_aug = im2[rindex, :]
150 | im1_aug = im1[rindex, :]
151 |
152 | else:
153 | im2_aug = v * im2 + (1-v) * im2[rindex, :]
154 | im1_aug = v * im1 + (1-v) * im1[rindex, :]
155 |
156 | # apply mixup to inside or outside
157 | if np.random.random() > 0.5:
158 | im2[..., tcy:tcy+ch, tcx:tcx+cw] = im2_aug[..., fcy:fcy+ch, fcx:fcx+cw]
159 | im1[..., htcy:htcy+hch, htcx:htcx+hcw] = im1_aug[..., hfcy:hfcy+hch, hfcx:hfcx+hcw]
160 | else:
161 | im2_aug[..., tcy:tcy+ch, tcx:tcx+cw] = im2[..., fcy:fcy+ch, fcx:fcx+cw]
162 | im1_aug[..., htcy:htcy+hch, htcx:htcx+hcw] = im1[..., hfcy:hfcy+hch, hfcx:hfcx+hcw]
163 | im2, im1 = im2_aug, im1_aug
164 |
165 | return im1, im2
166 |
167 |
168 | def cutblur(im1, im2, prob=1.0, alpha=1.0):
169 | if im1.size() != im2.size():
170 | raise ValueError("im1 and im2 have to be the same resolution.")
171 |
172 | if alpha <= 0 or np.random.rand(1) >= prob:
173 | return im1, im2
174 |
175 | cut_ratio = np.random.randn() * 0.01 + alpha
176 |
177 | h, w = im2.size(2), im2.size(3)
178 | ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio)
179 | cy = np.random.randint(0, h-ch+1)
180 | cx = np.random.randint(0, w-cw+1)
181 |
182 | # apply CutBlur to inside or outside
183 | if np.random.random() > 0.5:
184 | im2[..., cy:cy+ch, cx:cx+cw] = im1[..., cy:cy+ch, cx:cx+cw]
185 | else:
186 | im2_aug = im1.clone()
187 | im2_aug[..., cy:cy+ch, cx:cx+cw] = im2[..., cy:cy+ch, cx:cx+cw]
188 | im2 = im2_aug
189 |
190 | return im1, im2
191 |
192 |
193 | def cutout(im1, im2, prob=1.0, alpha=0.1):
194 | scale = im1.size(2) // im2.size(2)
195 | fsize = (im2.size(0), 1)+im2.size()[2:]
196 |
197 | if alpha <= 0 or np.random.rand(1) >= prob:
198 | fim2 = np.ones(fsize)
199 | fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device)
200 | fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest")
201 | return im1, im2, fim1, fim2
202 |
203 | fim2 = np.random.choice([0.0, 1.0], size=fsize, p=[alpha, 1-alpha])
204 | fim2 = torch.tensor(fim2, dtype=torch.float, device=im2.device)
205 | fim1 = F.interpolate(fim2, scale_factor=scale, mode="nearest")
206 |
207 | im2 *= fim2
208 |
209 | return im1, im2, fim1, fim2
210 |
211 |
212 | def rgb(im1, im2, prob=1.0):
213 | if np.random.rand(1) >= prob:
214 | return im1, im2
215 |
216 | perm = np.random.permutation(3)
217 | im1 = im1[:, perm]
218 | im2 = im2[:, perm]
219 |
220 | return im1, im2
221 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import glob
8 | import importlib
9 | import numpy as np
10 | import skimage.io as io
11 | import skimage.color as color
12 | import torch
13 | import utils
14 |
15 | def generate_loader(phase, opt):
16 | cname = opt.dataset.replace("_", "")
17 | if "DIV2K" in opt.dataset:
18 | mname = importlib.import_module("data.div2k")
19 | elif "RealSR" in opt.dataset:
20 | mname = importlib.import_module("data.realsr")
21 | elif "SR" in opt.dataset: # SR benchmark datasets
22 | mname = importlib.import_module("data.benchmark")
23 | cname = "BenchmarkSR"
24 | elif "DN" in opt.dataset: # DN benchmark datasets
25 | mname = importlib.import_module("data.benchmark")
26 | cname = "BenchmarkSR"
27 | elif "JPEG" in opt.dataset: # JPEG benchmark datasets
28 | mname = importlib.import_module("data.benchmark")
29 | cname = "BenchmarkSR"
30 | else:
31 | raise ValueError("Unsupported dataset: {}".format(opt.dataset))
32 |
33 | kwargs = {
34 | "batch_size": opt.batch_size if phase == "train" else 1,
35 | "num_workers": opt.num_workers if phase == "train" else 0,
36 | "shuffle": phase == "train",
37 | "drop_last": phase == "train",
38 | }
39 |
40 | dataset = getattr(mname, cname)(phase, opt)
41 | return torch.utils.data.DataLoader(dataset, **kwargs)
42 |
43 |
44 | class BaseDataset(torch.utils.data.Dataset):
45 | def __init__(self, phase, opt):
46 | print("Load dataset... (phase: {}, len: {})".format(phase, len(self.HQ_paths)))
47 | self.HQ, self.LQ = list(), list()
48 | for HQ_path, LQ_path in zip(self.HQ_paths, self.LQ_paths):
49 | self.HQ += [io.imread(HQ_path)]
50 | self.LQ += [io.imread(LQ_path)]
51 |
52 | self.phase = phase
53 | self.opt = opt
54 |
55 | def __getitem__(self, index):
56 | # follow the setup of EDSR-pytorch
57 | if self.phase == "train":
58 | index = index % len(self.HQ)
59 |
60 | def im2tensor(im):
61 | np_t = np.ascontiguousarray(im.transpose((2, 0, 1)))
62 | tensor = torch.from_numpy(np_t).float()
63 | return tensor
64 |
65 | HQ, LQ = self.HQ[index], self.LQ[index]
66 | if len(HQ.shape) < 3:
67 | HQ = color.gray2rgb(HQ)
68 | if len(LQ.shape) < 3:
69 | LQ = color.gray2rgb(LQ)
70 |
71 | if self.phase == "train":
72 | inp_scale = HQ.shape[0] // LQ.shape[0]
73 | HQ, LQ = utils.crop(HQ, LQ, self.opt.patch_size, inp_scale)
74 | HQ, LQ = utils.flip_and_rotate(HQ, LQ)
75 | return im2tensor(HQ), im2tensor(LQ)
76 |
77 | def __len__(self):
78 | # follow the setup of EDSR-pytorch
79 | if self.phase == "train":
80 | return (1000 * self.opt.batch_size) // len(self.HQ) * len(self.HQ)
81 | return len(self.HQ)
82 |
--------------------------------------------------------------------------------
/data/benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import glob
8 | import data
9 |
10 | class BenchmarkSR(data.BaseDataset):
11 | def __init__(self, phase, opt):
12 | root = opt.dataset_root
13 |
14 | self.scale = opt.scale
15 | dir_HQ, dir_LQ = self.get_subdir()
16 | self.HQ_paths = sorted(glob.glob(os.path.join(root, dir_HQ, "*.png")))
17 | self.LQ_paths = sorted(glob.glob(os.path.join(root, dir_LQ, "*.png")))
18 |
19 | super().__init__(phase, opt)
20 |
21 | def get_subdir(self):
22 | dir_HQ = "HR"
23 | dir_LQ = "X{}".format(self.scale)
24 | return dir_HQ, dir_LQ
25 |
26 |
27 | class BenchmarkDN(BenchmarkSR):
28 | def __init__(self, phase, opt):
29 | self.sigma = opt.sigma
30 |
31 | super().__init__(phase, opt)
32 |
33 | def get_subdir(self):
34 | dir_HQ = "HQ"
35 | dir_LQ = "{}".format(self.sigma)
36 | return dir_HQ, dir_LQ
37 |
38 |
39 | class BenchmarkJPEG(BenchmarkSR):
40 | def __init__(self, phase, opt):
41 | self.quality = opt.quality
42 |
43 | super().__init__(phase, opt)
44 |
45 | def get_subdir(self):
46 | dir_HQ = "HQ"
47 | dir_LQ = "{}".format(self.quality)
48 | return dir_HQ, dir_LQ
49 |
--------------------------------------------------------------------------------
/data/div2k.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import glob
8 | import data
9 |
10 | class DIV2KSR(data.BaseDataset):
11 | def __init__(self, phase, opt):
12 | root = opt.dataset_root
13 |
14 | self.scale = opt.scale
15 | dir_HQ, dir_LQ = self.get_subdir()
16 | self.HQ_paths = sorted(glob.glob(os.path.join(root, dir_HQ, "*.png")))
17 | self.LQ_paths = sorted(glob.glob(os.path.join(root, dir_LQ, "*.png")))
18 |
19 | split = [int(n) for n in opt.div2k_range.replace("/", "-").split("-")]
20 | if phase == "train":
21 | s = slice(split[0]-1, split[1])
22 | self.HQ_paths, self.LQ_paths = self.HQ_paths[s], self.LQ_paths[s]
23 | else:
24 | s = slice(split[2]-1, split[3])
25 | self.HQ_paths, self.LQ_paths = self.HQ_paths[s], self.LQ_paths[s]
26 |
27 | super().__init__(phase, opt)
28 |
29 | def get_subdir(self):
30 | dir_HQ = "DIV2K_train_HR"
31 | dir_LQ = "DIV2K_train_LR_bicubic/X{}".format(self.scale)
32 | return dir_HQ, dir_LQ
33 |
34 |
35 | class DIV2KDN(DIV2KSR):
36 | def __init__(self, phase, opt):
37 | self.sigma = opt.sigma
38 |
39 | super().__init__(phase, opt)
40 |
41 | def get_subdir(self):
42 | dir_HQ = "DIV2K_train_HR"
43 | dir_LQ = "DIV2K_train_DN/{}".format(self.sigma)
44 | return dir_HQ, dir_LQ
45 |
46 |
47 | class DIV2KJPEG(DIV2KSR):
48 | def __init__(self, phase, opt):
49 | self.quality = opt.quality
50 |
51 | super().__init__(phase, opt)
52 |
53 | def get_subdir(self):
54 | dir_HQ = "DIV2K_train_HR"
55 | dir_LQ = "DIV2K_train_JPEG/{}".format(self.quality)
56 | return dir_HQ, dir_LQ
57 |
--------------------------------------------------------------------------------
/data/realsr.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import glob
8 | import data
9 |
10 | class RealSR(data.BaseDataset):
11 | def __init__(self, phase, opt):
12 | root = opt.dataset_root
13 |
14 | self.scale = opt.scale
15 |
16 | subdir = "Train" if phase == "train" else "Test"
17 | path_canon_all = sorted(glob.glob(os.path.join(
18 | root, "Canon", subdir, str(self.scale), "*.png"
19 | )))
20 | path_nikon_all = sorted(glob.glob(os.path.join(
21 | root, "Nikon", subdir, str(self.scale), "*.png"
22 | )))
23 |
24 | path_canon_HR = [p for p in path_canon_all if "HR" in p]
25 | path_canon_LR = [p for p in path_canon_all if "LR" in p]
26 | path_nikon_HR = [p for p in path_nikon_all if "HR" in p]
27 | path_nikon_LR = [p for p in path_nikon_all if "LR" in p]
28 |
29 | if opt.camera == "canon":
30 | self.HQ_paths = path_canon_HR
31 | self.LQ_paths = path_canon_LR
32 | elif opt.camera == "nikon":
33 | self.HQ_paths = path_nikon_HR
34 | self.LQ_paths = path_nikon_LR
35 | elif opt.camera == "all":
36 | self.HQ_paths = path_canon_HR+path_nikon_HR
37 | self.LQ_paths = path_canon_LR+path_nikon_LR
38 | else:
39 | raise ValueError("camera must be one of the [canon, nikon, all].")
40 |
41 | super().__init__(phase, opt)
42 |
--------------------------------------------------------------------------------
/example/inputs/0869.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/0869.png
--------------------------------------------------------------------------------
/example/inputs/Canon_003_HR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Canon_003_HR.png
--------------------------------------------------------------------------------
/example/inputs/Canon_003_LR4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Canon_003_LR4.png
--------------------------------------------------------------------------------
/example/inputs/Nikon_006_HR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/cutblur/2448f8bb42e2705d54f0c273e0952664f4f98754/example/inputs/Nikon_006_HR.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import glob
8 | import importlib
9 | from tqdm import tqdm
10 | import numpy as np
11 | import skimage.io as io
12 | import skimage.color as color
13 | import torch
14 | import torch.nn.functional as F
15 | import option
16 |
17 | def im2tensor(im):
18 | np_t = np.ascontiguousarray(im.transpose((2, 0, 1)))
19 | tensor = torch.from_numpy(np_t).float()
20 | return tensor
21 |
22 |
23 | @torch.no_grad()
24 | def main(opt):
25 | os.makedirs(opt.save_root, exist_ok=True)
26 |
27 | dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | module = importlib.import_module("model.{}".format(opt.model.lower()))
29 | net = module.Net(opt).to(dev)
30 |
31 | state_dict = torch.load(opt.pretrain, map_location=lambda storage, loc: storage)
32 | net.load_state_dict(state_dict)
33 |
34 | paths = sorted(glob.glob(os.path.join(opt.dataset_root, "*.png")))
35 | for path in tqdm(paths):
36 | name = path.split("/")[-1]
37 |
38 | LR = color.gray2rgb(io.imread(path))
39 | LR = im2tensor(LR).unsqueeze(0).to(dev)
40 | LR = F.interpolate(LR, scale_factor=opt.scale, mode="nearest")
41 |
42 | SR = net(LR).detach()
43 | SR = SR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy()
44 |
45 | save_path = os.path.join(opt.save_root, name)
46 | io.imsave(save_path, SR)
47 |
48 |
49 | if __name__ == "__main__":
50 | opt = option.get_option()
51 | main(opt)
52 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import json
7 | import importlib
8 | import torch
9 | from option import get_option
10 | from solver import Solver
11 |
12 | def main():
13 | opt = get_option()
14 | torch.manual_seed(opt.seed)
15 |
16 | module = importlib.import_module("model.{}".format(opt.model.lower()))
17 |
18 | if not opt.test_only:
19 | print(json.dumps(vars(opt), indent=4))
20 |
21 | solver = Solver(module, opt)
22 | if opt.test_only:
23 | print("Evaluate {} (loaded from {})".format(opt.model, opt.pretrain))
24 | psnr = solver.evaluate()
25 | print("{:.2f}".format(psnr))
26 | else:
27 | solver.fit()
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/model/carn.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | Referenced from PCARN-pytorch, https://github.com/nmhkahn/PCARN-pytorch
6 | """
7 | import torch
8 | import torch.nn as nn
9 | from model import ops
10 |
11 | class Group(nn.Module):
12 | def __init__(self, num_channels, num_blocks, res_scale=1.0):
13 | super().__init__()
14 |
15 | for nb in range(num_blocks):
16 | setattr(self,
17 | "b{}".format(nb+1),
18 | ops.ResBlock(num_channels, res_scale)
19 | )
20 | setattr(self,
21 | "c{}".format(nb+1),
22 | nn.Conv2d(num_channels*(nb+2), num_channels, 1, 1, 0)
23 | )
24 | self.num_blocks = num_blocks
25 |
26 | def forward(self, x):
27 | c = out = x
28 | for nb in range(self.num_blocks):
29 | unit_b = getattr(self, "b{}".format(nb+1))
30 | unit_c = getattr(self, "c{}".format(nb+1))
31 |
32 | b = unit_b(out)
33 | c = torch.cat([c, b], dim=1)
34 | out = unit_c(c)
35 |
36 | return out
37 |
38 |
39 | class Net(nn.Module):
40 | def __init__(self, opt):
41 | super().__init__()
42 |
43 | self.sub_mean = ops.MeanShift(255)
44 | self.add_mean = ops.MeanShift(255, sign=1)
45 |
46 | head = [
47 | ops.DownBlock(opt.scale),
48 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1)
49 | ]
50 |
51 | # define body module
52 | for ng in range(opt.num_groups):
53 | setattr(self,
54 | "c{}".format(ng+1),
55 | nn.Conv2d(opt.num_channels*(ng+2), opt.num_channels, 1, 1, 0)
56 | )
57 | setattr(self,
58 | "b{}".format(ng+1),
59 | Group(opt.num_channels, opt.num_blocks)
60 | )
61 |
62 | tail = [
63 | ops.Upsampler(opt.num_channels, opt.scale),
64 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1)
65 | ]
66 |
67 | self.head = nn.Sequential(*head)
68 | self.tail = nn.Sequential(*tail)
69 |
70 | self.opt = opt
71 |
72 | def forward(self, x):
73 | x = self.sub_mean(x)
74 | x = self.head(x)
75 |
76 | c = out = x
77 | for ng in range(self.opt.num_groups):
78 | group = getattr(self, "b{}".format(ng+1))
79 | conv = getattr(self, "c{}".format(ng+1))
80 |
81 | g = group(out)
82 | c = torch.cat([c, g], dim=1)
83 | out = conv(c)
84 | res = out
85 | res += x
86 |
87 | x = self.tail(res)
88 | x = self.add_mean(x)
89 |
90 | return x
91 |
--------------------------------------------------------------------------------
/model/edsr.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | Referenced from EDSR-PyTorch, https://github.com/thstkdgus35/EDSR-PyTorch
6 | """
7 | import torch.nn as nn
8 | from model import ops
9 |
10 | class Net(nn.Module):
11 | def __init__(self, opt):
12 | super().__init__()
13 |
14 | self.sub_mean = ops.MeanShift(255)
15 | self.add_mean = ops.MeanShift(255, sign=1)
16 |
17 | head = [
18 | ops.DownBlock(opt.scale),
19 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1)
20 | ]
21 |
22 | body = list()
23 | for _ in range(opt.num_blocks):
24 | body += [ops.ResBlock(opt.num_channels, opt.res_scale)]
25 | body += [nn.Conv2d(opt.num_channels, opt.num_channels, 3, 1, 1)]
26 |
27 | tail = [
28 | ops.Upsampler(opt.num_channels, opt.scale),
29 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1)
30 | ]
31 |
32 | self.head = nn.Sequential(*head)
33 | self.body = nn.Sequential(*body)
34 | self.tail = nn.Sequential(*tail)
35 |
36 | self.opt = opt
37 |
38 | def forward(self, x):
39 | x = self.sub_mean(x)
40 | x = self.head(x)
41 |
42 | res = self.body(x)
43 | res += x
44 |
45 | x = self.tail(res)
46 | x = self.add_mean(x)
47 |
48 | return x
49 |
--------------------------------------------------------------------------------
/model/ops.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import math
7 | import torch
8 | import torch.nn as nn
9 |
10 | class MeanShift(nn.Conv2d):
11 | def __init__(
12 | self,
13 | rgb_range, sign=-1,
14 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0),
15 | ):
16 | super().__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
20 | for p in self.parameters():
21 | p.requires_grad = False
22 |
23 |
24 | class ResBlock(nn.Module):
25 | def __init__(self, num_channels, res_scale=1.0):
26 | super().__init__()
27 |
28 | self.body = nn.Sequential(
29 | nn.Conv2d(num_channels, num_channels, 3, 1, 1),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(num_channels, num_channels, 3, 1, 1)
32 | )
33 | self.res_scale = res_scale
34 |
35 | def forward(self, x):
36 | res = self.body(x).mul(self.res_scale)
37 | res += x
38 |
39 | return res
40 |
41 |
42 | class Upsampler(nn.Sequential):
43 | def __init__(self, num_channels, scale):
44 | m = list()
45 | if (scale & (scale-1)) == 0:
46 | for _ in range(int(math.log(scale, 2))):
47 | m += [nn.Conv2d(num_channels, 4*num_channels, 3, 1, 1)]
48 | m.append(nn.PixelShuffle(2))
49 | elif scale == 3:
50 | m += [nn.Conv2d(num_channels, 9*num_channels, 3, 1, 1)]
51 | m.append(nn.PixelShuffle(3))
52 | else:
53 | raise NotImplementedError
54 |
55 | super().__init__(*m)
56 |
57 |
58 | class DownBlock(nn.Module):
59 | def __init__(self, scale):
60 | super().__init__()
61 |
62 | self.scale = scale
63 |
64 | def forward(self, x):
65 | n, c, h, w = x.size()
66 | x = x.view(n, c, h//self.scale, self.scale, w//self.scale, self.scale)
67 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
68 | x = x.view(n, c * (self.scale**2), h//self.scale, w//self.scale)
69 | return x
70 |
--------------------------------------------------------------------------------
/model/rcan.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | Referenced from EDSR-PyTorch, https://github.com/thstkdgus35/EDSR-PyTorch
6 | """
7 | import torch.nn as nn
8 | from model import ops
9 |
10 | class CALayer(nn.Module):
11 | def __init__(self, num_channels, reduction=16):
12 | super().__init__()
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.conv_du = nn.Sequential(
15 | nn.Conv2d(num_channels, num_channels//reduction, 1, 1, 0),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(num_channels//reduction, num_channels, 1, 1, 0),
18 | nn.Sigmoid()
19 | )
20 |
21 | def forward(self, x):
22 | y = self.avg_pool(x)
23 | y = self.conv_du(y)
24 | return x * y
25 |
26 |
27 | class RCAB(nn.Module):
28 | def __init__(self, num_channels, reduction, res_scale):
29 | super().__init__()
30 |
31 | body = [
32 | nn.Conv2d(num_channels, num_channels, 3, 1, 1),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(num_channels, num_channels, 3, 1, 1),
35 | ]
36 | body.append(CALayer(num_channels, reduction))
37 |
38 | self.body = nn.Sequential(*body)
39 | self.res_scale = res_scale
40 |
41 | def forward(self, x):
42 | res = self.body(x).mul(self.res_scale)
43 | res += x
44 | return res
45 |
46 |
47 | class Group(nn.Module):
48 | def __init__(self, num_channels, num_blocks, reduction, res_scale=1.0):
49 | super().__init__()
50 |
51 | body = list()
52 | for _ in range(num_blocks):
53 | body += [RCAB(num_channels, reduction, res_scale)]
54 | body += [nn.Conv2d(num_channels, num_channels, 3, 1, 1)]
55 | self.body = nn.Sequential(*body)
56 |
57 | def forward(self, x):
58 | res = self.body(x)
59 | res += x
60 | return res
61 |
62 |
63 | class Net(nn.Module):
64 | def __init__(self, opt):
65 | super().__init__()
66 |
67 | self.sub_mean = ops.MeanShift(255)
68 | self.add_mean = ops.MeanShift(255, sign=1)
69 |
70 | head = [
71 | ops.DownBlock(opt.scale),
72 | nn.Conv2d(3*opt.scale**2, opt.num_channels, 3, 1, 1)
73 | ]
74 |
75 | body = list()
76 | for _ in range(opt.num_groups):
77 | body += [
78 | Group(opt.num_channels, opt.num_blocks, opt.reduction, opt.res_scale
79 | )]
80 | body += [nn.Conv2d(opt.num_channels, opt.num_channels, 3, 1, 1)]
81 |
82 | tail = [
83 | ops.Upsampler(opt.num_channels, opt.scale),
84 | nn.Conv2d(opt.num_channels, 3, 3, 1, 1)
85 | ]
86 |
87 | self.head = nn.Sequential(*head)
88 | self.body = nn.Sequential(*body)
89 | self.tail = nn.Sequential(*tail)
90 |
91 | self.opt = opt
92 |
93 | def forward(self, x):
94 | x = self.sub_mean(x)
95 | x = self.head(x)
96 |
97 | res = self.body(x)
98 | res += x
99 |
100 | x = self.tail(res)
101 | x = self.add_mean(x)
102 |
103 | return x
104 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import argparse
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--seed", type=int, default=1)
11 |
12 | # models
13 | parser.add_argument("--pretrain", type=str)
14 | parser.add_argument("--model", type=str, default="EDSR")
15 |
16 | # augmentations
17 | parser.add_argument("--use_moa", action="store_true")
18 | parser.add_argument("--augs", nargs="*", default=["none"])
19 | parser.add_argument("--prob", nargs="*", default=[1.0])
20 | parser.add_argument("--mix_p", nargs="*")
21 | parser.add_argument("--alpha", nargs="*", default=[1.0])
22 | parser.add_argument("--aux_prob", type=float, default=1.0)
23 | parser.add_argument("--aux_alpha", type=float, default=1.2)
24 |
25 | # dataset
26 | parser.add_argument("--dataset_root", type=str, default="")
27 | parser.add_argument("--dataset", type=str, default="DIV2K_SR")
28 | parser.add_argument("--camera", type=str, default="all") # RealSR
29 | parser.add_argument("--div2k_range", type=str, default="1-800/801-810")
30 | parser.add_argument("--scale", type=int, default=4) # SR
31 | parser.add_argument("--sigma", type=int, default=10) # DN
32 | parser.add_argument("--quality", type=int, default=10) # DeJPEG
33 | parser.add_argument("--type", type=int, default=1) # DeBlur
34 |
35 | # training setups
36 | parser.add_argument("--lr", type=float, default=1e-4)
37 | parser.add_argument("--decay", type=str, default="200-400-600")
38 | parser.add_argument("--gamma", type=int, default=0.5)
39 | parser.add_argument("--patch_size", type=int, default=48)
40 | parser.add_argument("--batch_size", type=int, default=16)
41 | parser.add_argument("--max_steps", type=int, default=700000)
42 | parser.add_argument("--eval_steps", type=int, default=1000)
43 | parser.add_argument("--num_workers", type=int, default=2)
44 | parser.add_argument("--gclip", type=int, default=0)
45 |
46 | # misc
47 | parser.add_argument("--test_only", action="store_true")
48 | parser.add_argument("--save_result", action="store_true")
49 | parser.add_argument("--ckpt_root", type=str, default="./pt")
50 | parser.add_argument("--save_root", type=str, default="./output")
51 |
52 | return parser.parse_args()
53 |
54 |
55 | def make_template(opt):
56 | opt.strict_load = opt.test_only
57 |
58 | # model
59 | if "EDSR" in opt.model:
60 | opt.num_blocks = 32
61 | opt.num_channels = 256
62 | opt.res_scale = 0.1
63 | if "RCAN" in opt.model:
64 | opt.num_groups = 10
65 | opt.num_blocks = 20
66 | opt.num_channels = 64
67 | opt.reduction = 16
68 | opt.res_scale = 1.0
69 | opt.max_steps = 1000000
70 | opt.decay = "200-400-600-800"
71 | opt.gclip = 0.5 if opt.pretrain else opt.gclip
72 | if "CARN" in opt.model:
73 | opt.num_groups = 3
74 | opt.num_blocks = 3
75 | opt.num_channels = 64
76 | opt.res_scale = 1.0
77 | opt.batch_size = 64
78 | opt.decay = "400"
79 |
80 | # training setup
81 | if "DN" in opt.dataset or "JPEG" in opt.dataset:
82 | opt.max_steps = 1000000
83 | opt.decay = "300-550-800"
84 | if "RealSR" in opt.dataset:
85 | opt.patch_size *= opt.scale # identical (LR, HR) resolution
86 |
87 | # evaluation setup
88 | opt.crop = 6 if "DIV2K" in opt.dataset else 0
89 | opt.crop += opt.scale if "SR" in opt.dataset else 4
90 |
91 | # note: we tested on color DN task
92 | if "DIV2K" in opt.dataset or "DN" in opt.dataset:
93 | opt.eval_y_only = False
94 | else:
95 | opt.eval_y_only = True
96 |
97 | # default augmentation policies
98 | if opt.use_moa:
99 | opt.augs = ["blend", "rgb", "mixup", "cutout", "cutmix", "cutmixup", "cutblur"]
100 | opt.prob = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
101 | opt.alpha = [0.6, 1.0, 1.2, 0.001, 0.7, 0.7, 0.7]
102 | opt.aux_prob, opt.aux_alpha = 1.0, 1.2
103 | opt.mix_p = None
104 |
105 | if "RealSR" in opt.dataset:
106 | opt.mix_p = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]
107 |
108 | if "DN" in opt.dataset or "JPEG" in opt.dataset:
109 | opt.prob = [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]
110 | if "CARN" in opt.model and not "RealSR" in opt.dataset:
111 | opt.prob = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
112 |
113 |
114 | def get_option():
115 | opt = parse_args()
116 | make_template(opt)
117 | return opt
118 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | torchvision==0.4.0
3 | scikit-image
4 | numpy
5 | tqdm
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import os
7 | import time
8 | import skimage.io as io
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import utils
13 | import augments
14 | from data import generate_loader
15 |
16 | class Solver():
17 | def __init__(self, module, opt):
18 | self.opt = opt
19 |
20 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 | self.net = module.Net(opt).to(self.dev)
22 | print("# params:", sum(map(lambda x: x.numel(), self.net.parameters())))
23 |
24 | if opt.pretrain:
25 | self.load(opt.pretrain)
26 |
27 | self.loss_fn = nn.L1Loss()
28 | self.optim = torch.optim.Adam(
29 | self.net.parameters(), opt.lr,
30 | betas=(0.9, 0.999), eps=1e-8
31 | )
32 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
33 | self.optim, [1000*int(d) for d in opt.decay.split("-")],
34 | gamma=opt.gamma,
35 | )
36 |
37 | if not opt.test_only:
38 | self.train_loader = generate_loader("train", opt)
39 | self.test_loader = generate_loader("test", opt)
40 |
41 | self.t1, self.t2 = None, None
42 | self.best_psnr, self.best_step = 0, 0
43 |
44 | def fit(self):
45 | opt = self.opt
46 |
47 | self.t1 = time.time()
48 | for step in range(opt.max_steps):
49 | try:
50 | inputs = next(iters)
51 | except (UnboundLocalError, StopIteration):
52 | iters = iter(self.train_loader)
53 | inputs = next(iters)
54 |
55 | HR = inputs[0].to(self.dev)
56 | LR = inputs[1].to(self.dev)
57 |
58 | # match the resolution of (LR, HR) due to CutBlur
59 | if HR.size() != LR.size():
60 | scale = HR.size(2) // LR.size(2)
61 | LR = F.interpolate(LR, scale_factor=scale, mode="nearest")
62 |
63 | HR, LR, mask, aug = augments.apply_augment(
64 | HR, LR,
65 | opt.augs, opt.prob, opt.alpha,
66 | opt.aux_alpha, opt.aux_alpha, opt.mix_p
67 | )
68 |
69 | SR = self.net(LR)
70 | if aug == "cutout":
71 | SR, HR = SR*mask, HR*mask
72 |
73 | loss = self.loss_fn(SR, HR)
74 | self.optim.zero_grad()
75 | loss.backward()
76 |
77 | if opt.gclip > 0:
78 | torch.nn.utils.clip_grad_value_(self.net.parameters(), opt.gclip)
79 |
80 | self.optim.step()
81 | self.scheduler.step()
82 |
83 | if (step+1) % opt.eval_steps == 0:
84 | self.summary_and_save(step)
85 |
86 | def summary_and_save(self, step):
87 | step, max_steps = (step+1)//1000, self.opt.max_steps//1000
88 | psnr = self.evaluate()
89 | self.t2 = time.time()
90 |
91 | if psnr >= self.best_psnr:
92 | self.best_psnr, self.best_step = psnr, step
93 | self.save(step)
94 |
95 | curr_lr = self.scheduler.get_lr()[0]
96 | eta = (self.t2-self.t1) * (max_steps-step) / 3600
97 | print("[{}K/{}K] {:.2f} (Best: {:.2f} @ {}K step) LR: {}, ETA: {:.1f} hours"
98 | .format(step, max_steps, psnr, self.best_psnr, self.best_step,
99 | curr_lr, eta))
100 |
101 | self.t1 = time.time()
102 |
103 | @torch.no_grad()
104 | def evaluate(self):
105 | opt = self.opt
106 | self.net.eval()
107 |
108 | if opt.save_result:
109 | save_root = os.path.join(opt.save_root, opt.dataset)
110 | os.makedirs(save_root, exist_ok=True)
111 |
112 | psnr = 0
113 | for i, inputs in enumerate(self.test_loader):
114 | HR = inputs[0].to(self.dev)
115 | LR = inputs[1].to(self.dev)
116 |
117 | # match the resolution of (LR, HR) due to CutBlur
118 | if HR.size() != LR.size():
119 | scale = HR.size(2) // LR.size(2)
120 | LR = F.interpolate(LR, scale_factor=scale, mode="nearest")
121 |
122 | SR = self.net(LR).detach()
123 | HR = HR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy()
124 | SR = SR[0].clamp(0, 255).round().cpu().byte().permute(1, 2, 0).numpy()
125 |
126 | if opt.save_result:
127 | save_path = os.path.join(save_root, "{:04d}.png".format(i+1))
128 | io.imsave(save_path, SR)
129 |
130 | HR = HR[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
131 | SR = SR[opt.crop:-opt.crop, opt.crop:-opt.crop, :]
132 | if opt.eval_y_only:
133 | HR = utils.rgb2ycbcr(HR)
134 | SR = utils.rgb2ycbcr(SR)
135 | psnr += utils.calculate_psnr(HR, SR)
136 |
137 | self.net.train()
138 |
139 | return psnr/len(self.test_loader)
140 |
141 | def load(self, path):
142 | state_dict = torch.load(path, map_location=lambda storage, loc: storage)
143 |
144 | if self.opt.strict_load:
145 | self.net.load_state_dict(state_dict)
146 | return
147 |
148 | # when to fine-tune the pre-trained model
149 | own_state = self.net.state_dict()
150 | for name, param in state_dict.items():
151 | if name in own_state:
152 | if isinstance(param, nn.Parameter):
153 | param = param.data
154 |
155 | try:
156 | own_state[name].copy_(param)
157 | except Exception:
158 | # head and tail modules can be different
159 | if name.find("head") == -1 and name.find("tail") == -1:
160 | raise RuntimeError(
161 | "While copying the parameter named {}, "
162 | "whose dimensions in the model are {} and "
163 | "whose dimensions in the checkpoint are {}."
164 | .format(name, own_state[name].size(), param.size())
165 | )
166 | else:
167 | raise RuntimeError(
168 | "Missing key {} in model's state_dict".format(name)
169 | )
170 |
171 | def save(self, step):
172 | os.makedirs(self.opt.ckpt_root, exist_ok=True)
173 | save_path = os.path.join(self.opt.ckpt_root, str(step)+".pt")
174 | torch.save(self.net.state_dict(), save_path)
175 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | CutBlur
3 | Copyright 2020-present NAVER corp.
4 | MIT license
5 | """
6 | import random
7 | import numpy as np
8 |
9 | def crop(HQ, LQ, psize, scale=4):
10 | h, w = LQ.shape[:-1]
11 | x = random.randrange(0, w-psize+1)
12 | y = random.randrange(0, h-psize+1)
13 |
14 | crop_HQ = HQ[y*scale:y*scale+psize*scale, x*scale:x*scale+psize*scale]
15 | crop_LQ = LQ[y:y+psize, x:x+psize]
16 |
17 | return crop_HQ.copy(), crop_LQ.copy()
18 |
19 |
20 | def flip_and_rotate(HQ, LQ):
21 | hflip = random.random() < 0.5
22 | vflip = random.random() < 0.5
23 | rot90 = random.random() < 0.5
24 |
25 | if hflip:
26 | HQ, LQ = HQ[:, ::-1, :], LQ[:, ::-1, :]
27 | if vflip:
28 | HQ, LQ = HQ[::-1, :, :], LQ[::-1, :, :]
29 | if rot90:
30 | HQ, LQ = HQ.transpose(1, 0, 2), LQ.transpose(1, 0, 2)
31 |
32 | return HQ, LQ
33 |
34 |
35 | def rgb2ycbcr(img, y_only=True):
36 | in_img_type = img.dtype
37 | img.astype(np.float32)
38 | if in_img_type != np.uint8:
39 | img *= 255.
40 |
41 | if y_only:
42 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
43 | else:
44 | rlt = np.matmul(
45 | img,
46 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
47 | [24.966, 112.0, -18.214]]
48 | ) / 255.0 + [16, 128, 128]
49 | if in_img_type == np.uint8:
50 | rlt = rlt.round()
51 | else:
52 | rlt /= 255.
53 | return rlt.astype(in_img_type)
54 |
55 |
56 | def calculate_psnr(img1, img2):
57 | img1 = img1.astype(np.float64)
58 | img2 = img2.astype(np.float64)
59 |
60 | mse = np.mean((img1 - img2)**2)
61 | if mse == 0:
62 | return float("inf")
63 | return 20 * np.log10(255.0 / np.sqrt(mse))
64 |
--------------------------------------------------------------------------------