├── FLOPs
├── __pycache__
│ ├── count_hooks.cpython-35.pyc
│ └── profile.cpython-35.pyc
├── count_hooks.py
└── profile.py
├── README.md
├── TF_model
└── imdn_rtc_time.tflite
├── Test_Datasets
├── RealSR
│ ├── ValidationGT
│ │ ├── cam1_01.png
│ │ ├── cam1_010.png
│ │ ├── cam1_02.png
│ │ ├── cam1_03.png
│ │ ├── cam1_04.png
│ │ ├── cam1_05.png
│ │ ├── cam1_06.png
│ │ ├── cam1_07.png
│ │ ├── cam1_08.png
│ │ ├── cam1_09.png
│ │ ├── cam2_01.png
│ │ ├── cam2_010.png
│ │ ├── cam2_02.png
│ │ ├── cam2_03.png
│ │ ├── cam2_04.png
│ │ ├── cam2_05.png
│ │ ├── cam2_06.png
│ │ ├── cam2_07.png
│ │ ├── cam2_08.png
│ │ └── cam2_09.png
│ └── ValidationLR
│ │ ├── cam1_01.png
│ │ ├── cam1_010.png
│ │ ├── cam1_02.png
│ │ ├── cam1_03.png
│ │ ├── cam1_04.png
│ │ ├── cam1_05.png
│ │ ├── cam1_06.png
│ │ ├── cam1_07.png
│ │ ├── cam1_08.png
│ │ ├── cam1_09.png
│ │ ├── cam2_01.png
│ │ ├── cam2_010.png
│ │ ├── cam2_02.png
│ │ ├── cam2_03.png
│ │ ├── cam2_04.png
│ │ ├── cam2_05.png
│ │ ├── cam2_06.png
│ │ ├── cam2_07.png
│ │ ├── cam2_08.png
│ │ └── cam2_09.png
├── Set5
│ ├── baby.bmp
│ ├── bird.bmp
│ ├── butterfly.bmp
│ ├── head.bmp
│ └── woman.bmp
└── Set5_LR
│ ├── x2
│ ├── babyx2.bmp
│ ├── birdx2.bmp
│ ├── butterflyx2.bmp
│ ├── headx2.bmp
│ └── womanx2.bmp
│ ├── x3
│ ├── babyx3.bmp
│ ├── birdx3.bmp
│ ├── butterflyx3.bmp
│ ├── headx3.bmp
│ └── womanx3.bmp
│ └── x4
│ ├── babyx4.bmp
│ ├── birdx4.bmp
│ ├── butterflyx4.bmp
│ ├── headx4.bmp
│ └── womanx4.bmp
├── calc_FLOPs.py
├── checkpoints
├── IMDN_AS.pth
├── IMDN_x2.pth
├── IMDN_x3.pth
├── IMDN_x4.pth
├── model_RTC.pth
└── model_RTE.pth
├── data
├── DIV2K.py
├── RealSR.py
├── Set5_val.py
├── common.py
└── image_folder.py
├── images
├── Pressure_test.png
├── acmmm19_poster.pdf
├── adaptive_cropping.png
├── imdb_plus.png
├── imdn_rtc.jpg
├── lenna.png
├── memory.png
├── parameters.png
├── psnr_ssim.png
├── reparam.png
└── time.png
├── model
├── __pycache__
│ ├── architecture.cpython-35.pyc
│ └── block.cpython-35.pyc
├── architecture.py
└── block.py
├── scripts
└── png2npy.py
├── test_IMDN.py
├── test_IMDN_AS.py
├── train_IMDN.py
└── utils.py
/FLOPs/__pycache__/count_hooks.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/FLOPs/__pycache__/count_hooks.cpython-35.pyc
--------------------------------------------------------------------------------
/FLOPs/__pycache__/profile.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/FLOPs/__pycache__/profile.cpython-35.pyc
--------------------------------------------------------------------------------
/FLOPs/count_hooks.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | multiply_adds = 1
7 |
8 |
9 | def count_convNd(m, x, y):
10 | x = x[0]
11 | cin = m.in_channels
12 | batch_size = x.size(0)
13 |
14 | kernel_ops = m.weight.size()[2:].numel()
15 | bias_ops = 1 if m.bias is not None else 0
16 | ops_per_element = kernel_ops + bias_ops
17 | output_elements = y.nelement()
18 |
19 | # cout x oW x oH
20 | total_ops = batch_size * cin * output_elements * ops_per_element // m.groups
21 | # total_ops = batch_size * output_elements * (cin * kernel_ops // m.groups + bias_ops)
22 | m.total_ops = torch.Tensor([int(total_ops)])
23 |
24 |
25 | def count_conv2d(m, x, y):
26 | x = x[0]
27 |
28 | cin = m.in_channels
29 | cout = m.out_channels
30 | kh, kw = m.kernel_size
31 | batch_size = x.size()[0]
32 |
33 | out_h = y.size(2)
34 | out_w = y.size(3)
35 |
36 | # ops per output element
37 | # kernel_mul = kh * kw * cin
38 | # kernel_add = kh * kw * cin - 1
39 | kernel_ops = multiply_adds * kh * kw
40 | bias_ops = 1 if m.bias is not None else 0
41 | ops_per_element = kernel_ops + bias_ops
42 |
43 | # total ops
44 | # num_out_elements = y.numel()
45 | output_elements = batch_size * out_w * out_h * cout
46 | total_ops = output_elements * ops_per_element * cin // m.groups
47 |
48 | m.total_ops = torch.Tensor([int(total_ops)])
49 |
50 |
51 | def count_convtranspose2d(m, x, y):
52 | x = x[0]
53 |
54 | cin = m.in_channels
55 | cout = m.out_channels
56 | kh, kw = m.kernel_size
57 | batch_size = x.size()[0]
58 |
59 | out_h = y.size(2)
60 | out_w = y.size(3)
61 |
62 | # ops per output element
63 | # kernel_mul = kh * kw * cin
64 | # kernel_add = kh * kw * cin - 1
65 | kernel_ops = multiply_adds * kh * kw * cin // m.groups
66 | bias_ops = 1 if m.bias is not None else 0
67 | ops_per_element = kernel_ops + bias_ops
68 |
69 | # total ops
70 | # num_out_elements = y.numel()
71 | # output_elements = batch_size * out_w * out_h * cout
72 | ops_per_element = m.weight.nelement()
73 | output_elements = y.nelement()
74 | total_ops = output_elements * ops_per_element
75 |
76 | m.total_ops = torch.Tensor([int(total_ops)])
77 |
78 |
79 | def count_bn(m, x, y):
80 | x = x[0]
81 |
82 | nelements = x.numel()
83 | # subtract, divide, gamma, beta
84 | total_ops = 4 * nelements
85 |
86 | m.total_ops = torch.Tensor([int(total_ops)])
87 |
88 |
89 | def count_relu(m, x, y):
90 | x = x[0]
91 |
92 | nelements = x.numel()
93 | total_ops = nelements
94 |
95 | m.total_ops = torch.Tensor([int(total_ops)])
96 |
97 |
98 | def count_sigmoid(m, x, y):
99 | x = x[0]
100 | nelements = x.numel()
101 |
102 | total_exp = nelements
103 | total_add = nelements
104 | total_div = nelements
105 |
106 | total_ops = total_exp + total_add + total_div
107 | m.total_ops = torch.Tensor([int(total_ops)])
108 |
109 | def count_pixelshuffle(m, x, y):
110 | x = x[0]
111 | nelements = x.numel()
112 | total_ops = nelements
113 | m.total_ops = torch.Tensor([int(total_ops)])
114 |
115 |
116 | def count_softmax(m, x, y):
117 | x = x[0]
118 |
119 | batch_size, nfeatures = x.size()
120 |
121 | total_exp = nfeatures
122 | total_add = nfeatures - 1
123 | total_div = nfeatures
124 | total_ops = batch_size * (total_exp + total_add + total_div)
125 |
126 | m.total_ops = torch.Tensor([int(total_ops)])
127 |
128 |
129 | def count_maxpool(m, x, y):
130 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
131 | num_elements = y.numel()
132 | total_ops = kernel_ops * num_elements
133 |
134 | m.total_ops = torch.Tensor([int(total_ops)])
135 |
136 |
137 | def count_adap_maxpool(m, x, y):
138 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
139 | kernel_ops = torch.prod(kernel)
140 | num_elements = y.numel()
141 | total_ops = kernel_ops * num_elements
142 |
143 | m.total_ops = torch.Tensor([int(total_ops)])
144 |
145 |
146 | def count_avgpool(m, x, y):
147 | total_add = torch.prod(torch.Tensor([m.kernel_size]))
148 | total_div = 1
149 | kernel_ops = total_add + total_div
150 | num_elements = y.numel()
151 | total_ops = kernel_ops * num_elements
152 |
153 | m.total_ops = torch.Tensor([int(total_ops)])
154 |
155 |
156 | def count_adap_avgpool(m, x, y):
157 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
158 | total_add = torch.prod(kernel)
159 | total_div = 1
160 | kernel_ops = total_add + total_div
161 | num_elements = y.numel()
162 | total_ops = kernel_ops * num_elements
163 |
164 | m.total_ops = torch.Tensor([int(total_ops)])
165 |
166 |
167 | def count_linear(m, x, y):
168 | # per output element
169 | total_mul = m.in_features
170 | total_add = m.in_features - 1
171 | num_elements = y.numel()
172 | total_ops = (total_mul + total_add) * num_elements
173 |
174 | m.total_ops = torch.Tensor([int(total_ops)])
175 |
--------------------------------------------------------------------------------
/FLOPs/profile.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.conv import _ConvNd
6 |
7 | from .count_hooks import *
8 |
9 | register_hooks = {
10 | nn.Conv1d: count_convNd,
11 | nn.Conv2d: count_convNd,
12 | nn.Conv3d: count_convNd,
13 | nn.ConvTranspose2d: count_convtranspose2d,
14 |
15 | nn.BatchNorm1d: count_bn,
16 | nn.BatchNorm2d: count_bn,
17 | nn.BatchNorm3d: count_bn,
18 |
19 | nn.ReLU: count_relu,
20 | nn.ReLU6: count_relu,
21 | nn.LeakyReLU: count_relu,
22 | nn.PReLU: count_relu,
23 |
24 | nn.MaxPool1d: count_maxpool,
25 | nn.MaxPool2d: count_maxpool,
26 | nn.MaxPool3d: count_maxpool,
27 | nn.AdaptiveMaxPool1d: count_adap_maxpool,
28 | nn.AdaptiveMaxPool2d: count_adap_maxpool,
29 | nn.AdaptiveMaxPool3d: count_adap_maxpool,
30 |
31 | nn.AvgPool1d: count_avgpool,
32 | nn.AvgPool2d: count_avgpool,
33 | nn.AvgPool3d: count_avgpool,
34 |
35 | nn.AdaptiveAvgPool1d: count_adap_avgpool,
36 | nn.AdaptiveAvgPool2d: count_adap_avgpool,
37 | nn.AdaptiveAvgPool3d: count_adap_avgpool,
38 | nn.Linear: count_linear,
39 | nn.Dropout: None,
40 | nn.PixelShuffle: count_pixelshuffle,
41 | nn.Sigmoid: count_sigmoid,
42 | }
43 |
44 |
45 | def profile(model, input_size, custom_ops={}, device="cpu"):
46 | handler_collection = []
47 |
48 | def add_hooks(m):
49 | if len(list(m.children())) > 0:
50 | return
51 |
52 | m.register_buffer('total_ops', torch.zeros(1))
53 | m.register_buffer('total_params', torch.zeros(1))
54 |
55 | for p in m.parameters():
56 | m.total_params += torch.Tensor([p.numel()])
57 |
58 | m_type = type(m)
59 | fn = None
60 |
61 | if m_type in custom_ops:
62 | fn = custom_ops[m_type]
63 | elif m_type in register_hooks:
64 | fn = register_hooks[m_type]
65 | else:
66 | print("Not implemented for ", m)
67 |
68 | if fn is not None:
69 | #print("Register FLOP counter for module %s" % str(m))
70 | handler = m.register_forward_hook(fn)
71 | handler_collection.append(handler)
72 |
73 | original_device = model.parameters().__next__().device
74 | training = model.training
75 |
76 | model.eval().to(device)
77 | model.apply(add_hooks)
78 |
79 | x = torch.zeros(input_size).to(device)
80 | with torch.no_grad():
81 | model(x)
82 |
83 | total_ops = 0
84 | total_params = 0
85 | for m in model.modules():
86 | if len(list(m.children())) > 0: # skip for non-leaf module
87 | continue
88 | total_ops += m.total_ops
89 | total_params += m.total_params
90 |
91 | total_ops = total_ops.item()
92 | total_params = total_params.item()
93 |
94 | model.train(training).to(original_device)
95 | for handler in handler_collection:
96 | handler.remove()
97 |
98 | return total_ops, total_params
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IMDN
2 | Lightweight Image Super-Resolution with Information Multi-distillation Network (ACM MM 2019)
3 |
4 | [[arXiv]](https://arxiv.org/pdf/1909.11856v1.pdf)
5 | [[Poster]](https://github.com/Zheng222/IMDN/blob/master/images/acmmm19_poster.pdf)
6 | [[ACM DL]](https://dl.acm.org/citation.cfm?id=3351084)
7 |
8 | ## :sparkles: News
9 | - Nov 26, 2021. **Add IMDN_RTC tflite model.**
10 |
11 | # [CVPR 2022 Workshop NTIRE report](https://github.com/ofsoundof/NTIRE2022_ESR)
12 | The IMDN+ got the **Second Runner-up** at NTIRE 2022 Efficient SR Challenge (Sub-Track2 - Overall Performance Track).
13 |
14 |
15 |
16 |
17 |
18 | IMDB+
19 | |
20 |
21 |
22 | structural re-parameterization
23 | |
24 |
25 |
26 |
27 |
28 | ##### Model complexity
29 | | number of parameters | 275,844 |
30 | |:---:|:---:|
31 | | FLOPs | 17.9848G (input size: 3*256*256) |
32 | | GPU memory consumption | 2893M (DIV2K test) |
33 | | number of activations | 92.7990M (input size: 3*256*256) |
34 | | runtime | 0.026783s (RTX 2080Ti, DIV2K test) |
35 |
36 |
37 |
38 | ##### PSNR / SSIM (Y channel) on 5 benchmark datasets.
39 | | Metrics | Set5 | Set14 | B100 | Urban100 | Manga109 |
40 | |:---:|:---:|:---:|:---:|:---:|:---:|
41 | | PSNR| 32.11 | 28.63 | 27.58 | 26.10 | 30.55 |
42 | | SSIM| 0.8934 | 0.7823 | 0.7358 | 0.7846 | 0.9072 |
43 |
44 |
45 |
46 | # [ICCV 2019 Workshop AIM report](https://arxiv.org/abs/1911.01249)
47 | The simplified version of IMDN won the **first place** at Contrained Super-Resolution Challenge (Track1 & Track2). The test code is available at [Google Drive](https://drive.google.com/open?id=1BQkpqp2oZUH_J_amJv33ehGjx6gvCd0L)
48 |
49 | # [AI in RTC 2019-rainbow](https://www.dcjingsai.com/common/bbs/topicDetails.html?tid=3787)
50 | The ultra lightweight version of IMDN won the **first place** at Super Resolution Algorithm Performance Comparison Challenge. (https://github.com/Zheng222/IMDN/blob/53f1dac25e8cd8e11ad65484eadf0d1e31d602fa/model/architecture.py#L79)
51 |
52 | Degradation type: Bicubic
53 |
54 | [PyTorch Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/model_RTC.pth)
55 |
56 | [Tensorflow Lite Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/imdn_rtc_time.tflite)
57 |
58 | input_shape = (1, 720, 480, 3), AI Benchmark(OPPO Find X3-Qualcomm Snapdragon 870, FP16, TFLite GPU Delegate)
59 |
60 |
61 |
62 |
63 | # [AI in RTE 2020-rainbow](https://www.dcjingsai.com/v2/news-detail.html?id=64)
64 | The down-up version of IMDN won the **second place** at Super Resolution Algorithm Performance Comparison Challenge.
65 | (https://github.com/Zheng222/IMDN/blob/53f1dac25e8cd8e11ad65484eadf0d1e31d602fa/model/architecture.py#L98)
66 |
67 | Degradation type: Downsampling + noise
68 |
69 | [Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/model_RTE.pth)
70 | # Hightlights
71 | 1. Our information multi-distillation block (IMDB) with contrast-aware attention (CCA) layer.
72 |
73 | 2. The adaptive cropping strategy (ACS) to achieve the processing images of any arbitrary size (implementing any upscaling factors using one model).
74 |
75 | 3. The exploration of factors affecting actual inference time.
76 |
77 | ## Testing
78 | Pytorch 1.1
79 | * Runing testing:
80 | ```bash
81 | # Set5 x2 IMDN
82 | python test_IMDN.py --test_hr_folder Test_Datasets/Set5/ --test_lr_folder Test_Datasets/Set5_LR/x2/ --output_folder results/Set5/x2 --checkpoint checkpoints/IMDN_x2.pth --upscale_factor 2
83 | # RealSR IMDN_AS
84 | python test_IMDN_AS.py --test_hr_folder Test_Datasets/RealSR/ValidationGT --test_lr_folder Test_Datasets/RealSR/ValidationLR/ --output_folder results/RealSR --checkpoint checkpoints/IMDN_AS.pth
85 |
86 | ```
87 | * Calculating IMDN_RTC's FLOPs and parameters, input size is 240*360
88 | ```bash
89 | python calc_FLOPs.py
90 | ```
91 |
92 | ## Training
93 | * Download [Training dataset DIV2K](https://drive.google.com/open?id=12hOYsMa8t1ErKj6PZA352icsx9mz1TwB)
94 | * Convert png file to npy file
95 | ```bash
96 | python scripts/png2npy.py --pathFrom /path/to/DIV2K/ --pathTo /path/to/DIV2K_decoded/
97 | ```
98 | * Run training x2, x3, x4 model
99 | ```bash
100 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 2 --pretrained checkpoints/IMDN_x2.pth
101 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 3 --pretrained checkpoints/IMDN_x3.pth
102 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 4 --pretrained checkpoints/IMDN_x4.pth
103 | ```
104 |
105 | ## Results
106 | [百度网盘](https://pan.baidu.com/s/1DY0Npete3WsIoFbjmgXQlw)提取码: 8yqj or
107 | [Google drive](https://drive.google.com/open?id=1GsEcpIZ7uA97D89WOGa9sWTSl4choy_O)
108 |
109 | The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to [Evaluate_PSNR_SSIM.m](https://github.com/yulunzhang/RCAN/blob/master/RCAN_TestCode/Evaluate_PSNR_SSIM.m).
110 |
111 | ## Pressure Test
112 |
113 |
114 | Pressure test for ×4 SR model.
115 |
116 |
117 | *Note: Using torch.cuda.Event() to record inference times.
118 |
119 |
120 | ## PSNR & SSIM
121 |
122 |
123 | Average PSNR/SSIM on datasets Set5, Set14, BSD100, Urban100, and Manga109.
124 |
125 |
126 | ## Memory consumption
127 |
128 |
129 | Memory Consumption (MB) and average inference time (second).
130 |
131 |
132 | ## Model parameters
133 |
134 |
135 |
136 | Trade-off between performance and number of parameters on Set5 ×4 dataset.
137 |
138 |
139 | ## Running time
140 |
141 |
142 |
143 | Trade-off between performance and running time on Set5 ×4 dataset. VDSR, DRCN, and LapSRN were implemented by MatConvNet, while DRRN, and IDN employed Caffe package. The rest EDSR-baseline, CARN, and our IMDN utilized PyTorch.
144 |
145 |
146 | ## Adaptive Cropping
147 |
148 |
149 | The diagrammatic sketch of adaptive cropping strategy (ACS). The cropped image patches in the green dotted boxes.
150 |
151 |
152 | ## Visualization of feature maps
153 |
154 |
155 | Visualization of output feature maps of the 6-th progressive refinement module (PRM).
156 |
157 |
158 | ## Citation
159 |
160 | If you find IMDN useful in your research, please consider citing:
161 |
162 | ```
163 | @inproceedings{Hui-IMDN-2019,
164 | title={Lightweight Image Super-Resolution with Information Multi-distillation Network},
165 | author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei},
166 | booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)},
167 | pages={2024--2032},
168 | year={2019}
169 | }
170 |
171 | @inproceedings{AIM19constrainedSR,
172 | title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results},
173 | author={Kai Zhang and Shuhang Gu and Radu Timofte and others},
174 | booktitle={The IEEE International Conference on Computer Vision (ICCV) Workshops},
175 | year={2019}
176 | }
177 |
178 | ```
179 |
--------------------------------------------------------------------------------
/TF_model/imdn_rtc_time.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/TF_model/imdn_rtc_time.tflite
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_01.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_010.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_02.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_03.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_04.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_05.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_06.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_07.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_08.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam1_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_09.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_01.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_010.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_02.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_03.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_04.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_05.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_06.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_07.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_08.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationGT/cam2_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_09.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_01.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_010.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_02.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_03.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_04.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_05.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_06.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_07.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_08.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam1_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_09.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_01.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_010.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_02.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_03.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_04.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_05.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_06.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_07.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_08.png
--------------------------------------------------------------------------------
/Test_Datasets/RealSR/ValidationLR/cam2_09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_09.png
--------------------------------------------------------------------------------
/Test_Datasets/Set5/baby.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/baby.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5/bird.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/bird.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5/butterfly.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/butterfly.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5/head.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/head.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5/woman.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/woman.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x2/babyx2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/babyx2.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x2/birdx2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/birdx2.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x2/butterflyx2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/butterflyx2.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x2/headx2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/headx2.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x2/womanx2.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/womanx2.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x3/babyx3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/babyx3.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x3/birdx3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/birdx3.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x3/butterflyx3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/butterflyx3.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x3/headx3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/headx3.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x3/womanx3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/womanx3.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x4/babyx4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/babyx4.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x4/birdx4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/birdx4.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x4/butterflyx4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/butterflyx4.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x4/headx4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/headx4.bmp
--------------------------------------------------------------------------------
/Test_Datasets/Set5_LR/x4/womanx4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/womanx4.bmp
--------------------------------------------------------------------------------
/calc_FLOPs.py:
--------------------------------------------------------------------------------
1 | from model import architecture
2 | from FLOPs.profile import profile
3 |
4 | width = 360
5 | height = 240
6 | model = architecture.IMDN_RTC(upscale=2)
7 | flops, params = profile(model, input_size=(1, 3, height, width))
8 | print('IMDN_light: {} x {}, flops: {:.10f} GFLOPs, params: {}'.format(height,width,flops/(1e9),params))
9 |
--------------------------------------------------------------------------------
/checkpoints/IMDN_AS.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_AS.pth
--------------------------------------------------------------------------------
/checkpoints/IMDN_x2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x2.pth
--------------------------------------------------------------------------------
/checkpoints/IMDN_x3.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x3.pth
--------------------------------------------------------------------------------
/checkpoints/IMDN_x4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x4.pth
--------------------------------------------------------------------------------
/checkpoints/model_RTC.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/model_RTC.pth
--------------------------------------------------------------------------------
/checkpoints/model_RTE.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/model_RTE.pth
--------------------------------------------------------------------------------
/data/DIV2K.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os.path
3 | import cv2
4 | import numpy as np
5 | from data import common
6 |
7 | def default_loader(path):
8 | return cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, [2, 1, 0]]
9 |
10 | def npy_loader(path):
11 | return np.load(path)
12 |
13 | IMG_EXTENSIONS = [
14 | '.png', '.npy',
15 | ]
16 |
17 | def is_image_file(filename):
18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19 |
20 | def make_dataset(dir):
21 | images = []
22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
23 |
24 | for root, _, fnames in sorted(os.walk(dir)):
25 | for fname in fnames:
26 | if is_image_file(fname):
27 | path = os.path.join(root, fname)
28 | images.append(path)
29 | return images
30 |
31 |
32 | class div2k(data.Dataset):
33 | def __init__(self, opt):
34 | self.opt = opt
35 | self.scale = self.opt.scale
36 | self.root = self.opt.root
37 | self.ext = self.opt.ext # '.png' or '.npy'(default)
38 | self.train = True if self.opt.phase == 'train' else False
39 | self.repeat = self.opt.test_every // (self.opt.n_train // self.opt.batch_size)
40 | self._set_filesystem(self.root)
41 | self.images_hr, self.images_lr = self._scan()
42 |
43 | def _set_filesystem(self, dir_data):
44 | self.root = dir_data + '/DIV2K_decoded'
45 | self.dir_hr = os.path.join(self.root, 'DIV2K_HR')
46 | self.dir_lr = os.path.join(self.root, 'DIV2K_LR_bicubic/x' + str(self.scale))
47 |
48 | def __getitem__(self, idx):
49 | lr, hr = self._load_file(idx)
50 | lr, hr = self._get_patch(lr, hr)
51 | lr, hr = common.set_channel(lr, hr, n_channels=self.opt.n_colors)
52 | lr_tensor, hr_tensor = common.np2Tensor(lr, hr, rgb_range=self.opt.rgb_range)
53 | return lr_tensor, hr_tensor
54 |
55 | def __len__(self):
56 | if self.train:
57 | return self.opt.n_train * self.repeat
58 |
59 | def _get_index(self, idx):
60 | if self.train:
61 | return idx % self.opt.n_train
62 | else:
63 | return idx
64 |
65 | def _get_patch(self, img_in, img_tar):
66 | patch_size = self.opt.patch_size
67 | scale = self.scale
68 | if self.train:
69 | img_in, img_tar = common.get_patch(
70 | img_in, img_tar, patch_size=patch_size, scale=scale)
71 | img_in, img_tar = common.augment(img_in, img_tar)
72 | else:
73 | ih, iw = img_in.shape[:2]
74 | img_tar = img_tar[0:ih * scale, 0:iw * scale, :]
75 | return img_in, img_tar
76 |
77 | def _scan(self):
78 | list_hr = sorted(make_dataset(self.dir_hr))
79 | list_lr = sorted(make_dataset(self.dir_lr))
80 | return list_hr, list_lr
81 |
82 | def _load_file(self, idx):
83 | idx = self._get_index(idx)
84 | if self.ext == '.npy':
85 | lr = npy_loader(self.images_lr[idx])
86 | hr = npy_loader(self.images_hr[idx])
87 | else:
88 | lr = default_loader(self.images_lr[idx])
89 | hr = default_loader(self.images_hr[idx])
90 | return lr, hr
91 |
--------------------------------------------------------------------------------
/data/RealSR.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import common
3 | import numpy as np
4 | import torch.utils.data as data
5 | from data.image_folder import make_dataset
6 |
7 |
8 | class srdata(data.Dataset):
9 | def __init__(self, args, train=True):
10 | self.args = args
11 | self.train = train
12 | self.root = self.args.root
13 | self.split = 'train' if train else 'test'
14 | self.scale = 1
15 | self.repeat = args.test_every // (args.n_train // args.batch_size) # 300 / (60/6) = 30
16 | self._set_filesystem(self.root)
17 | self.images_hr, self.images_lr = self._scan_npy()
18 |
19 | def _set_filesystem(self, dir_data):
20 | self.apath = dir_data + '/RealSR_decoded'
21 | self.dir_hr = os.path.join(self.apath, 'TrainGT')
22 | self.dir_lr = os.path.join(self.apath, 'TrainLR')
23 | self.ext = '.npy'
24 |
25 | def __getitem__(self, idx):
26 | lr, hr = self._load_file(idx)
27 | lr, hr = self._get_patch(lr, hr)
28 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors)
29 | lr_tensor, hr_tensor = common.np2Tensor(lr, hr, rgb_range=self.args.rgb_range)
30 | return lr_tensor, hr_tensor
31 |
32 | def __len__(self):
33 | if self.train:
34 | return self.args.n_train * self.repeat
35 | else:
36 | return self.args.n_val
37 |
38 | def _get_index(self, idx):
39 | if self.train:
40 | return idx % self.args.n_train
41 | else:
42 | return idx
43 |
44 | def _get_patch(self, img_in, img_tar):
45 | patch_size = self.args.patch_size
46 | scale = self.scale
47 | if self.train:
48 | img_in, img_tar = common.get_patch(
49 | img_in, img_tar, patch_size=patch_size, scale=scale)
50 | img_in, img_tar = common.augment(img_in, img_tar)
51 |
52 | else:
53 | ih, iw = img_in.shape[:2]
54 | img_tar = img_tar[0:ih * scale, 0:iw * scale, :]
55 |
56 | return img_in, img_tar
57 |
58 | def _scan(self):
59 | list_hr = []
60 | list_lr = []
61 | if self.train:
62 | idx_begin = 0
63 | idx_end = self.args.n_train
64 | else:
65 | idx_begin = self.args.n_train
66 | idx_end = self.args.offset_val + self.args.n_val
67 |
68 | for i in range(idx_begin + 1, idx_end + 1):
69 | filename = '{:0>6}'.format(i)
70 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
71 | list_lr.append(os.path.join(self.dir_lr, '{}x{}{}'.format(filename, self.scale, self.ext)))
72 |
73 | return list_hr, list_lr
74 |
75 | def _scan_npy(self):
76 | list_hr = sorted(make_dataset(self.dir_hr))
77 | list_lr = sorted(make_dataset(self.dir_lr))
78 | return list_hr, list_lr
79 |
80 | def _load_file(self, idx):
81 | idx = self._get_index(idx)
82 | lr = np.load(self.images_lr[idx])
83 | hr = np.load(self.images_hr[idx])
84 | return lr, hr
85 |
--------------------------------------------------------------------------------
/data/Set5_val.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from os.path import join
3 | from os import listdir
4 | from torchvision.transforms import Compose, ToTensor
5 | from PIL import Image
6 | import numpy as np
7 |
8 |
9 | def img_modcrop(image, modulo):
10 | sz = image.size
11 | w = np.int32(sz[0] / modulo) * modulo
12 | h = np.int32(sz[1] / modulo) * modulo
13 | out = image.crop((0, 0, w, h))
14 | return out
15 |
16 |
17 | def np2tensor():
18 | return Compose([
19 | ToTensor(),
20 | ])
21 |
22 |
23 | def is_image_file(filename):
24 | return any(filename.endswith(extension) for extension in [".bmp", ".png", ".jpg"])
25 |
26 |
27 | def load_image(filepath):
28 | return Image.open(filepath).convert('RGB')
29 |
30 |
31 | class DatasetFromFolderVal(data.Dataset):
32 | def __init__(self, hr_dir, lr_dir, upscale):
33 | super(DatasetFromFolderVal, self).__init__()
34 | self.hr_filenames = sorted([join(hr_dir, x) for x in listdir(hr_dir) if is_image_file(x)])
35 | self.lr_filenames = sorted([join(lr_dir, x) for x in listdir(lr_dir) if is_image_file(x)])
36 | self.upscale = upscale
37 |
38 | def __getitem__(self, index):
39 | input = load_image(self.lr_filenames[index])
40 | target = load_image(self.hr_filenames[index])
41 | input = np2tensor()(input)
42 | target = np2tensor()(img_modcrop(target, self.upscale))
43 |
44 | return input, target
45 |
46 | def __len__(self):
47 | return len(self.lr_filenames)
48 |
--------------------------------------------------------------------------------
/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 | import skimage.color as sc
5 |
6 |
7 | def get_patch(*args, patch_size, scale):
8 | ih, iw = args[0].shape[:2]
9 |
10 | tp = patch_size # target patch (HR)
11 | ip = tp // scale # input patch (LR)
12 |
13 | ix = random.randrange(0, iw - ip + 1)
14 | iy = random.randrange(0, ih - ip + 1)
15 | tx, ty = scale * ix, scale * iy
16 |
17 | ret = [
18 | args[0][iy:iy + ip, ix:ix + ip, :],
19 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
20 | ] # results
21 | return ret
22 |
23 |
24 | def set_channel(*args, n_channels=3):
25 | def _set_channel(img):
26 | if img.ndim == 2:
27 | img = np.expand_dims(img, axis=2)
28 |
29 | c = img.shape[2]
30 | if n_channels == 1 and c == 3:
31 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
32 | elif n_channels == 3 and c == 1:
33 | img = np.concatenate([img] * n_channels, 2)
34 |
35 | return img
36 |
37 | return [_set_channel(a) for a in args]
38 |
39 |
40 | def np2Tensor(*args, rgb_range):
41 | def _np2Tensor(img):
42 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
43 | tensor = torch.from_numpy(np_transpose).float()
44 | tensor.mul_(rgb_range / 255)
45 |
46 | return tensor
47 |
48 | return [_np2Tensor(a) for a in args]
49 |
50 |
51 | def augment(*args, hflip=True, rot=True):
52 | hflip = hflip and random.random() < 0.5
53 | vflip = rot and random.random() < 0.5
54 | rot90 = rot and random.random() < 0.5
55 |
56 | def _augment(img):
57 | if hflip: img = img[:, ::-1, :]
58 | if vflip: img = img[::-1, :, :]
59 | if rot90: img = img.transpose(1, 0, 2)
60 |
61 | return img
62 |
63 | return [_augment(a) for a in args]
64 |
65 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | def default_flist_reader(flist):
15 | imlist = []
16 | with open(flist, 'r') as rf:
17 | for line in rf.readlines():
18 | impath = line.strip()
19 | imlist.append(impath)
20 |
21 | return imlist
22 |
23 |
24 | IMG_EXTENSIONS = [
25 | '.jpg', '.JPG', '.jpeg', '.JPEG',
26 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.npy'
27 | ]
28 |
29 |
30 | def is_image_file(filename):
31 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
32 |
33 |
34 | def make_dataset(dir):
35 | images = []
36 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
37 |
38 | for root, _, fnames in sorted(os.walk(dir)):
39 | for fname in fnames:
40 | if is_image_file(fname):
41 | path = os.path.join(root, fname)
42 | images.append(path)
43 |
44 | return images
45 |
46 |
47 | def default_loader(path):
48 | return Image.open(path).convert('RGB')
49 |
50 |
51 | class ImageFolder(data.Dataset):
52 |
53 | def __init__(self, root, transform=None, return_paths=False,
54 | loader=default_loader):
55 | imgs = make_dataset(root)
56 | if len(imgs) == 0:
57 | raise(RuntimeError("Found 0 images in: " + root + "\n"
58 | "Supported image extensions are: " +
59 | ",".join(IMG_EXTENSIONS)))
60 |
61 | self.root = root
62 | self.imgs = imgs
63 | self.transform = transform
64 | self.return_paths = return_paths
65 | self.loader = loader
66 |
67 | def __getitem__(self, index):
68 | path = self.imgs[index]
69 | img = self.loader(path)
70 | if self.transform is not None:
71 | img = self.transform(img)
72 | if self.return_paths:
73 | return img, path
74 | else:
75 | return img
76 |
77 | def __len__(self):
78 | return len(self.imgs)
--------------------------------------------------------------------------------
/images/Pressure_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/Pressure_test.png
--------------------------------------------------------------------------------
/images/acmmm19_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/acmmm19_poster.pdf
--------------------------------------------------------------------------------
/images/adaptive_cropping.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/adaptive_cropping.png
--------------------------------------------------------------------------------
/images/imdb_plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/imdb_plus.png
--------------------------------------------------------------------------------
/images/imdn_rtc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/imdn_rtc.jpg
--------------------------------------------------------------------------------
/images/lenna.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/lenna.png
--------------------------------------------------------------------------------
/images/memory.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/memory.png
--------------------------------------------------------------------------------
/images/parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/parameters.png
--------------------------------------------------------------------------------
/images/psnr_ssim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/psnr_ssim.png
--------------------------------------------------------------------------------
/images/reparam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/reparam.png
--------------------------------------------------------------------------------
/images/time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/time.png
--------------------------------------------------------------------------------
/model/__pycache__/architecture.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/model/__pycache__/architecture.cpython-35.pyc
--------------------------------------------------------------------------------
/model/__pycache__/block.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/model/__pycache__/block.cpython-35.pyc
--------------------------------------------------------------------------------
/model/architecture.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from . import block as B
3 | import torch
4 |
5 | # For any upscale factors
6 | class IMDN_AS(nn.Module):
7 | def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4):
8 | super(IMDN_AS, self).__init__()
9 |
10 | self.fea_conv = nn.Sequential(B.conv_layer(in_nc, nf, kernel_size=3, stride=2),
11 | nn.LeakyReLU(0.05),
12 | B.conv_layer(nf, nf, kernel_size=3, stride=2))
13 |
14 | # IMDBs
15 | self.IMDB1 = B.IMDModule(in_channels=nf)
16 | self.IMDB2 = B.IMDModule(in_channels=nf)
17 | self.IMDB3 = B.IMDModule(in_channels=nf)
18 | self.IMDB4 = B.IMDModule(in_channels=nf)
19 | self.IMDB5 = B.IMDModule(in_channels=nf)
20 | self.IMDB6 = B.IMDModule(in_channels=nf)
21 | self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu')
22 |
23 | self.LR_conv = B.conv_layer(nf, nf, kernel_size=3)
24 |
25 | upsample_block = B.pixelshuffle_block
26 | self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)
27 |
28 |
29 | def forward(self, input):
30 | out_fea = self.fea_conv(input)
31 | out_B1 = self.IMDB1(out_fea)
32 | out_B2 = self.IMDB2(out_B1)
33 | out_B3 = self.IMDB3(out_B2)
34 | out_B4 = self.IMDB4(out_B3)
35 | out_B5 = self.IMDB5(out_B4)
36 | out_B6 = self.IMDB6(out_B5)
37 |
38 | out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1))
39 | out_lr = self.LR_conv(out_B) + out_fea
40 | output = self.upsampler(out_lr)
41 | return output
42 |
43 | class IMDN(nn.Module):
44 | def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4):
45 | super(IMDN, self).__init__()
46 |
47 | self.fea_conv = B.conv_layer(in_nc, nf, kernel_size=3)
48 |
49 | # IMDBs
50 | self.IMDB1 = B.IMDModule(in_channels=nf)
51 | self.IMDB2 = B.IMDModule(in_channels=nf)
52 | self.IMDB3 = B.IMDModule(in_channels=nf)
53 | self.IMDB4 = B.IMDModule(in_channels=nf)
54 | self.IMDB5 = B.IMDModule(in_channels=nf)
55 | self.IMDB6 = B.IMDModule(in_channels=nf)
56 | self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu')
57 |
58 | self.LR_conv = B.conv_layer(nf, nf, kernel_size=3)
59 |
60 | upsample_block = B.pixelshuffle_block
61 | self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)
62 |
63 |
64 | def forward(self, input):
65 | out_fea = self.fea_conv(input)
66 | out_B1 = self.IMDB1(out_fea)
67 | out_B2 = self.IMDB2(out_B1)
68 | out_B3 = self.IMDB3(out_B2)
69 | out_B4 = self.IMDB4(out_B3)
70 | out_B5 = self.IMDB5(out_B4)
71 | out_B6 = self.IMDB6(out_B5)
72 |
73 | out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1))
74 | out_lr = self.LR_conv(out_B) + out_fea
75 | output = self.upsampler(out_lr)
76 | return output
77 |
78 | # AI in RTC Image Super-Resolution Algorithm Performance Comparison Challenge (Winner solution)
79 | class IMDN_RTC(nn.Module):
80 | def __init__(self, in_nc=3, nf=12, num_modules=5, out_nc=3, upscale=2):
81 | super(IMDN_RTC, self).__init__()
82 |
83 | fea_conv = [B.conv_layer(in_nc, nf, kernel_size=3)]
84 | rb_blocks = [B.IMDModule_speed(in_channels=nf) for _ in range(num_modules)]
85 | LR_conv = B.conv_layer(nf, nf, kernel_size=1)
86 |
87 | upsample_block = B.pixelshuffle_block
88 | upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)
89 |
90 | self.model = B.sequential(*fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),
91 | *upsampler)
92 |
93 | def forward(self, input):
94 | output = self.model(input)
95 | return output
96 |
97 |
98 | class IMDN_RTE(nn.Module):
99 | def __init__(self, upscale=2, in_nc=3, nf=20, out_nc=3):
100 | super(IMDN_RTE, self).__init__()
101 | self.upscale = upscale
102 | self.fea_conv = nn.Sequential(B.conv_layer(in_nc, nf, 3),
103 | nn.ReLU(inplace=True),
104 | B.conv_layer(nf, nf, 3, stride=2, bias=False))
105 |
106 | self.block1 = IMDModule_Large(nf)
107 | self.block2 = IMDModule_Large(nf)
108 | self.block3 = IMDModule_Large(nf)
109 | self.block4 = IMDModule_Large(nf)
110 | self.block5 = IMDModule_Large(nf)
111 | self.block6 = IMDModule_Large(nf)
112 |
113 | self.LR_conv = B.conv_layer(nf, nf, 1, bias=False)
114 |
115 | self.upsampler = B.pixelshuffle_block(nf, out_nc, upscale_factor=upscale**2)
116 |
117 | def forward(self, input):
118 |
119 | fea = self.fea_conv(input)
120 | out_b1 = self.block1(fea)
121 | out_b2 = self.block2(out_b1)
122 | out_b3 = self.block3(out_b2)
123 | out_b4 = self.block4(out_b3)
124 | out_b5 = self.block5(out_b4)
125 | out_b6 = self.block6(out_b5)
126 |
127 | out_lr = self.LR_conv(out_b6) + fea
128 |
129 | output = self.upsampler(out_lr)
130 |
131 | return output
132 |
133 |
--------------------------------------------------------------------------------
/model/block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from collections import OrderedDict
3 | import torch
4 |
5 |
6 | def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
7 | padding = int((kernel_size - 1) / 2) * dilation
8 | return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias, dilation=dilation,
9 | groups=groups)
10 |
11 |
12 | def norm(norm_type, nc):
13 | norm_type = norm_type.lower()
14 | if norm_type == 'batch':
15 | layer = nn.BatchNorm2d(nc, affine=True)
16 | elif norm_type == 'instance':
17 | layer = nn.InstanceNorm2d(nc, affine=False)
18 | else:
19 | raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
20 | return layer
21 |
22 |
23 | def pad(pad_type, padding):
24 | pad_type = pad_type.lower()
25 | if padding == 0:
26 | return None
27 | if pad_type == 'reflect':
28 | layer = nn.ReflectionPad2d(padding)
29 | elif pad_type == 'replicate':
30 | layer = nn.ReplicationPad2d(padding)
31 | else:
32 | raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
33 | return layer
34 |
35 |
36 | def get_valid_padding(kernel_size, dilation):
37 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
38 | padding = (kernel_size - 1) // 2
39 | return padding
40 |
41 |
42 | def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
43 | pad_type='zero', norm_type=None, act_type='relu'):
44 | padding = get_valid_padding(kernel_size, dilation)
45 | p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
46 | padding = padding if pad_type == 'zero' else 0
47 |
48 | c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
49 | dilation=dilation, bias=bias, groups=groups)
50 | a = activation(act_type) if act_type else None
51 | n = norm(norm_type, out_nc) if norm_type else None
52 | return sequential(p, c, n, a)
53 |
54 |
55 | def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
56 | act_type = act_type.lower()
57 | if act_type == 'relu':
58 | layer = nn.ReLU(inplace)
59 | elif act_type == 'lrelu':
60 | layer = nn.LeakyReLU(neg_slope, inplace)
61 | elif act_type == 'prelu':
62 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
63 | else:
64 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
65 | return layer
66 |
67 |
68 | class ShortcutBlock(nn.Module):
69 | def __init__(self, submodule):
70 | super(ShortcutBlock, self).__init__()
71 | self.sub = submodule
72 |
73 | def forward(self, x):
74 | output = x + self.sub(x)
75 | return output
76 |
77 | def mean_channels(F):
78 | assert(F.dim() == 4)
79 | spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
80 | return spatial_sum / (F.size(2) * F.size(3))
81 |
82 | def stdv_channels(F):
83 | assert(F.dim() == 4)
84 | F_mean = mean_channels(F)
85 | F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
86 | return F_variance.pow(0.5)
87 |
88 | def sequential(*args):
89 | if len(args) == 1:
90 | if isinstance(args[0], OrderedDict):
91 | raise NotImplementedError('sequential does not support OrderedDict input.')
92 | return args[0]
93 | modules = []
94 | for module in args:
95 | if isinstance(module, nn.Sequential):
96 | for submodule in module.children():
97 | modules.append(submodule)
98 | elif isinstance(module, nn.Module):
99 | modules.append(module)
100 | return nn.Sequential(*modules)
101 |
102 | # contrast-aware channel attention module
103 | class CCALayer(nn.Module):
104 | def __init__(self, channel, reduction=16):
105 | super(CCALayer, self).__init__()
106 |
107 | self.contrast = stdv_channels
108 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
109 | self.conv_du = nn.Sequential(
110 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
111 | nn.ReLU(inplace=True),
112 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
113 | nn.Sigmoid()
114 | )
115 |
116 |
117 | def forward(self, x):
118 | y = self.contrast(x) + self.avg_pool(x)
119 | y = self.conv_du(y)
120 | return x * y
121 |
122 |
123 | class IMDModule(nn.Module):
124 | def __init__(self, in_channels, distillation_rate=0.25):
125 | super(IMDModule, self).__init__()
126 | self.distilled_channels = int(in_channels * distillation_rate)
127 | self.remaining_channels = int(in_channels - self.distilled_channels)
128 | self.c1 = conv_layer(in_channels, in_channels, 3)
129 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
130 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3)
131 | self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
132 | self.act = activation('lrelu', neg_slope=0.05)
133 | self.c5 = conv_layer(in_channels, in_channels, 1)
134 | self.cca = CCALayer(self.distilled_channels * 4)
135 |
136 | def forward(self, input):
137 | out_c1 = self.act(self.c1(input))
138 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
139 | out_c2 = self.act(self.c2(remaining_c1))
140 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
141 | out_c3 = self.act(self.c3(remaining_c2))
142 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
143 | out_c4 = self.c4(remaining_c3)
144 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
145 | out_fused = self.c5(self.cca(out)) + input
146 | return out_fused
147 |
148 | class IMDModule_speed(nn.Module):
149 | def __init__(self, in_channels, distillation_rate=0.25):
150 | super(IMDModule_speed, self).__init__()
151 | self.distilled_channels = int(in_channels * distillation_rate)
152 | self.remaining_channels = int(in_channels - self.distilled_channels)
153 | self.c1 = conv_layer(in_channels, in_channels, 3)
154 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3)
155 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3)
156 | self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3)
157 | self.act = activation('lrelu', neg_slope=0.05)
158 | self.c5 = conv_layer(self.distilled_channels * 4, in_channels, 1)
159 |
160 | def forward(self, input):
161 | out_c1 = self.act(self.c1(input))
162 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1)
163 | out_c2 = self.act(self.c2(remaining_c1))
164 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1)
165 | out_c3 = self.act(self.c3(remaining_c2))
166 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1)
167 | out_c4 = self.c4(remaining_c3)
168 |
169 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1)
170 | out_fused = self.c5(out) + input
171 | return out_fused
172 |
173 | class IMDModule_Large(nn.Module):
174 | def __init__(self, in_channels, distillation_rate=1/4):
175 | super(IMDModule_Large, self).__init__()
176 | self.distilled_channels = int(in_channels * distillation_rate) # 6
177 | self.remaining_channels = int(in_channels - self.distilled_channels) # 18
178 | self.c1 = conv_layer(in_channels, in_channels, 3, bias=False) # 24 --> 24
179 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3, bias=False) # 18 --> 24
180 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3, bias=False) # 18 --> 24
181 | self.c4 = conv_layer(self.remaining_channels, self.remaining_channels, 3, bias=False) # 15 --> 15
182 | self.c5 = conv_layer(self.remaining_channels-self.distilled_channels, self.remaining_channels-self.distilled_channels, 3, bias=False) # 10 --> 10
183 | self.c6 = conv_layer(self.distilled_channels, self.distilled_channels, 3, bias=False) # 5 --> 5
184 | self.act = activation('relu')
185 | self.c7 = conv_layer(self.distilled_channels * 6, in_channels, 1, bias=False)
186 |
187 | def forward(self, input):
188 | out_c1 = self.act(self.c1(input)) # 24 --> 24
189 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18
190 | out_c2 = self.act(self.c2(remaining_c1)) # 18 --> 24
191 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18
192 | out_c3 = self.act(self.c3(remaining_c2)) # 18 --> 24
193 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18
194 | out_c4 = self.act(self.c4(remaining_c3)) # 18 --> 18
195 | distilled_c4, remaining_c4 = torch.split(out_c4, (self.distilled_channels, self.remaining_channels-self.distilled_channels), dim=1) # 6, 12
196 | out_c5 = self.act(self.c5(remaining_c4)) # 12 --> 12
197 | distilled_c5, remaining_c5 = torch.split(out_c5, (self.distilled_channels, self.remaining_channels-self.distilled_channels*2), dim=1) # 6, 6
198 | out_c6 = self.act(self.c6(remaining_c5)) # 6 --> 6
199 |
200 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, distilled_c4, distilled_c5, out_c6], dim=1)
201 | out_fused = self.c7(out) + input
202 | return out_fused
203 |
204 | def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1):
205 | conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride)
206 | pixel_shuffle = nn.PixelShuffle(upscale_factor)
207 | return sequential(conv, pixel_shuffle)
208 |
--------------------------------------------------------------------------------
/scripts/png2npy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import skimage.io as sio
4 | import numpy as np
5 |
6 | parser = argparse.ArgumentParser(description='Pre-processing .png images')
7 | parser.add_argument('--pathFrom', default='',
8 | help='directory of images to convert')
9 | parser.add_argument('--pathTo', default='',
10 | help='directory of images to save')
11 | parser.add_argument('--split', default=True,
12 | help='save individual images')
13 | parser.add_argument('--select', default='',
14 | help='select certain path')
15 |
16 | args = parser.parse_args()
17 |
18 | for (path, dirs, files) in os.walk(args.pathFrom):
19 | print(path)
20 | targetDir = os.path.join(args.pathTo, path[len(args.pathFrom) + 1:])
21 | if len(args.select) > 0 and path.find(args.select) == -1:
22 | continue
23 |
24 | if not os.path.exists(targetDir):
25 | os.mkdir(targetDir)
26 |
27 | if len(dirs) == 0:
28 | pack = {}
29 | n = 0
30 | for fileName in files:
31 | (idx, ext) = os.path.splitext(fileName)
32 | if ext == '.png':
33 | image = sio.imread(os.path.join(path, fileName))
34 | if args.split:
35 | np.save(os.path.join(targetDir, idx + '.npy'), image)
36 | n += 1
37 | if n % 100 == 0:
38 | print('Converted ' + str(n) + ' images.')
39 |
--------------------------------------------------------------------------------
/test_IMDN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import numpy as np
5 | import utils
6 | import skimage.color as sc
7 | import cv2
8 | from model import architecture
9 | # Testing settings
10 |
11 | parser = argparse.ArgumentParser(description='IMDN')
12 | parser.add_argument("--test_hr_folder", type=str, default='Test_Datasets/Set5/',
13 | help='the folder of the target images')
14 | parser.add_argument("--test_lr_folder", type=str, default='Test_Datasets/Set5_LR/x2/',
15 | help='the folder of the input images')
16 | parser.add_argument("--output_folder", type=str, default='results/Set5/x2')
17 | parser.add_argument("--checkpoint", type=str, default='checkpoints/IMDN_x2.pth',
18 | help='checkpoint folder to use')
19 | parser.add_argument('--cuda', action='store_true', default=True,
20 | help='use cuda')
21 | parser.add_argument("--upscale_factor", type=int, default=2,
22 | help='upscaling factor')
23 | parser.add_argument("--is_y", action='store_true', default=True,
24 | help='evaluate on y channel, if False evaluate on RGB channels')
25 | opt = parser.parse_args()
26 |
27 | print(opt)
28 |
29 | cuda = opt.cuda
30 | device = torch.device('cuda' if cuda else 'cpu')
31 |
32 | filepath = opt.test_hr_folder
33 | if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14':
34 | ext = '.bmp'
35 | else:
36 | ext = '.png'
37 |
38 | filelist = utils.get_list(filepath, ext=ext)
39 | psnr_list = np.zeros(len(filelist))
40 | ssim_list = np.zeros(len(filelist))
41 | time_list = np.zeros(len(filelist))
42 |
43 | model = architecture.IMDN(upscale=opt.upscale_factor)
44 | model_dict = utils.load_state_dict(opt.checkpoint)
45 | model.load_state_dict(model_dict, strict=True)
46 |
47 | i = 0
48 | start = torch.cuda.Event(enable_timing=True)
49 | end = torch.cuda.Event(enable_timing=True)
50 |
51 | for imname in filelist:
52 | im_gt = cv2.imread(imname, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR to RGB
53 | im_gt = utils.modcrop(im_gt, opt.upscale_factor)
54 | im_l = cv2.imread(opt.test_lr_folder + imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + ext, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR to RGB
55 | if len(im_gt.shape) < 3:
56 | im_gt = im_gt[..., np.newaxis]
57 | im_gt = np.concatenate([im_gt] * 3, 2)
58 | im_l = im_l[..., np.newaxis]
59 | im_l = np.concatenate([im_l] * 3, 2)
60 | im_input = im_l / 255.0
61 | im_input = np.transpose(im_input, (2, 0, 1))
62 | im_input = im_input[np.newaxis, ...]
63 | im_input = torch.from_numpy(im_input).float()
64 |
65 | if cuda:
66 | model = model.to(device)
67 | im_input = im_input.to(device)
68 |
69 | with torch.no_grad():
70 | start.record()
71 | out = model(im_input)
72 | end.record()
73 | torch.cuda.synchronize()
74 | time_list[i] = start.elapsed_time(end) # milliseconds
75 |
76 | out_img = utils.tensor2np(out.detach()[0])
77 | crop_size = opt.upscale_factor
78 | cropped_sr_img = utils.shave(out_img, crop_size)
79 | cropped_gt_img = utils.shave(im_gt, crop_size)
80 | if opt.is_y is True:
81 | im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
82 | im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
83 | else:
84 | im_label = cropped_gt_img
85 | im_pre = cropped_sr_img
86 | psnr_list[i] = utils.compute_psnr(im_pre, im_label)
87 | ssim_list[i] = utils.compute_ssim(im_pre, im_label)
88 |
89 |
90 | output_folder = os.path.join(opt.output_folder,
91 | imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + '.png')
92 |
93 | if not os.path.exists(opt.output_folder):
94 | os.makedirs(opt.output_folder)
95 |
96 | cv2.imwrite(output_folder, out_img[:, :, [2, 1, 0]])
97 | i += 1
98 |
99 |
100 | print("Mean PSNR: {}, SSIM: {}, TIME: {} ms".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list)))
101 |
--------------------------------------------------------------------------------
/test_IMDN_AS.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import numpy as np
5 | import utils
6 | import skimage.color as sc
7 | import cv2
8 | from model import architecture
9 |
10 | parser = argparse.ArgumentParser(description='IMDN_AS')
11 | parser.add_argument("--test_hr_folder", type=str, default='Test_Datasets/RealSR/ValidationGT',
12 | help='the folder of the target images')
13 | parser.add_argument("--test_lr_folder", type=str, default='Test_Datasets/RealSR/ValidationLR/',
14 | help='the folder of the input images')
15 | parser.add_argument("--output_folder", type=str, default='results/RealSR')
16 | parser.add_argument("--checkpoint", type=str, default='checkpoints/IMDN_AS.pth',
17 | help='checkpoint folder to use')
18 | parser.add_argument('--cuda', action='store_true', default=True,
19 | help='use cuda')
20 | parser.add_argument("--is_y", action='store_true', default=False,
21 | help='evaluate on y channel, if False evaluate on RGB channels')
22 | opt = parser.parse_args()
23 |
24 | print(opt)
25 |
26 | cuda = opt.cuda
27 | device = torch.device('cuda' if cuda else 'cpu')
28 |
29 |
30 | def crop_forward(x, model, shave=32):
31 | b, c, h, w = x.size()
32 | h_half, w_half = h // 2, w // 2
33 |
34 | h_size, w_size = h_half + shave - (h_half + shave) % 4, w_half + shave - (w_half + shave) % 4
35 |
36 | inputlist = [
37 | x[:, :, 0:h_size, 0:w_size],
38 | x[:, :, 0:h_size, (w - w_size):w],
39 | x[:, :, (h - h_size):h, 0:w_size],
40 | x[:, :, (h - h_size):h, (w - w_size):w]]
41 |
42 | outputlist = []
43 |
44 | with torch.no_grad():
45 | input_batch = torch.cat(inputlist, dim=0)
46 | output_batch = model(input_batch)
47 | outputlist.extend(output_batch.chunk(4, dim=0))
48 |
49 | output = torch.zeros_like(x)
50 |
51 | output[:, :, 0:h_half, 0:w_half] \
52 | = outputlist[0][:, :, 0:h_half, 0:w_half]
53 | output[:, :, 0:h_half, w_half:w] \
54 | = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
55 | output[:, :, h_half:h, 0:w_half] \
56 | = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
57 | output[:, :, h_half:h, w_half:w] \
58 | = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
59 |
60 | return output
61 |
62 |
63 | filepath = opt.test_hr_folder
64 | if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14':
65 | ext = '.bmp'
66 | else:
67 | ext = '.png'
68 |
69 | filelist = utils.get_list(filepath, ext=ext)
70 | psnr_list = np.zeros(len(filelist))
71 | ssim_list = np.zeros(len(filelist))
72 | time_list = np.zeros(len(filelist))
73 |
74 | model = architecture.IMDN_AS()
75 | model_dict = utils.load_state_dict(opt.checkpoint)
76 | model.load_state_dict(model_dict, strict=True)
77 |
78 | start = torch.cuda.Event(enable_timing=True)
79 | end = torch.cuda.Event(enable_timing=True)
80 | i = 0
81 | for imname in filelist:
82 | im_gt = cv2.imread(imname)[:, :, [2, 1, 0]]
83 | im_l = cv2.imread(opt.test_lr_folder + imname.split('/')[-1])[:, :, [2, 1, 0]]
84 | if len(im_gt.shape) < 3:
85 | im_gt = im_gt[..., np.newaxis]
86 | im_gt = np.concatenate([im_gt] * 3, 2)
87 | im_l = im_l[..., np.newaxis]
88 | im_l = np.concatenate([im_l] * 3, 2)
89 | im_input = im_l / 255.0
90 | im_input = np.transpose(im_input, (2, 0, 1))
91 | im_input = im_input[np.newaxis, ...]
92 | im_input = torch.from_numpy(im_input).float()
93 |
94 | if cuda:
95 | model = model.to(device)
96 | im_input = im_input.to(device)
97 |
98 | _, _, h, w = im_input.size()
99 | with torch.no_grad():
100 |
101 | if h % 4 == 0 and w % 4 == 0:
102 | start.record()
103 | out = model(im_input)
104 | end.record()
105 | torch.cuda.synchronize()
106 | time_list[i] = start.elapsed_time(end) # milliseconds
107 | else:
108 | start.record()
109 | out = crop_forward(im_input, model)
110 | end.record()
111 | torch.cuda.synchronize()
112 | time_list[i] = start.elapsed_time(end) # milliseconds
113 |
114 | sr_img = utils.tensor2np(out.detach()[0])
115 | if opt.is_y is True:
116 | im_label = utils.quantize(sc.rgb2ycbcr(im_gt)[:, :, 0])
117 | im_pre = utils.quantize(sc.rgb2ycbcr(sr_img)[:, :, 0])
118 | else:
119 | im_label = im_gt
120 | im_pre = sr_img
121 | psnr_list[i] = utils.compute_psnr(im_pre, im_label)
122 | ssim_list[i] = utils.compute_ssim(im_pre, im_label)
123 |
124 | output_folder = os.path.join(opt.output_folder,
125 | imname.split('/')[-1])
126 |
127 | if not os.path.exists(opt.output_folder):
128 | os.makedirs(opt.output_folder)
129 |
130 | cv2.imwrite(output_folder, sr_img[:, :, [2, 1, 0]])
131 | i += 1
132 |
133 | print("Mean PSNR: {}, SSIM: {}, Time: {} ms".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list)))
134 |
--------------------------------------------------------------------------------
/train_IMDN.py:
--------------------------------------------------------------------------------
1 | import argparse, os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.utils.data import DataLoader
6 | from model import architecture
7 | from data import DIV2K, Set5_val
8 | import utils
9 | import skimage.color as sc
10 | import random
11 | from collections import OrderedDict
12 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
13 |
14 | # Training settings
15 | parser = argparse.ArgumentParser(description="IMDN")
16 | parser.add_argument("--batch_size", type=int, default=16,
17 | help="training batch size")
18 | parser.add_argument("--testBatchSize", type=int, default=1,
19 | help="testing batch size")
20 | parser.add_argument("-nEpochs", type=int, default=1000,
21 | help="number of epochs to train")
22 | parser.add_argument("--lr", type=float, default=2e-4,
23 | help="Learning Rate. Default=2e-4")
24 | parser.add_argument("--step_size", type=int, default=200,
25 | help="learning rate decay per N epochs")
26 | parser.add_argument("--gamma", type=int, default=0.5,
27 | help="learning rate decay factor for step decay")
28 | parser.add_argument("--cuda", action="store_true", default=True,
29 | help="use cuda")
30 | parser.add_argument("--resume", default="", type=str,
31 | help="path to checkpoint")
32 | parser.add_argument("--start-epoch", default=1, type=int,
33 | help="manual epoch number")
34 | parser.add_argument("--threads", type=int, default=8,
35 | help="number of threads for data loading")
36 | parser.add_argument("--root", type=str, default="training_data/",
37 | help='dataset directory')
38 | parser.add_argument("--n_train", type=int, default=800,
39 | help="number of training set")
40 | parser.add_argument("--n_val", type=int, default=1,
41 | help="number of validation set")
42 | parser.add_argument("--test_every", type=int, default=1000)
43 | parser.add_argument("--scale", type=int, default=2,
44 | help="super-resolution scale")
45 | parser.add_argument("--patch_size", type=int, default=192,
46 | help="output patch size")
47 | parser.add_argument("--rgb_range", type=int, default=1,
48 | help="maxium value of RGB")
49 | parser.add_argument("--n_colors", type=int, default=3,
50 | help="number of color channels to use")
51 | parser.add_argument("--pretrained", default="", type=str,
52 | help="path to pretrained models")
53 | parser.add_argument("--seed", type=int, default=1)
54 | parser.add_argument("--isY", action="store_true", default=True)
55 | parser.add_argument("--ext", type=str, default='.npy')
56 | parser.add_argument("--phase", type=str, default='train')
57 |
58 | args = parser.parse_args()
59 | print(args)
60 | torch.backends.cudnn.benchmark = True
61 | # random seed
62 | seed = args.seed
63 | if seed is None:
64 | seed = random.randint(1, 10000)
65 | print("Ramdom Seed: ", seed)
66 | random.seed(seed)
67 | torch.manual_seed(seed)
68 |
69 | cuda = args.cuda
70 | device = torch.device('cuda' if cuda else 'cpu')
71 |
72 | print("===> Loading datasets")
73 |
74 | trainset = DIV2K.div2k(args)
75 | testset = Set5_val.DatasetFromFolderVal("Test_Datasets/Set5/",
76 | "Test_Datasets/Set5_LR/x{}/".format(args.scale),
77 | args.scale)
78 | training_data_loader = DataLoader(dataset=trainset, num_workers=args.threads, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True)
79 | testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize,
80 | shuffle=False)
81 |
82 | print("===> Building models")
83 | args.is_train = True
84 |
85 | model = architecture.IMDN(upscale=args.scale)
86 | l1_criterion = nn.L1Loss()
87 |
88 | print("===> Setting GPU")
89 | if cuda:
90 | model = model.to(device)
91 | l1_criterion = l1_criterion.to(device)
92 |
93 | if args.pretrained:
94 |
95 | if os.path.isfile(args.pretrained):
96 | print("===> loading models '{}'".format(args.pretrained))
97 | checkpoint = torch.load(args.pretrained)
98 | new_state_dcit = OrderedDict()
99 | for k, v in checkpoint.items():
100 | if 'module' in k:
101 | name = k[7:]
102 | else:
103 | name = k
104 | new_state_dcit[name] = v
105 | model_dict = model.state_dict()
106 | pretrained_dict = {k: v for k, v in new_state_dcit.items() if k in model_dict}
107 |
108 | for k, v in model_dict.items():
109 | if k not in pretrained_dict:
110 | print(k)
111 | model.load_state_dict(pretrained_dict, strict=True)
112 |
113 | else:
114 | print("===> no models found at '{}'".format(args.pretrained))
115 |
116 | print("===> Setting Optimizer")
117 |
118 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
119 |
120 |
121 | def train(epoch):
122 | model.train()
123 | utils.adjust_learning_rate(optimizer, epoch, args.step_size, args.lr, args.gamma)
124 | print('epoch =', epoch, 'lr = ', optimizer.param_groups[0]['lr'])
125 | for iteration, (lr_tensor, hr_tensor) in enumerate(training_data_loader, 1):
126 |
127 | if args.cuda:
128 | lr_tensor = lr_tensor.to(device) # ranges from [0, 1]
129 | hr_tensor = hr_tensor.to(device) # ranges from [0, 1]
130 |
131 | optimizer.zero_grad()
132 | sr_tensor = model(lr_tensor)
133 | loss_l1 = l1_criterion(sr_tensor, hr_tensor)
134 | loss_sr = loss_l1
135 |
136 | loss_sr.backward()
137 | optimizer.step()
138 | if iteration % 100 == 0:
139 | print("===> Epoch[{}]({}/{}): Loss_l1: {:.5f}".format(epoch, iteration, len(training_data_loader),
140 | loss_l1.item()))
141 |
142 |
143 | def valid():
144 | model.eval()
145 |
146 | avg_psnr, avg_ssim = 0, 0
147 | for batch in testing_data_loader:
148 | lr_tensor, hr_tensor = batch[0], batch[1]
149 | if args.cuda:
150 | lr_tensor = lr_tensor.to(device)
151 | hr_tensor = hr_tensor.to(device)
152 |
153 | with torch.no_grad():
154 | pre = model(lr_tensor)
155 |
156 | sr_img = utils.tensor2np(pre.detach()[0])
157 | gt_img = utils.tensor2np(hr_tensor.detach()[0])
158 | crop_size = args.scale
159 | cropped_sr_img = utils.shave(sr_img, crop_size)
160 | cropped_gt_img = utils.shave(gt_img, crop_size)
161 | if args.isY is True:
162 | im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
163 | im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
164 | else:
165 | im_label = cropped_gt_img
166 | im_pre = cropped_sr_img
167 | avg_psnr += utils.compute_psnr(im_pre, im_label)
168 | avg_ssim += utils.compute_ssim(im_pre, im_label)
169 | print("===> Valid. psnr: {:.4f}, ssim: {:.4f}".format(avg_psnr / len(testing_data_loader), avg_ssim / len(testing_data_loader)))
170 |
171 |
172 | def save_checkpoint(epoch):
173 | model_folder = "checkpoint_x{}/".format(args.scale)
174 | model_out_path = model_folder + "epoch_{}.pth".format(epoch)
175 | if not os.path.exists(model_folder):
176 | os.makedirs(model_folder)
177 | torch.save(model.state_dict(), model_out_path)
178 | print("===> Checkpoint saved to {}".format(model_out_path))
179 |
180 | def print_network(net):
181 | num_params = 0
182 | for param in net.parameters():
183 | num_params += param.numel()
184 | print(net)
185 | print('Total number of parameters: %d' % num_params)
186 |
187 |
188 | print("===> Training")
189 | print_network(model)
190 | for epoch in range(args.start_epoch, args.nEpochs + 1):
191 | valid()
192 | train(epoch)
193 | save_checkpoint(epoch)
194 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from skimage.measure import compare_psnr as psnr
2 | from skimage.measure import compare_ssim as ssim
3 | import numpy as np
4 | import os
5 | import torch
6 | from collections import OrderedDict
7 |
8 | def compute_psnr(im1, im2):
9 | p = psnr(im1, im2)
10 | return p
11 |
12 |
13 | def compute_ssim(im1, im2):
14 | isRGB = len(im1.shape) == 3 and im1.shape[-1] == 3
15 | s = ssim(im1, im2, K1=0.01, K2=0.03, gaussian_weights=True, sigma=1.5, use_sample_covariance=False,
16 | multichannel=isRGB)
17 | return s
18 |
19 |
20 | def shave(im, border):
21 | border = [border, border]
22 | im = im[border[0]:-border[0], border[1]:-border[1], ...]
23 | return im
24 |
25 |
26 | def modcrop(im, modulo):
27 | sz = im.shape
28 | h = np.int32(sz[0] / modulo) * modulo
29 | w = np.int32(sz[1] / modulo) * modulo
30 | ims = im[0:h, 0:w, ...]
31 | return ims
32 |
33 |
34 | def get_list(path, ext):
35 | return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)]
36 |
37 |
38 | def convert_shape(img):
39 | img = np.transpose((img * 255.0).round(), (1, 2, 0))
40 | img = np.uint8(np.clip(img, 0, 255))
41 | return img
42 |
43 |
44 | def quantize(img):
45 | return img.clip(0, 255).round().astype(np.uint8)
46 |
47 |
48 | def tensor2np(tensor, out_type=np.uint8, min_max=(0, 1)):
49 | tensor = tensor.float().cpu().clamp_(*min_max)
50 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0, 1]
51 | img_np = tensor.numpy()
52 | img_np = np.transpose(img_np, (1, 2, 0))
53 | if out_type == np.uint8:
54 | img_np = (img_np * 255.0).round()
55 |
56 | return img_np.astype(out_type)
57 |
58 | def convert2np(tensor):
59 | return tensor.cpu().mul(255).clamp(0, 255).byte().squeeze().permute(1, 2, 0).numpy()
60 |
61 |
62 | def adjust_learning_rate(optimizer, epoch, step_size, lr_init, gamma):
63 | factor = epoch // step_size
64 | lr = lr_init * (gamma ** factor)
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = lr
67 |
68 | def load_state_dict(path):
69 |
70 | state_dict = torch.load(path)
71 | new_state_dcit = OrderedDict()
72 | for k, v in state_dict.items():
73 | if 'module' in k:
74 | name = k[7:]
75 | else:
76 | name = k
77 | new_state_dcit[name] = v
78 | return new_state_dcit
79 |
--------------------------------------------------------------------------------