├── FLOPs ├── __pycache__ │ ├── count_hooks.cpython-35.pyc │ └── profile.cpython-35.pyc ├── count_hooks.py └── profile.py ├── README.md ├── TF_model └── imdn_rtc_time.tflite ├── Test_Datasets ├── RealSR │ ├── ValidationGT │ │ ├── cam1_01.png │ │ ├── cam1_010.png │ │ ├── cam1_02.png │ │ ├── cam1_03.png │ │ ├── cam1_04.png │ │ ├── cam1_05.png │ │ ├── cam1_06.png │ │ ├── cam1_07.png │ │ ├── cam1_08.png │ │ ├── cam1_09.png │ │ ├── cam2_01.png │ │ ├── cam2_010.png │ │ ├── cam2_02.png │ │ ├── cam2_03.png │ │ ├── cam2_04.png │ │ ├── cam2_05.png │ │ ├── cam2_06.png │ │ ├── cam2_07.png │ │ ├── cam2_08.png │ │ └── cam2_09.png │ └── ValidationLR │ │ ├── cam1_01.png │ │ ├── cam1_010.png │ │ ├── cam1_02.png │ │ ├── cam1_03.png │ │ ├── cam1_04.png │ │ ├── cam1_05.png │ │ ├── cam1_06.png │ │ ├── cam1_07.png │ │ ├── cam1_08.png │ │ ├── cam1_09.png │ │ ├── cam2_01.png │ │ ├── cam2_010.png │ │ ├── cam2_02.png │ │ ├── cam2_03.png │ │ ├── cam2_04.png │ │ ├── cam2_05.png │ │ ├── cam2_06.png │ │ ├── cam2_07.png │ │ ├── cam2_08.png │ │ └── cam2_09.png ├── Set5 │ ├── baby.bmp │ ├── bird.bmp │ ├── butterfly.bmp │ ├── head.bmp │ └── woman.bmp └── Set5_LR │ ├── x2 │ ├── babyx2.bmp │ ├── birdx2.bmp │ ├── butterflyx2.bmp │ ├── headx2.bmp │ └── womanx2.bmp │ ├── x3 │ ├── babyx3.bmp │ ├── birdx3.bmp │ ├── butterflyx3.bmp │ ├── headx3.bmp │ └── womanx3.bmp │ └── x4 │ ├── babyx4.bmp │ ├── birdx4.bmp │ ├── butterflyx4.bmp │ ├── headx4.bmp │ └── womanx4.bmp ├── calc_FLOPs.py ├── checkpoints ├── IMDN_AS.pth ├── IMDN_x2.pth ├── IMDN_x3.pth ├── IMDN_x4.pth ├── model_RTC.pth └── model_RTE.pth ├── data ├── DIV2K.py ├── RealSR.py ├── Set5_val.py ├── common.py └── image_folder.py ├── images ├── Pressure_test.png ├── acmmm19_poster.pdf ├── adaptive_cropping.png ├── imdb_plus.png ├── imdn_rtc.jpg ├── lenna.png ├── memory.png ├── parameters.png ├── psnr_ssim.png ├── reparam.png └── time.png ├── model ├── __pycache__ │ ├── architecture.cpython-35.pyc │ └── block.cpython-35.pyc ├── architecture.py └── block.py ├── scripts └── png2npy.py ├── test_IMDN.py ├── test_IMDN_AS.py ├── train_IMDN.py └── utils.py /FLOPs/__pycache__/count_hooks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/FLOPs/__pycache__/count_hooks.cpython-35.pyc -------------------------------------------------------------------------------- /FLOPs/__pycache__/profile.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/FLOPs/__pycache__/profile.cpython-35.pyc -------------------------------------------------------------------------------- /FLOPs/count_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | multiply_adds = 1 7 | 8 | 9 | def count_convNd(m, x, y): 10 | x = x[0] 11 | cin = m.in_channels 12 | batch_size = x.size(0) 13 | 14 | kernel_ops = m.weight.size()[2:].numel() 15 | bias_ops = 1 if m.bias is not None else 0 16 | ops_per_element = kernel_ops + bias_ops 17 | output_elements = y.nelement() 18 | 19 | # cout x oW x oH 20 | total_ops = batch_size * cin * output_elements * ops_per_element // m.groups 21 | # total_ops = batch_size * output_elements * (cin * kernel_ops // m.groups + bias_ops) 22 | m.total_ops = torch.Tensor([int(total_ops)]) 23 | 24 | 25 | def count_conv2d(m, x, y): 26 | x = x[0] 27 | 28 | cin = m.in_channels 29 | cout = m.out_channels 30 | kh, kw = m.kernel_size 31 | batch_size = x.size()[0] 32 | 33 | out_h = y.size(2) 34 | out_w = y.size(3) 35 | 36 | # ops per output element 37 | # kernel_mul = kh * kw * cin 38 | # kernel_add = kh * kw * cin - 1 39 | kernel_ops = multiply_adds * kh * kw 40 | bias_ops = 1 if m.bias is not None else 0 41 | ops_per_element = kernel_ops + bias_ops 42 | 43 | # total ops 44 | # num_out_elements = y.numel() 45 | output_elements = batch_size * out_w * out_h * cout 46 | total_ops = output_elements * ops_per_element * cin // m.groups 47 | 48 | m.total_ops = torch.Tensor([int(total_ops)]) 49 | 50 | 51 | def count_convtranspose2d(m, x, y): 52 | x = x[0] 53 | 54 | cin = m.in_channels 55 | cout = m.out_channels 56 | kh, kw = m.kernel_size 57 | batch_size = x.size()[0] 58 | 59 | out_h = y.size(2) 60 | out_w = y.size(3) 61 | 62 | # ops per output element 63 | # kernel_mul = kh * kw * cin 64 | # kernel_add = kh * kw * cin - 1 65 | kernel_ops = multiply_adds * kh * kw * cin // m.groups 66 | bias_ops = 1 if m.bias is not None else 0 67 | ops_per_element = kernel_ops + bias_ops 68 | 69 | # total ops 70 | # num_out_elements = y.numel() 71 | # output_elements = batch_size * out_w * out_h * cout 72 | ops_per_element = m.weight.nelement() 73 | output_elements = y.nelement() 74 | total_ops = output_elements * ops_per_element 75 | 76 | m.total_ops = torch.Tensor([int(total_ops)]) 77 | 78 | 79 | def count_bn(m, x, y): 80 | x = x[0] 81 | 82 | nelements = x.numel() 83 | # subtract, divide, gamma, beta 84 | total_ops = 4 * nelements 85 | 86 | m.total_ops = torch.Tensor([int(total_ops)]) 87 | 88 | 89 | def count_relu(m, x, y): 90 | x = x[0] 91 | 92 | nelements = x.numel() 93 | total_ops = nelements 94 | 95 | m.total_ops = torch.Tensor([int(total_ops)]) 96 | 97 | 98 | def count_sigmoid(m, x, y): 99 | x = x[0] 100 | nelements = x.numel() 101 | 102 | total_exp = nelements 103 | total_add = nelements 104 | total_div = nelements 105 | 106 | total_ops = total_exp + total_add + total_div 107 | m.total_ops = torch.Tensor([int(total_ops)]) 108 | 109 | def count_pixelshuffle(m, x, y): 110 | x = x[0] 111 | nelements = x.numel() 112 | total_ops = nelements 113 | m.total_ops = torch.Tensor([int(total_ops)]) 114 | 115 | 116 | def count_softmax(m, x, y): 117 | x = x[0] 118 | 119 | batch_size, nfeatures = x.size() 120 | 121 | total_exp = nfeatures 122 | total_add = nfeatures - 1 123 | total_div = nfeatures 124 | total_ops = batch_size * (total_exp + total_add + total_div) 125 | 126 | m.total_ops = torch.Tensor([int(total_ops)]) 127 | 128 | 129 | def count_maxpool(m, x, y): 130 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) 131 | num_elements = y.numel() 132 | total_ops = kernel_ops * num_elements 133 | 134 | m.total_ops = torch.Tensor([int(total_ops)]) 135 | 136 | 137 | def count_adap_maxpool(m, x, y): 138 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 139 | kernel_ops = torch.prod(kernel) 140 | num_elements = y.numel() 141 | total_ops = kernel_ops * num_elements 142 | 143 | m.total_ops = torch.Tensor([int(total_ops)]) 144 | 145 | 146 | def count_avgpool(m, x, y): 147 | total_add = torch.prod(torch.Tensor([m.kernel_size])) 148 | total_div = 1 149 | kernel_ops = total_add + total_div 150 | num_elements = y.numel() 151 | total_ops = kernel_ops * num_elements 152 | 153 | m.total_ops = torch.Tensor([int(total_ops)]) 154 | 155 | 156 | def count_adap_avgpool(m, x, y): 157 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 158 | total_add = torch.prod(kernel) 159 | total_div = 1 160 | kernel_ops = total_add + total_div 161 | num_elements = y.numel() 162 | total_ops = kernel_ops * num_elements 163 | 164 | m.total_ops = torch.Tensor([int(total_ops)]) 165 | 166 | 167 | def count_linear(m, x, y): 168 | # per output element 169 | total_mul = m.in_features 170 | total_add = m.in_features - 1 171 | num_elements = y.numel() 172 | total_ops = (total_mul + total_add) * num_elements 173 | 174 | m.total_ops = torch.Tensor([int(total_ops)]) 175 | -------------------------------------------------------------------------------- /FLOPs/profile.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.conv import _ConvNd 6 | 7 | from .count_hooks import * 8 | 9 | register_hooks = { 10 | nn.Conv1d: count_convNd, 11 | nn.Conv2d: count_convNd, 12 | nn.Conv3d: count_convNd, 13 | nn.ConvTranspose2d: count_convtranspose2d, 14 | 15 | nn.BatchNorm1d: count_bn, 16 | nn.BatchNorm2d: count_bn, 17 | nn.BatchNorm3d: count_bn, 18 | 19 | nn.ReLU: count_relu, 20 | nn.ReLU6: count_relu, 21 | nn.LeakyReLU: count_relu, 22 | nn.PReLU: count_relu, 23 | 24 | nn.MaxPool1d: count_maxpool, 25 | nn.MaxPool2d: count_maxpool, 26 | nn.MaxPool3d: count_maxpool, 27 | nn.AdaptiveMaxPool1d: count_adap_maxpool, 28 | nn.AdaptiveMaxPool2d: count_adap_maxpool, 29 | nn.AdaptiveMaxPool3d: count_adap_maxpool, 30 | 31 | nn.AvgPool1d: count_avgpool, 32 | nn.AvgPool2d: count_avgpool, 33 | nn.AvgPool3d: count_avgpool, 34 | 35 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 36 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 37 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 38 | nn.Linear: count_linear, 39 | nn.Dropout: None, 40 | nn.PixelShuffle: count_pixelshuffle, 41 | nn.Sigmoid: count_sigmoid, 42 | } 43 | 44 | 45 | def profile(model, input_size, custom_ops={}, device="cpu"): 46 | handler_collection = [] 47 | 48 | def add_hooks(m): 49 | if len(list(m.children())) > 0: 50 | return 51 | 52 | m.register_buffer('total_ops', torch.zeros(1)) 53 | m.register_buffer('total_params', torch.zeros(1)) 54 | 55 | for p in m.parameters(): 56 | m.total_params += torch.Tensor([p.numel()]) 57 | 58 | m_type = type(m) 59 | fn = None 60 | 61 | if m_type in custom_ops: 62 | fn = custom_ops[m_type] 63 | elif m_type in register_hooks: 64 | fn = register_hooks[m_type] 65 | else: 66 | print("Not implemented for ", m) 67 | 68 | if fn is not None: 69 | #print("Register FLOP counter for module %s" % str(m)) 70 | handler = m.register_forward_hook(fn) 71 | handler_collection.append(handler) 72 | 73 | original_device = model.parameters().__next__().device 74 | training = model.training 75 | 76 | model.eval().to(device) 77 | model.apply(add_hooks) 78 | 79 | x = torch.zeros(input_size).to(device) 80 | with torch.no_grad(): 81 | model(x) 82 | 83 | total_ops = 0 84 | total_params = 0 85 | for m in model.modules(): 86 | if len(list(m.children())) > 0: # skip for non-leaf module 87 | continue 88 | total_ops += m.total_ops 89 | total_params += m.total_params 90 | 91 | total_ops = total_ops.item() 92 | total_params = total_params.item() 93 | 94 | model.train(training).to(original_device) 95 | for handler in handler_collection: 96 | handler.remove() 97 | 98 | return total_ops, total_params 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IMDN 2 | Lightweight Image Super-Resolution with Information Multi-distillation Network (ACM MM 2019) 3 | 4 | [[arXiv]](https://arxiv.org/pdf/1909.11856v1.pdf) 5 | [[Poster]](https://github.com/Zheng222/IMDN/blob/master/images/acmmm19_poster.pdf) 6 | [[ACM DL]](https://dl.acm.org/citation.cfm?id=3351084) 7 | 8 | ## :sparkles: News 9 | - Nov 26, 2021. **Add IMDN_RTC tflite model.** 10 | 11 | # [CVPR 2022 Workshop NTIRE report](https://github.com/ofsoundof/NTIRE2022_ESR) 12 | The IMDN+ got the **Second Runner-up** at NTIRE 2022 Efficient SR Challenge (Sub-Track2 - Overall Performance Track). 13 | 14 | 15 | 16 | 20 | 24 | 25 |
17 |
18 | IMDB+ 19 |
21 |
22 | structural re-parameterization 23 |
26 | 27 | 28 | ##### Model complexity 29 | | number of parameters | 275,844 | 30 | |:---:|:---:| 31 | | FLOPs | 17.9848G (input size: 3*256*256) | 32 | | GPU memory consumption | 2893M (DIV2K test) | 33 | | number of activations | 92.7990M (input size: 3*256*256) | 34 | | runtime | 0.026783s (RTX 2080Ti, DIV2K test) | 35 | 36 | 37 | 38 | ##### PSNR / SSIM (Y channel) on 5 benchmark datasets. 39 | | Metrics | Set5 | Set14 | B100 | Urban100 | Manga109 | 40 | |:---:|:---:|:---:|:---:|:---:|:---:| 41 | | PSNR| 32.11 | 28.63 | 27.58 | 26.10 | 30.55 | 42 | | SSIM| 0.8934 | 0.7823 | 0.7358 | 0.7846 | 0.9072 | 43 | 44 | 45 | 46 | # [ICCV 2019 Workshop AIM report](https://arxiv.org/abs/1911.01249) 47 | The simplified version of IMDN won the **first place** at Contrained Super-Resolution Challenge (Track1 & Track2). The test code is available at [Google Drive](https://drive.google.com/open?id=1BQkpqp2oZUH_J_amJv33ehGjx6gvCd0L) 48 | 49 | # [AI in RTC 2019-rainbow](https://www.dcjingsai.com/common/bbs/topicDetails.html?tid=3787) 50 | The ultra lightweight version of IMDN won the **first place** at Super Resolution Algorithm Performance Comparison Challenge. (https://github.com/Zheng222/IMDN/blob/53f1dac25e8cd8e11ad65484eadf0d1e31d602fa/model/architecture.py#L79) 51 | 52 | Degradation type: Bicubic 53 | 54 | [PyTorch Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/model_RTC.pth) 55 | 56 | [Tensorflow Lite Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/imdn_rtc_time.tflite) 57 | 58 | input_shape = (1, 720, 480, 3), AI Benchmark(OPPO Find X3-Qualcomm Snapdragon 870, FP16, TFLite GPU Delegate) 59 | 60 | 61 | 62 | 63 | # [AI in RTE 2020-rainbow](https://www.dcjingsai.com/v2/news-detail.html?id=64) 64 | The down-up version of IMDN won the **second place** at Super Resolution Algorithm Performance Comparison Challenge. 65 | (https://github.com/Zheng222/IMDN/blob/53f1dac25e8cd8e11ad65484eadf0d1e31d602fa/model/architecture.py#L98) 66 | 67 | Degradation type: Downsampling + noise 68 | 69 | [Checkpoint](https://github.com/Zheng222/IMDN/blob/master/checkpoints/model_RTE.pth) 70 | # Hightlights 71 | 1. Our information multi-distillation block (IMDB) with contrast-aware attention (CCA) layer. 72 | 73 | 2. The adaptive cropping strategy (ACS) to achieve the processing images of any arbitrary size (implementing any upscaling factors using one model). 74 | 75 | 3. The exploration of factors affecting actual inference time. 76 | 77 | ## Testing 78 | Pytorch 1.1 79 | * Runing testing: 80 | ```bash 81 | # Set5 x2 IMDN 82 | python test_IMDN.py --test_hr_folder Test_Datasets/Set5/ --test_lr_folder Test_Datasets/Set5_LR/x2/ --output_folder results/Set5/x2 --checkpoint checkpoints/IMDN_x2.pth --upscale_factor 2 83 | # RealSR IMDN_AS 84 | python test_IMDN_AS.py --test_hr_folder Test_Datasets/RealSR/ValidationGT --test_lr_folder Test_Datasets/RealSR/ValidationLR/ --output_folder results/RealSR --checkpoint checkpoints/IMDN_AS.pth 85 | 86 | ``` 87 | * Calculating IMDN_RTC's FLOPs and parameters, input size is 240*360 88 | ```bash 89 | python calc_FLOPs.py 90 | ``` 91 | 92 | ## Training 93 | * Download [Training dataset DIV2K](https://drive.google.com/open?id=12hOYsMa8t1ErKj6PZA352icsx9mz1TwB) 94 | * Convert png file to npy file 95 | ```bash 96 | python scripts/png2npy.py --pathFrom /path/to/DIV2K/ --pathTo /path/to/DIV2K_decoded/ 97 | ``` 98 | * Run training x2, x3, x4 model 99 | ```bash 100 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 2 --pretrained checkpoints/IMDN_x2.pth 101 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 3 --pretrained checkpoints/IMDN_x3.pth 102 | python train_IMDN.py --root /path/to/DIV2K_decoded/ --scale 4 --pretrained checkpoints/IMDN_x4.pth 103 | ``` 104 | 105 | ## Results 106 | [百度网盘](https://pan.baidu.com/s/1DY0Npete3WsIoFbjmgXQlw)提取码: 8yqj or 107 | [Google drive](https://drive.google.com/open?id=1GsEcpIZ7uA97D89WOGa9sWTSl4choy_O) 108 | 109 | The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to [Evaluate_PSNR_SSIM.m](https://github.com/yulunzhang/RCAN/blob/master/RCAN_TestCode/Evaluate_PSNR_SSIM.m). 110 | 111 | ## Pressure Test 112 |

113 |
114 | Pressure test for ×4 SR model. 115 |

116 | 117 | *Note: Using torch.cuda.Event() to record inference times. 118 | 119 | 120 | ## PSNR & SSIM 121 |

122 |
123 | Average PSNR/SSIM on datasets Set5, Set14, BSD100, Urban100, and Manga109. 124 |

125 | 126 | ## Memory consumption 127 |

128 |
129 | Memory Consumption (MB) and average inference time (second). 130 |

131 | 132 | ## Model parameters 133 | 134 |

135 |
136 | Trade-off between performance and number of parameters on Set5 ×4 dataset. 137 |

138 | 139 | ## Running time 140 | 141 |

142 |
143 | Trade-off between performance and running time on Set5 ×4 dataset. VDSR, DRCN, and LapSRN were implemented by MatConvNet, while DRRN, and IDN employed Caffe package. The rest EDSR-baseline, CARN, and our IMDN utilized PyTorch. 144 |

145 | 146 | ## Adaptive Cropping 147 |

148 |
149 | The diagrammatic sketch of adaptive cropping strategy (ACS). The cropped image patches in the green dotted boxes. 150 |

151 | 152 | ## Visualization of feature maps 153 |

154 |
155 | Visualization of output feature maps of the 6-th progressive refinement module (PRM). 156 |

157 | 158 | ## Citation 159 | 160 | If you find IMDN useful in your research, please consider citing: 161 | 162 | ``` 163 | @inproceedings{Hui-IMDN-2019, 164 | title={Lightweight Image Super-Resolution with Information Multi-distillation Network}, 165 | author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei}, 166 | booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)}, 167 | pages={2024--2032}, 168 | year={2019} 169 | } 170 | 171 | @inproceedings{AIM19constrainedSR, 172 | title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results}, 173 | author={Kai Zhang and Shuhang Gu and Radu Timofte and others}, 174 | booktitle={The IEEE International Conference on Computer Vision (ICCV) Workshops}, 175 | year={2019} 176 | } 177 | 178 | ``` 179 | -------------------------------------------------------------------------------- /TF_model/imdn_rtc_time.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/TF_model/imdn_rtc_time.tflite -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_01.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_010.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_02.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_03.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_04.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_05.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_06.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_07.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_08.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam1_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam1_09.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_01.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_010.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_02.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_03.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_04.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_05.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_06.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_07.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_08.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationGT/cam2_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationGT/cam2_09.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_01.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_010.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_02.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_03.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_04.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_05.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_06.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_07.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_08.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam1_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam1_09.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_01.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_010.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_02.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_03.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_04.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_05.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_06.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_07.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_08.png -------------------------------------------------------------------------------- /Test_Datasets/RealSR/ValidationLR/cam2_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/RealSR/ValidationLR/cam2_09.png -------------------------------------------------------------------------------- /Test_Datasets/Set5/baby.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/baby.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5/bird.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/bird.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5/butterfly.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/butterfly.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5/head.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/head.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5/woman.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5/woman.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x2/babyx2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/babyx2.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x2/birdx2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/birdx2.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x2/butterflyx2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/butterflyx2.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x2/headx2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/headx2.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x2/womanx2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x2/womanx2.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x3/babyx3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/babyx3.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x3/birdx3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/birdx3.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x3/butterflyx3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/butterflyx3.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x3/headx3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/headx3.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x3/womanx3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x3/womanx3.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x4/babyx4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/babyx4.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x4/birdx4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/birdx4.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x4/butterflyx4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/butterflyx4.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x4/headx4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/headx4.bmp -------------------------------------------------------------------------------- /Test_Datasets/Set5_LR/x4/womanx4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/Test_Datasets/Set5_LR/x4/womanx4.bmp -------------------------------------------------------------------------------- /calc_FLOPs.py: -------------------------------------------------------------------------------- 1 | from model import architecture 2 | from FLOPs.profile import profile 3 | 4 | width = 360 5 | height = 240 6 | model = architecture.IMDN_RTC(upscale=2) 7 | flops, params = profile(model, input_size=(1, 3, height, width)) 8 | print('IMDN_light: {} x {}, flops: {:.10f} GFLOPs, params: {}'.format(height,width,flops/(1e9),params)) 9 | -------------------------------------------------------------------------------- /checkpoints/IMDN_AS.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_AS.pth -------------------------------------------------------------------------------- /checkpoints/IMDN_x2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x2.pth -------------------------------------------------------------------------------- /checkpoints/IMDN_x3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x3.pth -------------------------------------------------------------------------------- /checkpoints/IMDN_x4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/IMDN_x4.pth -------------------------------------------------------------------------------- /checkpoints/model_RTC.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/model_RTC.pth -------------------------------------------------------------------------------- /checkpoints/model_RTE.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/checkpoints/model_RTE.pth -------------------------------------------------------------------------------- /data/DIV2K.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os.path 3 | import cv2 4 | import numpy as np 5 | from data import common 6 | 7 | def default_loader(path): 8 | return cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, [2, 1, 0]] 9 | 10 | def npy_loader(path): 11 | return np.load(path) 12 | 13 | IMG_EXTENSIONS = [ 14 | '.png', '.npy', 15 | ] 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | def make_dataset(dir): 21 | images = [] 22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 23 | 24 | for root, _, fnames in sorted(os.walk(dir)): 25 | for fname in fnames: 26 | if is_image_file(fname): 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | return images 30 | 31 | 32 | class div2k(data.Dataset): 33 | def __init__(self, opt): 34 | self.opt = opt 35 | self.scale = self.opt.scale 36 | self.root = self.opt.root 37 | self.ext = self.opt.ext # '.png' or '.npy'(default) 38 | self.train = True if self.opt.phase == 'train' else False 39 | self.repeat = self.opt.test_every // (self.opt.n_train // self.opt.batch_size) 40 | self._set_filesystem(self.root) 41 | self.images_hr, self.images_lr = self._scan() 42 | 43 | def _set_filesystem(self, dir_data): 44 | self.root = dir_data + '/DIV2K_decoded' 45 | self.dir_hr = os.path.join(self.root, 'DIV2K_HR') 46 | self.dir_lr = os.path.join(self.root, 'DIV2K_LR_bicubic/x' + str(self.scale)) 47 | 48 | def __getitem__(self, idx): 49 | lr, hr = self._load_file(idx) 50 | lr, hr = self._get_patch(lr, hr) 51 | lr, hr = common.set_channel(lr, hr, n_channels=self.opt.n_colors) 52 | lr_tensor, hr_tensor = common.np2Tensor(lr, hr, rgb_range=self.opt.rgb_range) 53 | return lr_tensor, hr_tensor 54 | 55 | def __len__(self): 56 | if self.train: 57 | return self.opt.n_train * self.repeat 58 | 59 | def _get_index(self, idx): 60 | if self.train: 61 | return idx % self.opt.n_train 62 | else: 63 | return idx 64 | 65 | def _get_patch(self, img_in, img_tar): 66 | patch_size = self.opt.patch_size 67 | scale = self.scale 68 | if self.train: 69 | img_in, img_tar = common.get_patch( 70 | img_in, img_tar, patch_size=patch_size, scale=scale) 71 | img_in, img_tar = common.augment(img_in, img_tar) 72 | else: 73 | ih, iw = img_in.shape[:2] 74 | img_tar = img_tar[0:ih * scale, 0:iw * scale, :] 75 | return img_in, img_tar 76 | 77 | def _scan(self): 78 | list_hr = sorted(make_dataset(self.dir_hr)) 79 | list_lr = sorted(make_dataset(self.dir_lr)) 80 | return list_hr, list_lr 81 | 82 | def _load_file(self, idx): 83 | idx = self._get_index(idx) 84 | if self.ext == '.npy': 85 | lr = npy_loader(self.images_lr[idx]) 86 | hr = npy_loader(self.images_hr[idx]) 87 | else: 88 | lr = default_loader(self.images_lr[idx]) 89 | hr = default_loader(self.images_hr[idx]) 90 | return lr, hr 91 | -------------------------------------------------------------------------------- /data/RealSR.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import common 3 | import numpy as np 4 | import torch.utils.data as data 5 | from data.image_folder import make_dataset 6 | 7 | 8 | class srdata(data.Dataset): 9 | def __init__(self, args, train=True): 10 | self.args = args 11 | self.train = train 12 | self.root = self.args.root 13 | self.split = 'train' if train else 'test' 14 | self.scale = 1 15 | self.repeat = args.test_every // (args.n_train // args.batch_size) # 300 / (60/6) = 30 16 | self._set_filesystem(self.root) 17 | self.images_hr, self.images_lr = self._scan_npy() 18 | 19 | def _set_filesystem(self, dir_data): 20 | self.apath = dir_data + '/RealSR_decoded' 21 | self.dir_hr = os.path.join(self.apath, 'TrainGT') 22 | self.dir_lr = os.path.join(self.apath, 'TrainLR') 23 | self.ext = '.npy' 24 | 25 | def __getitem__(self, idx): 26 | lr, hr = self._load_file(idx) 27 | lr, hr = self._get_patch(lr, hr) 28 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors) 29 | lr_tensor, hr_tensor = common.np2Tensor(lr, hr, rgb_range=self.args.rgb_range) 30 | return lr_tensor, hr_tensor 31 | 32 | def __len__(self): 33 | if self.train: 34 | return self.args.n_train * self.repeat 35 | else: 36 | return self.args.n_val 37 | 38 | def _get_index(self, idx): 39 | if self.train: 40 | return idx % self.args.n_train 41 | else: 42 | return idx 43 | 44 | def _get_patch(self, img_in, img_tar): 45 | patch_size = self.args.patch_size 46 | scale = self.scale 47 | if self.train: 48 | img_in, img_tar = common.get_patch( 49 | img_in, img_tar, patch_size=patch_size, scale=scale) 50 | img_in, img_tar = common.augment(img_in, img_tar) 51 | 52 | else: 53 | ih, iw = img_in.shape[:2] 54 | img_tar = img_tar[0:ih * scale, 0:iw * scale, :] 55 | 56 | return img_in, img_tar 57 | 58 | def _scan(self): 59 | list_hr = [] 60 | list_lr = [] 61 | if self.train: 62 | idx_begin = 0 63 | idx_end = self.args.n_train 64 | else: 65 | idx_begin = self.args.n_train 66 | idx_end = self.args.offset_val + self.args.n_val 67 | 68 | for i in range(idx_begin + 1, idx_end + 1): 69 | filename = '{:0>6}'.format(i) 70 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 71 | list_lr.append(os.path.join(self.dir_lr, '{}x{}{}'.format(filename, self.scale, self.ext))) 72 | 73 | return list_hr, list_lr 74 | 75 | def _scan_npy(self): 76 | list_hr = sorted(make_dataset(self.dir_hr)) 77 | list_lr = sorted(make_dataset(self.dir_lr)) 78 | return list_hr, list_lr 79 | 80 | def _load_file(self, idx): 81 | idx = self._get_index(idx) 82 | lr = np.load(self.images_lr[idx]) 83 | hr = np.load(self.images_hr[idx]) 84 | return lr, hr 85 | -------------------------------------------------------------------------------- /data/Set5_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from os.path import join 3 | from os import listdir 4 | from torchvision.transforms import Compose, ToTensor 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | def img_modcrop(image, modulo): 10 | sz = image.size 11 | w = np.int32(sz[0] / modulo) * modulo 12 | h = np.int32(sz[1] / modulo) * modulo 13 | out = image.crop((0, 0, w, h)) 14 | return out 15 | 16 | 17 | def np2tensor(): 18 | return Compose([ 19 | ToTensor(), 20 | ]) 21 | 22 | 23 | def is_image_file(filename): 24 | return any(filename.endswith(extension) for extension in [".bmp", ".png", ".jpg"]) 25 | 26 | 27 | def load_image(filepath): 28 | return Image.open(filepath).convert('RGB') 29 | 30 | 31 | class DatasetFromFolderVal(data.Dataset): 32 | def __init__(self, hr_dir, lr_dir, upscale): 33 | super(DatasetFromFolderVal, self).__init__() 34 | self.hr_filenames = sorted([join(hr_dir, x) for x in listdir(hr_dir) if is_image_file(x)]) 35 | self.lr_filenames = sorted([join(lr_dir, x) for x in listdir(lr_dir) if is_image_file(x)]) 36 | self.upscale = upscale 37 | 38 | def __getitem__(self, index): 39 | input = load_image(self.lr_filenames[index]) 40 | target = load_image(self.hr_filenames[index]) 41 | input = np2tensor()(input) 42 | target = np2tensor()(img_modcrop(target, self.upscale)) 43 | 44 | return input, target 45 | 46 | def __len__(self): 47 | return len(self.lr_filenames) 48 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | 7 | def get_patch(*args, patch_size, scale): 8 | ih, iw = args[0].shape[:2] 9 | 10 | tp = patch_size # target patch (HR) 11 | ip = tp // scale # input patch (LR) 12 | 13 | ix = random.randrange(0, iw - ip + 1) 14 | iy = random.randrange(0, ih - ip + 1) 15 | tx, ty = scale * ix, scale * iy 16 | 17 | ret = [ 18 | args[0][iy:iy + ip, ix:ix + ip, :], 19 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 20 | ] # results 21 | return ret 22 | 23 | 24 | def set_channel(*args, n_channels=3): 25 | def _set_channel(img): 26 | if img.ndim == 2: 27 | img = np.expand_dims(img, axis=2) 28 | 29 | c = img.shape[2] 30 | if n_channels == 1 and c == 3: 31 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 32 | elif n_channels == 3 and c == 1: 33 | img = np.concatenate([img] * n_channels, 2) 34 | 35 | return img 36 | 37 | return [_set_channel(a) for a in args] 38 | 39 | 40 | def np2Tensor(*args, rgb_range): 41 | def _np2Tensor(img): 42 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 43 | tensor = torch.from_numpy(np_transpose).float() 44 | tensor.mul_(rgb_range / 255) 45 | 46 | return tensor 47 | 48 | return [_np2Tensor(a) for a in args] 49 | 50 | 51 | def augment(*args, hflip=True, rot=True): 52 | hflip = hflip and random.random() < 0.5 53 | vflip = rot and random.random() < 0.5 54 | rot90 = rot and random.random() < 0.5 55 | 56 | def _augment(img): 57 | if hflip: img = img[:, ::-1, :] 58 | if vflip: img = img[::-1, :, :] 59 | if rot90: img = img.transpose(1, 0, 2) 60 | 61 | return img 62 | 63 | return [_augment(a) for a in args] 64 | 65 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | def default_flist_reader(flist): 15 | imlist = [] 16 | with open(flist, 'r') as rf: 17 | for line in rf.readlines(): 18 | impath = line.strip() 19 | imlist.append(impath) 20 | 21 | return imlist 22 | 23 | 24 | IMG_EXTENSIONS = [ 25 | '.jpg', '.JPG', '.jpeg', '.JPEG', 26 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.npy' 27 | ] 28 | 29 | 30 | def is_image_file(filename): 31 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 32 | 33 | 34 | def make_dataset(dir): 35 | images = [] 36 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 37 | 38 | for root, _, fnames in sorted(os.walk(dir)): 39 | for fname in fnames: 40 | if is_image_file(fname): 41 | path = os.path.join(root, fname) 42 | images.append(path) 43 | 44 | return images 45 | 46 | 47 | def default_loader(path): 48 | return Image.open(path).convert('RGB') 49 | 50 | 51 | class ImageFolder(data.Dataset): 52 | 53 | def __init__(self, root, transform=None, return_paths=False, 54 | loader=default_loader): 55 | imgs = make_dataset(root) 56 | if len(imgs) == 0: 57 | raise(RuntimeError("Found 0 images in: " + root + "\n" 58 | "Supported image extensions are: " + 59 | ",".join(IMG_EXTENSIONS))) 60 | 61 | self.root = root 62 | self.imgs = imgs 63 | self.transform = transform 64 | self.return_paths = return_paths 65 | self.loader = loader 66 | 67 | def __getitem__(self, index): 68 | path = self.imgs[index] 69 | img = self.loader(path) 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | if self.return_paths: 73 | return img, path 74 | else: 75 | return img 76 | 77 | def __len__(self): 78 | return len(self.imgs) -------------------------------------------------------------------------------- /images/Pressure_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/Pressure_test.png -------------------------------------------------------------------------------- /images/acmmm19_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/acmmm19_poster.pdf -------------------------------------------------------------------------------- /images/adaptive_cropping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/adaptive_cropping.png -------------------------------------------------------------------------------- /images/imdb_plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/imdb_plus.png -------------------------------------------------------------------------------- /images/imdn_rtc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/imdn_rtc.jpg -------------------------------------------------------------------------------- /images/lenna.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/lenna.png -------------------------------------------------------------------------------- /images/memory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/memory.png -------------------------------------------------------------------------------- /images/parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/parameters.png -------------------------------------------------------------------------------- /images/psnr_ssim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/psnr_ssim.png -------------------------------------------------------------------------------- /images/reparam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/reparam.png -------------------------------------------------------------------------------- /images/time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/images/time.png -------------------------------------------------------------------------------- /model/__pycache__/architecture.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/model/__pycache__/architecture.cpython-35.pyc -------------------------------------------------------------------------------- /model/__pycache__/block.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IMDN/8f158e6a5ac9db6e5857d9159fd4a6c4214da574/model/__pycache__/block.cpython-35.pyc -------------------------------------------------------------------------------- /model/architecture.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from . import block as B 3 | import torch 4 | 5 | # For any upscale factors 6 | class IMDN_AS(nn.Module): 7 | def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4): 8 | super(IMDN_AS, self).__init__() 9 | 10 | self.fea_conv = nn.Sequential(B.conv_layer(in_nc, nf, kernel_size=3, stride=2), 11 | nn.LeakyReLU(0.05), 12 | B.conv_layer(nf, nf, kernel_size=3, stride=2)) 13 | 14 | # IMDBs 15 | self.IMDB1 = B.IMDModule(in_channels=nf) 16 | self.IMDB2 = B.IMDModule(in_channels=nf) 17 | self.IMDB3 = B.IMDModule(in_channels=nf) 18 | self.IMDB4 = B.IMDModule(in_channels=nf) 19 | self.IMDB5 = B.IMDModule(in_channels=nf) 20 | self.IMDB6 = B.IMDModule(in_channels=nf) 21 | self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu') 22 | 23 | self.LR_conv = B.conv_layer(nf, nf, kernel_size=3) 24 | 25 | upsample_block = B.pixelshuffle_block 26 | self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale) 27 | 28 | 29 | def forward(self, input): 30 | out_fea = self.fea_conv(input) 31 | out_B1 = self.IMDB1(out_fea) 32 | out_B2 = self.IMDB2(out_B1) 33 | out_B3 = self.IMDB3(out_B2) 34 | out_B4 = self.IMDB4(out_B3) 35 | out_B5 = self.IMDB5(out_B4) 36 | out_B6 = self.IMDB6(out_B5) 37 | 38 | out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1)) 39 | out_lr = self.LR_conv(out_B) + out_fea 40 | output = self.upsampler(out_lr) 41 | return output 42 | 43 | class IMDN(nn.Module): 44 | def __init__(self, in_nc=3, nf=64, num_modules=6, out_nc=3, upscale=4): 45 | super(IMDN, self).__init__() 46 | 47 | self.fea_conv = B.conv_layer(in_nc, nf, kernel_size=3) 48 | 49 | # IMDBs 50 | self.IMDB1 = B.IMDModule(in_channels=nf) 51 | self.IMDB2 = B.IMDModule(in_channels=nf) 52 | self.IMDB3 = B.IMDModule(in_channels=nf) 53 | self.IMDB4 = B.IMDModule(in_channels=nf) 54 | self.IMDB5 = B.IMDModule(in_channels=nf) 55 | self.IMDB6 = B.IMDModule(in_channels=nf) 56 | self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu') 57 | 58 | self.LR_conv = B.conv_layer(nf, nf, kernel_size=3) 59 | 60 | upsample_block = B.pixelshuffle_block 61 | self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale) 62 | 63 | 64 | def forward(self, input): 65 | out_fea = self.fea_conv(input) 66 | out_B1 = self.IMDB1(out_fea) 67 | out_B2 = self.IMDB2(out_B1) 68 | out_B3 = self.IMDB3(out_B2) 69 | out_B4 = self.IMDB4(out_B3) 70 | out_B5 = self.IMDB5(out_B4) 71 | out_B6 = self.IMDB6(out_B5) 72 | 73 | out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1)) 74 | out_lr = self.LR_conv(out_B) + out_fea 75 | output = self.upsampler(out_lr) 76 | return output 77 | 78 | # AI in RTC Image Super-Resolution Algorithm Performance Comparison Challenge (Winner solution) 79 | class IMDN_RTC(nn.Module): 80 | def __init__(self, in_nc=3, nf=12, num_modules=5, out_nc=3, upscale=2): 81 | super(IMDN_RTC, self).__init__() 82 | 83 | fea_conv = [B.conv_layer(in_nc, nf, kernel_size=3)] 84 | rb_blocks = [B.IMDModule_speed(in_channels=nf) for _ in range(num_modules)] 85 | LR_conv = B.conv_layer(nf, nf, kernel_size=1) 86 | 87 | upsample_block = B.pixelshuffle_block 88 | upsampler = upsample_block(nf, out_nc, upscale_factor=upscale) 89 | 90 | self.model = B.sequential(*fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), 91 | *upsampler) 92 | 93 | def forward(self, input): 94 | output = self.model(input) 95 | return output 96 | 97 | 98 | class IMDN_RTE(nn.Module): 99 | def __init__(self, upscale=2, in_nc=3, nf=20, out_nc=3): 100 | super(IMDN_RTE, self).__init__() 101 | self.upscale = upscale 102 | self.fea_conv = nn.Sequential(B.conv_layer(in_nc, nf, 3), 103 | nn.ReLU(inplace=True), 104 | B.conv_layer(nf, nf, 3, stride=2, bias=False)) 105 | 106 | self.block1 = IMDModule_Large(nf) 107 | self.block2 = IMDModule_Large(nf) 108 | self.block3 = IMDModule_Large(nf) 109 | self.block4 = IMDModule_Large(nf) 110 | self.block5 = IMDModule_Large(nf) 111 | self.block6 = IMDModule_Large(nf) 112 | 113 | self.LR_conv = B.conv_layer(nf, nf, 1, bias=False) 114 | 115 | self.upsampler = B.pixelshuffle_block(nf, out_nc, upscale_factor=upscale**2) 116 | 117 | def forward(self, input): 118 | 119 | fea = self.fea_conv(input) 120 | out_b1 = self.block1(fea) 121 | out_b2 = self.block2(out_b1) 122 | out_b3 = self.block3(out_b2) 123 | out_b4 = self.block4(out_b3) 124 | out_b5 = self.block5(out_b4) 125 | out_b6 = self.block6(out_b5) 126 | 127 | out_lr = self.LR_conv(out_b6) + fea 128 | 129 | output = self.upsampler(out_lr) 130 | 131 | return output 132 | 133 | -------------------------------------------------------------------------------- /model/block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | import torch 4 | 5 | 6 | def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 7 | padding = int((kernel_size - 1) / 2) * dilation 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias, dilation=dilation, 9 | groups=groups) 10 | 11 | 12 | def norm(norm_type, nc): 13 | norm_type = norm_type.lower() 14 | if norm_type == 'batch': 15 | layer = nn.BatchNorm2d(nc, affine=True) 16 | elif norm_type == 'instance': 17 | layer = nn.InstanceNorm2d(nc, affine=False) 18 | else: 19 | raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) 20 | return layer 21 | 22 | 23 | def pad(pad_type, padding): 24 | pad_type = pad_type.lower() 25 | if padding == 0: 26 | return None 27 | if pad_type == 'reflect': 28 | layer = nn.ReflectionPad2d(padding) 29 | elif pad_type == 'replicate': 30 | layer = nn.ReplicationPad2d(padding) 31 | else: 32 | raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) 33 | return layer 34 | 35 | 36 | def get_valid_padding(kernel_size, dilation): 37 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 38 | padding = (kernel_size - 1) // 2 39 | return padding 40 | 41 | 42 | def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, 43 | pad_type='zero', norm_type=None, act_type='relu'): 44 | padding = get_valid_padding(kernel_size, dilation) 45 | p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None 46 | padding = padding if pad_type == 'zero' else 0 47 | 48 | c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, 49 | dilation=dilation, bias=bias, groups=groups) 50 | a = activation(act_type) if act_type else None 51 | n = norm(norm_type, out_nc) if norm_type else None 52 | return sequential(p, c, n, a) 53 | 54 | 55 | def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1): 56 | act_type = act_type.lower() 57 | if act_type == 'relu': 58 | layer = nn.ReLU(inplace) 59 | elif act_type == 'lrelu': 60 | layer = nn.LeakyReLU(neg_slope, inplace) 61 | elif act_type == 'prelu': 62 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 63 | else: 64 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 65 | return layer 66 | 67 | 68 | class ShortcutBlock(nn.Module): 69 | def __init__(self, submodule): 70 | super(ShortcutBlock, self).__init__() 71 | self.sub = submodule 72 | 73 | def forward(self, x): 74 | output = x + self.sub(x) 75 | return output 76 | 77 | def mean_channels(F): 78 | assert(F.dim() == 4) 79 | spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True) 80 | return spatial_sum / (F.size(2) * F.size(3)) 81 | 82 | def stdv_channels(F): 83 | assert(F.dim() == 4) 84 | F_mean = mean_channels(F) 85 | F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3)) 86 | return F_variance.pow(0.5) 87 | 88 | def sequential(*args): 89 | if len(args) == 1: 90 | if isinstance(args[0], OrderedDict): 91 | raise NotImplementedError('sequential does not support OrderedDict input.') 92 | return args[0] 93 | modules = [] 94 | for module in args: 95 | if isinstance(module, nn.Sequential): 96 | for submodule in module.children(): 97 | modules.append(submodule) 98 | elif isinstance(module, nn.Module): 99 | modules.append(module) 100 | return nn.Sequential(*modules) 101 | 102 | # contrast-aware channel attention module 103 | class CCALayer(nn.Module): 104 | def __init__(self, channel, reduction=16): 105 | super(CCALayer, self).__init__() 106 | 107 | self.contrast = stdv_channels 108 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 109 | self.conv_du = nn.Sequential( 110 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 111 | nn.ReLU(inplace=True), 112 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 113 | nn.Sigmoid() 114 | ) 115 | 116 | 117 | def forward(self, x): 118 | y = self.contrast(x) + self.avg_pool(x) 119 | y = self.conv_du(y) 120 | return x * y 121 | 122 | 123 | class IMDModule(nn.Module): 124 | def __init__(self, in_channels, distillation_rate=0.25): 125 | super(IMDModule, self).__init__() 126 | self.distilled_channels = int(in_channels * distillation_rate) 127 | self.remaining_channels = int(in_channels - self.distilled_channels) 128 | self.c1 = conv_layer(in_channels, in_channels, 3) 129 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3) 130 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3) 131 | self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3) 132 | self.act = activation('lrelu', neg_slope=0.05) 133 | self.c5 = conv_layer(in_channels, in_channels, 1) 134 | self.cca = CCALayer(self.distilled_channels * 4) 135 | 136 | def forward(self, input): 137 | out_c1 = self.act(self.c1(input)) 138 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1) 139 | out_c2 = self.act(self.c2(remaining_c1)) 140 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1) 141 | out_c3 = self.act(self.c3(remaining_c2)) 142 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1) 143 | out_c4 = self.c4(remaining_c3) 144 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1) 145 | out_fused = self.c5(self.cca(out)) + input 146 | return out_fused 147 | 148 | class IMDModule_speed(nn.Module): 149 | def __init__(self, in_channels, distillation_rate=0.25): 150 | super(IMDModule_speed, self).__init__() 151 | self.distilled_channels = int(in_channels * distillation_rate) 152 | self.remaining_channels = int(in_channels - self.distilled_channels) 153 | self.c1 = conv_layer(in_channels, in_channels, 3) 154 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3) 155 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3) 156 | self.c4 = conv_layer(self.remaining_channels, self.distilled_channels, 3) 157 | self.act = activation('lrelu', neg_slope=0.05) 158 | self.c5 = conv_layer(self.distilled_channels * 4, in_channels, 1) 159 | 160 | def forward(self, input): 161 | out_c1 = self.act(self.c1(input)) 162 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1) 163 | out_c2 = self.act(self.c2(remaining_c1)) 164 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1) 165 | out_c3 = self.act(self.c3(remaining_c2)) 166 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1) 167 | out_c4 = self.c4(remaining_c3) 168 | 169 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, out_c4], dim=1) 170 | out_fused = self.c5(out) + input 171 | return out_fused 172 | 173 | class IMDModule_Large(nn.Module): 174 | def __init__(self, in_channels, distillation_rate=1/4): 175 | super(IMDModule_Large, self).__init__() 176 | self.distilled_channels = int(in_channels * distillation_rate) # 6 177 | self.remaining_channels = int(in_channels - self.distilled_channels) # 18 178 | self.c1 = conv_layer(in_channels, in_channels, 3, bias=False) # 24 --> 24 179 | self.c2 = conv_layer(self.remaining_channels, in_channels, 3, bias=False) # 18 --> 24 180 | self.c3 = conv_layer(self.remaining_channels, in_channels, 3, bias=False) # 18 --> 24 181 | self.c4 = conv_layer(self.remaining_channels, self.remaining_channels, 3, bias=False) # 15 --> 15 182 | self.c5 = conv_layer(self.remaining_channels-self.distilled_channels, self.remaining_channels-self.distilled_channels, 3, bias=False) # 10 --> 10 183 | self.c6 = conv_layer(self.distilled_channels, self.distilled_channels, 3, bias=False) # 5 --> 5 184 | self.act = activation('relu') 185 | self.c7 = conv_layer(self.distilled_channels * 6, in_channels, 1, bias=False) 186 | 187 | def forward(self, input): 188 | out_c1 = self.act(self.c1(input)) # 24 --> 24 189 | distilled_c1, remaining_c1 = torch.split(out_c1, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18 190 | out_c2 = self.act(self.c2(remaining_c1)) # 18 --> 24 191 | distilled_c2, remaining_c2 = torch.split(out_c2, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18 192 | out_c3 = self.act(self.c3(remaining_c2)) # 18 --> 24 193 | distilled_c3, remaining_c3 = torch.split(out_c3, (self.distilled_channels, self.remaining_channels), dim=1) # 6, 18 194 | out_c4 = self.act(self.c4(remaining_c3)) # 18 --> 18 195 | distilled_c4, remaining_c4 = torch.split(out_c4, (self.distilled_channels, self.remaining_channels-self.distilled_channels), dim=1) # 6, 12 196 | out_c5 = self.act(self.c5(remaining_c4)) # 12 --> 12 197 | distilled_c5, remaining_c5 = torch.split(out_c5, (self.distilled_channels, self.remaining_channels-self.distilled_channels*2), dim=1) # 6, 6 198 | out_c6 = self.act(self.c6(remaining_c5)) # 6 --> 6 199 | 200 | out = torch.cat([distilled_c1, distilled_c2, distilled_c3, distilled_c4, distilled_c5, out_c6], dim=1) 201 | out_fused = self.c7(out) + input 202 | return out_fused 203 | 204 | def pixelshuffle_block(in_channels, out_channels, upscale_factor=2, kernel_size=3, stride=1): 205 | conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride) 206 | pixel_shuffle = nn.PixelShuffle(upscale_factor) 207 | return sequential(conv, pixel_shuffle) 208 | -------------------------------------------------------------------------------- /scripts/png2npy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import skimage.io as sio 4 | import numpy as np 5 | 6 | parser = argparse.ArgumentParser(description='Pre-processing .png images') 7 | parser.add_argument('--pathFrom', default='', 8 | help='directory of images to convert') 9 | parser.add_argument('--pathTo', default='', 10 | help='directory of images to save') 11 | parser.add_argument('--split', default=True, 12 | help='save individual images') 13 | parser.add_argument('--select', default='', 14 | help='select certain path') 15 | 16 | args = parser.parse_args() 17 | 18 | for (path, dirs, files) in os.walk(args.pathFrom): 19 | print(path) 20 | targetDir = os.path.join(args.pathTo, path[len(args.pathFrom) + 1:]) 21 | if len(args.select) > 0 and path.find(args.select) == -1: 22 | continue 23 | 24 | if not os.path.exists(targetDir): 25 | os.mkdir(targetDir) 26 | 27 | if len(dirs) == 0: 28 | pack = {} 29 | n = 0 30 | for fileName in files: 31 | (idx, ext) = os.path.splitext(fileName) 32 | if ext == '.png': 33 | image = sio.imread(os.path.join(path, fileName)) 34 | if args.split: 35 | np.save(os.path.join(targetDir, idx + '.npy'), image) 36 | n += 1 37 | if n % 100 == 0: 38 | print('Converted ' + str(n) + ' images.') 39 | -------------------------------------------------------------------------------- /test_IMDN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import utils 6 | import skimage.color as sc 7 | import cv2 8 | from model import architecture 9 | # Testing settings 10 | 11 | parser = argparse.ArgumentParser(description='IMDN') 12 | parser.add_argument("--test_hr_folder", type=str, default='Test_Datasets/Set5/', 13 | help='the folder of the target images') 14 | parser.add_argument("--test_lr_folder", type=str, default='Test_Datasets/Set5_LR/x2/', 15 | help='the folder of the input images') 16 | parser.add_argument("--output_folder", type=str, default='results/Set5/x2') 17 | parser.add_argument("--checkpoint", type=str, default='checkpoints/IMDN_x2.pth', 18 | help='checkpoint folder to use') 19 | parser.add_argument('--cuda', action='store_true', default=True, 20 | help='use cuda') 21 | parser.add_argument("--upscale_factor", type=int, default=2, 22 | help='upscaling factor') 23 | parser.add_argument("--is_y", action='store_true', default=True, 24 | help='evaluate on y channel, if False evaluate on RGB channels') 25 | opt = parser.parse_args() 26 | 27 | print(opt) 28 | 29 | cuda = opt.cuda 30 | device = torch.device('cuda' if cuda else 'cpu') 31 | 32 | filepath = opt.test_hr_folder 33 | if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14': 34 | ext = '.bmp' 35 | else: 36 | ext = '.png' 37 | 38 | filelist = utils.get_list(filepath, ext=ext) 39 | psnr_list = np.zeros(len(filelist)) 40 | ssim_list = np.zeros(len(filelist)) 41 | time_list = np.zeros(len(filelist)) 42 | 43 | model = architecture.IMDN(upscale=opt.upscale_factor) 44 | model_dict = utils.load_state_dict(opt.checkpoint) 45 | model.load_state_dict(model_dict, strict=True) 46 | 47 | i = 0 48 | start = torch.cuda.Event(enable_timing=True) 49 | end = torch.cuda.Event(enable_timing=True) 50 | 51 | for imname in filelist: 52 | im_gt = cv2.imread(imname, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR to RGB 53 | im_gt = utils.modcrop(im_gt, opt.upscale_factor) 54 | im_l = cv2.imread(opt.test_lr_folder + imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + ext, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR to RGB 55 | if len(im_gt.shape) < 3: 56 | im_gt = im_gt[..., np.newaxis] 57 | im_gt = np.concatenate([im_gt] * 3, 2) 58 | im_l = im_l[..., np.newaxis] 59 | im_l = np.concatenate([im_l] * 3, 2) 60 | im_input = im_l / 255.0 61 | im_input = np.transpose(im_input, (2, 0, 1)) 62 | im_input = im_input[np.newaxis, ...] 63 | im_input = torch.from_numpy(im_input).float() 64 | 65 | if cuda: 66 | model = model.to(device) 67 | im_input = im_input.to(device) 68 | 69 | with torch.no_grad(): 70 | start.record() 71 | out = model(im_input) 72 | end.record() 73 | torch.cuda.synchronize() 74 | time_list[i] = start.elapsed_time(end) # milliseconds 75 | 76 | out_img = utils.tensor2np(out.detach()[0]) 77 | crop_size = opt.upscale_factor 78 | cropped_sr_img = utils.shave(out_img, crop_size) 79 | cropped_gt_img = utils.shave(im_gt, crop_size) 80 | if opt.is_y is True: 81 | im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0]) 82 | im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0]) 83 | else: 84 | im_label = cropped_gt_img 85 | im_pre = cropped_sr_img 86 | psnr_list[i] = utils.compute_psnr(im_pre, im_label) 87 | ssim_list[i] = utils.compute_ssim(im_pre, im_label) 88 | 89 | 90 | output_folder = os.path.join(opt.output_folder, 91 | imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + '.png') 92 | 93 | if not os.path.exists(opt.output_folder): 94 | os.makedirs(opt.output_folder) 95 | 96 | cv2.imwrite(output_folder, out_img[:, :, [2, 1, 0]]) 97 | i += 1 98 | 99 | 100 | print("Mean PSNR: {}, SSIM: {}, TIME: {} ms".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list))) 101 | -------------------------------------------------------------------------------- /test_IMDN_AS.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import utils 6 | import skimage.color as sc 7 | import cv2 8 | from model import architecture 9 | 10 | parser = argparse.ArgumentParser(description='IMDN_AS') 11 | parser.add_argument("--test_hr_folder", type=str, default='Test_Datasets/RealSR/ValidationGT', 12 | help='the folder of the target images') 13 | parser.add_argument("--test_lr_folder", type=str, default='Test_Datasets/RealSR/ValidationLR/', 14 | help='the folder of the input images') 15 | parser.add_argument("--output_folder", type=str, default='results/RealSR') 16 | parser.add_argument("--checkpoint", type=str, default='checkpoints/IMDN_AS.pth', 17 | help='checkpoint folder to use') 18 | parser.add_argument('--cuda', action='store_true', default=True, 19 | help='use cuda') 20 | parser.add_argument("--is_y", action='store_true', default=False, 21 | help='evaluate on y channel, if False evaluate on RGB channels') 22 | opt = parser.parse_args() 23 | 24 | print(opt) 25 | 26 | cuda = opt.cuda 27 | device = torch.device('cuda' if cuda else 'cpu') 28 | 29 | 30 | def crop_forward(x, model, shave=32): 31 | b, c, h, w = x.size() 32 | h_half, w_half = h // 2, w // 2 33 | 34 | h_size, w_size = h_half + shave - (h_half + shave) % 4, w_half + shave - (w_half + shave) % 4 35 | 36 | inputlist = [ 37 | x[:, :, 0:h_size, 0:w_size], 38 | x[:, :, 0:h_size, (w - w_size):w], 39 | x[:, :, (h - h_size):h, 0:w_size], 40 | x[:, :, (h - h_size):h, (w - w_size):w]] 41 | 42 | outputlist = [] 43 | 44 | with torch.no_grad(): 45 | input_batch = torch.cat(inputlist, dim=0) 46 | output_batch = model(input_batch) 47 | outputlist.extend(output_batch.chunk(4, dim=0)) 48 | 49 | output = torch.zeros_like(x) 50 | 51 | output[:, :, 0:h_half, 0:w_half] \ 52 | = outputlist[0][:, :, 0:h_half, 0:w_half] 53 | output[:, :, 0:h_half, w_half:w] \ 54 | = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 55 | output[:, :, h_half:h, 0:w_half] \ 56 | = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 57 | output[:, :, h_half:h, w_half:w] \ 58 | = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 59 | 60 | return output 61 | 62 | 63 | filepath = opt.test_hr_folder 64 | if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14': 65 | ext = '.bmp' 66 | else: 67 | ext = '.png' 68 | 69 | filelist = utils.get_list(filepath, ext=ext) 70 | psnr_list = np.zeros(len(filelist)) 71 | ssim_list = np.zeros(len(filelist)) 72 | time_list = np.zeros(len(filelist)) 73 | 74 | model = architecture.IMDN_AS() 75 | model_dict = utils.load_state_dict(opt.checkpoint) 76 | model.load_state_dict(model_dict, strict=True) 77 | 78 | start = torch.cuda.Event(enable_timing=True) 79 | end = torch.cuda.Event(enable_timing=True) 80 | i = 0 81 | for imname in filelist: 82 | im_gt = cv2.imread(imname)[:, :, [2, 1, 0]] 83 | im_l = cv2.imread(opt.test_lr_folder + imname.split('/')[-1])[:, :, [2, 1, 0]] 84 | if len(im_gt.shape) < 3: 85 | im_gt = im_gt[..., np.newaxis] 86 | im_gt = np.concatenate([im_gt] * 3, 2) 87 | im_l = im_l[..., np.newaxis] 88 | im_l = np.concatenate([im_l] * 3, 2) 89 | im_input = im_l / 255.0 90 | im_input = np.transpose(im_input, (2, 0, 1)) 91 | im_input = im_input[np.newaxis, ...] 92 | im_input = torch.from_numpy(im_input).float() 93 | 94 | if cuda: 95 | model = model.to(device) 96 | im_input = im_input.to(device) 97 | 98 | _, _, h, w = im_input.size() 99 | with torch.no_grad(): 100 | 101 | if h % 4 == 0 and w % 4 == 0: 102 | start.record() 103 | out = model(im_input) 104 | end.record() 105 | torch.cuda.synchronize() 106 | time_list[i] = start.elapsed_time(end) # milliseconds 107 | else: 108 | start.record() 109 | out = crop_forward(im_input, model) 110 | end.record() 111 | torch.cuda.synchronize() 112 | time_list[i] = start.elapsed_time(end) # milliseconds 113 | 114 | sr_img = utils.tensor2np(out.detach()[0]) 115 | if opt.is_y is True: 116 | im_label = utils.quantize(sc.rgb2ycbcr(im_gt)[:, :, 0]) 117 | im_pre = utils.quantize(sc.rgb2ycbcr(sr_img)[:, :, 0]) 118 | else: 119 | im_label = im_gt 120 | im_pre = sr_img 121 | psnr_list[i] = utils.compute_psnr(im_pre, im_label) 122 | ssim_list[i] = utils.compute_ssim(im_pre, im_label) 123 | 124 | output_folder = os.path.join(opt.output_folder, 125 | imname.split('/')[-1]) 126 | 127 | if not os.path.exists(opt.output_folder): 128 | os.makedirs(opt.output_folder) 129 | 130 | cv2.imwrite(output_folder, sr_img[:, :, [2, 1, 0]]) 131 | i += 1 132 | 133 | print("Mean PSNR: {}, SSIM: {}, Time: {} ms".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list))) 134 | -------------------------------------------------------------------------------- /train_IMDN.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from model import architecture 7 | from data import DIV2K, Set5_val 8 | import utils 9 | import skimage.color as sc 10 | import random 11 | from collections import OrderedDict 12 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0' 13 | 14 | # Training settings 15 | parser = argparse.ArgumentParser(description="IMDN") 16 | parser.add_argument("--batch_size", type=int, default=16, 17 | help="training batch size") 18 | parser.add_argument("--testBatchSize", type=int, default=1, 19 | help="testing batch size") 20 | parser.add_argument("-nEpochs", type=int, default=1000, 21 | help="number of epochs to train") 22 | parser.add_argument("--lr", type=float, default=2e-4, 23 | help="Learning Rate. Default=2e-4") 24 | parser.add_argument("--step_size", type=int, default=200, 25 | help="learning rate decay per N epochs") 26 | parser.add_argument("--gamma", type=int, default=0.5, 27 | help="learning rate decay factor for step decay") 28 | parser.add_argument("--cuda", action="store_true", default=True, 29 | help="use cuda") 30 | parser.add_argument("--resume", default="", type=str, 31 | help="path to checkpoint") 32 | parser.add_argument("--start-epoch", default=1, type=int, 33 | help="manual epoch number") 34 | parser.add_argument("--threads", type=int, default=8, 35 | help="number of threads for data loading") 36 | parser.add_argument("--root", type=str, default="training_data/", 37 | help='dataset directory') 38 | parser.add_argument("--n_train", type=int, default=800, 39 | help="number of training set") 40 | parser.add_argument("--n_val", type=int, default=1, 41 | help="number of validation set") 42 | parser.add_argument("--test_every", type=int, default=1000) 43 | parser.add_argument("--scale", type=int, default=2, 44 | help="super-resolution scale") 45 | parser.add_argument("--patch_size", type=int, default=192, 46 | help="output patch size") 47 | parser.add_argument("--rgb_range", type=int, default=1, 48 | help="maxium value of RGB") 49 | parser.add_argument("--n_colors", type=int, default=3, 50 | help="number of color channels to use") 51 | parser.add_argument("--pretrained", default="", type=str, 52 | help="path to pretrained models") 53 | parser.add_argument("--seed", type=int, default=1) 54 | parser.add_argument("--isY", action="store_true", default=True) 55 | parser.add_argument("--ext", type=str, default='.npy') 56 | parser.add_argument("--phase", type=str, default='train') 57 | 58 | args = parser.parse_args() 59 | print(args) 60 | torch.backends.cudnn.benchmark = True 61 | # random seed 62 | seed = args.seed 63 | if seed is None: 64 | seed = random.randint(1, 10000) 65 | print("Ramdom Seed: ", seed) 66 | random.seed(seed) 67 | torch.manual_seed(seed) 68 | 69 | cuda = args.cuda 70 | device = torch.device('cuda' if cuda else 'cpu') 71 | 72 | print("===> Loading datasets") 73 | 74 | trainset = DIV2K.div2k(args) 75 | testset = Set5_val.DatasetFromFolderVal("Test_Datasets/Set5/", 76 | "Test_Datasets/Set5_LR/x{}/".format(args.scale), 77 | args.scale) 78 | training_data_loader = DataLoader(dataset=trainset, num_workers=args.threads, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True) 79 | testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize, 80 | shuffle=False) 81 | 82 | print("===> Building models") 83 | args.is_train = True 84 | 85 | model = architecture.IMDN(upscale=args.scale) 86 | l1_criterion = nn.L1Loss() 87 | 88 | print("===> Setting GPU") 89 | if cuda: 90 | model = model.to(device) 91 | l1_criterion = l1_criterion.to(device) 92 | 93 | if args.pretrained: 94 | 95 | if os.path.isfile(args.pretrained): 96 | print("===> loading models '{}'".format(args.pretrained)) 97 | checkpoint = torch.load(args.pretrained) 98 | new_state_dcit = OrderedDict() 99 | for k, v in checkpoint.items(): 100 | if 'module' in k: 101 | name = k[7:] 102 | else: 103 | name = k 104 | new_state_dcit[name] = v 105 | model_dict = model.state_dict() 106 | pretrained_dict = {k: v for k, v in new_state_dcit.items() if k in model_dict} 107 | 108 | for k, v in model_dict.items(): 109 | if k not in pretrained_dict: 110 | print(k) 111 | model.load_state_dict(pretrained_dict, strict=True) 112 | 113 | else: 114 | print("===> no models found at '{}'".format(args.pretrained)) 115 | 116 | print("===> Setting Optimizer") 117 | 118 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 119 | 120 | 121 | def train(epoch): 122 | model.train() 123 | utils.adjust_learning_rate(optimizer, epoch, args.step_size, args.lr, args.gamma) 124 | print('epoch =', epoch, 'lr = ', optimizer.param_groups[0]['lr']) 125 | for iteration, (lr_tensor, hr_tensor) in enumerate(training_data_loader, 1): 126 | 127 | if args.cuda: 128 | lr_tensor = lr_tensor.to(device) # ranges from [0, 1] 129 | hr_tensor = hr_tensor.to(device) # ranges from [0, 1] 130 | 131 | optimizer.zero_grad() 132 | sr_tensor = model(lr_tensor) 133 | loss_l1 = l1_criterion(sr_tensor, hr_tensor) 134 | loss_sr = loss_l1 135 | 136 | loss_sr.backward() 137 | optimizer.step() 138 | if iteration % 100 == 0: 139 | print("===> Epoch[{}]({}/{}): Loss_l1: {:.5f}".format(epoch, iteration, len(training_data_loader), 140 | loss_l1.item())) 141 | 142 | 143 | def valid(): 144 | model.eval() 145 | 146 | avg_psnr, avg_ssim = 0, 0 147 | for batch in testing_data_loader: 148 | lr_tensor, hr_tensor = batch[0], batch[1] 149 | if args.cuda: 150 | lr_tensor = lr_tensor.to(device) 151 | hr_tensor = hr_tensor.to(device) 152 | 153 | with torch.no_grad(): 154 | pre = model(lr_tensor) 155 | 156 | sr_img = utils.tensor2np(pre.detach()[0]) 157 | gt_img = utils.tensor2np(hr_tensor.detach()[0]) 158 | crop_size = args.scale 159 | cropped_sr_img = utils.shave(sr_img, crop_size) 160 | cropped_gt_img = utils.shave(gt_img, crop_size) 161 | if args.isY is True: 162 | im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0]) 163 | im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0]) 164 | else: 165 | im_label = cropped_gt_img 166 | im_pre = cropped_sr_img 167 | avg_psnr += utils.compute_psnr(im_pre, im_label) 168 | avg_ssim += utils.compute_ssim(im_pre, im_label) 169 | print("===> Valid. psnr: {:.4f}, ssim: {:.4f}".format(avg_psnr / len(testing_data_loader), avg_ssim / len(testing_data_loader))) 170 | 171 | 172 | def save_checkpoint(epoch): 173 | model_folder = "checkpoint_x{}/".format(args.scale) 174 | model_out_path = model_folder + "epoch_{}.pth".format(epoch) 175 | if not os.path.exists(model_folder): 176 | os.makedirs(model_folder) 177 | torch.save(model.state_dict(), model_out_path) 178 | print("===> Checkpoint saved to {}".format(model_out_path)) 179 | 180 | def print_network(net): 181 | num_params = 0 182 | for param in net.parameters(): 183 | num_params += param.numel() 184 | print(net) 185 | print('Total number of parameters: %d' % num_params) 186 | 187 | 188 | print("===> Training") 189 | print_network(model) 190 | for epoch in range(args.start_epoch, args.nEpochs + 1): 191 | valid() 192 | train(epoch) 193 | save_checkpoint(epoch) 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from skimage.measure import compare_psnr as psnr 2 | from skimage.measure import compare_ssim as ssim 3 | import numpy as np 4 | import os 5 | import torch 6 | from collections import OrderedDict 7 | 8 | def compute_psnr(im1, im2): 9 | p = psnr(im1, im2) 10 | return p 11 | 12 | 13 | def compute_ssim(im1, im2): 14 | isRGB = len(im1.shape) == 3 and im1.shape[-1] == 3 15 | s = ssim(im1, im2, K1=0.01, K2=0.03, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, 16 | multichannel=isRGB) 17 | return s 18 | 19 | 20 | def shave(im, border): 21 | border = [border, border] 22 | im = im[border[0]:-border[0], border[1]:-border[1], ...] 23 | return im 24 | 25 | 26 | def modcrop(im, modulo): 27 | sz = im.shape 28 | h = np.int32(sz[0] / modulo) * modulo 29 | w = np.int32(sz[1] / modulo) * modulo 30 | ims = im[0:h, 0:w, ...] 31 | return ims 32 | 33 | 34 | def get_list(path, ext): 35 | return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)] 36 | 37 | 38 | def convert_shape(img): 39 | img = np.transpose((img * 255.0).round(), (1, 2, 0)) 40 | img = np.uint8(np.clip(img, 0, 255)) 41 | return img 42 | 43 | 44 | def quantize(img): 45 | return img.clip(0, 255).round().astype(np.uint8) 46 | 47 | 48 | def tensor2np(tensor, out_type=np.uint8, min_max=(0, 1)): 49 | tensor = tensor.float().cpu().clamp_(*min_max) 50 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0, 1] 51 | img_np = tensor.numpy() 52 | img_np = np.transpose(img_np, (1, 2, 0)) 53 | if out_type == np.uint8: 54 | img_np = (img_np * 255.0).round() 55 | 56 | return img_np.astype(out_type) 57 | 58 | def convert2np(tensor): 59 | return tensor.cpu().mul(255).clamp(0, 255).byte().squeeze().permute(1, 2, 0).numpy() 60 | 61 | 62 | def adjust_learning_rate(optimizer, epoch, step_size, lr_init, gamma): 63 | factor = epoch // step_size 64 | lr = lr_init * (gamma ** factor) 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | def load_state_dict(path): 69 | 70 | state_dict = torch.load(path) 71 | new_state_dcit = OrderedDict() 72 | for k, v in state_dict.items(): 73 | if 'module' in k: 74 | name = k[7:] 75 | else: 76 | name = k 77 | new_state_dcit[name] = v 78 | return new_state_dcit 79 | --------------------------------------------------------------------------------