├── .gitignore ├── README.md ├── calc_psnr_ssim.m ├── data ├── __init__.py ├── base_dataset.py ├── benchmark_dataset.py ├── div2k_dataset.py ├── imlib.py └── sr_dataset.py ├── figs ├── architecture.png ├── results.png └── visualization.png ├── flops.py ├── masked_conv2d ├── masked_conv2d │ ├── __init__.py │ ├── masked_conv.py │ └── src │ │ ├── masked_conv2d_cuda.cpp │ │ └── masked_conv2d_kernel.cu └── setup.py ├── models ├── MPNCOV │ ├── __init__.py │ └── python │ │ ├── MPNCOV.py │ │ └── __init__.py ├── __init__.py ├── adaedsr_fixd_model.py ├── adaedsr_model.py ├── adarcan_model.py ├── base_model.py ├── common.py ├── dsr_model.py ├── edsr_model.py ├── losses.py ├── networks.py ├── non_local │ ├── network.py │ ├── non_local.py │ ├── non_local_simple_version.py │ └── utils.py ├── rcan_model.py ├── rdn_model.py ├── san_model.py ├── srcnn_model.py ├── srresnet_model.py └── vdsr_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── test_adaedsr.sh ├── test_adaedsr_fixd.sh ├── test_adarcan.sh ├── test_edsr.sh ├── test_rcan.sh ├── test_rdn.sh ├── test_san.sh ├── test_srcnn.sh ├── test_vdsr.sh ├── train_adaedsr.sh ├── train_adaedsr_fixd.sh ├── train_adarcan.sh ├── train_edsr.sh ├── train_rcan.sh ├── train_rdn.sh ├── train_san.sh ├── train_srcnn.sh └── train_vdsr.sh ├── test.py ├── train.py └── util ├── __init__.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | */__pycache__ 4 | */*/__pycache__ 5 | */*/*/__pycache__ 6 | .vscode 7 | checkpoints 8 | ckpt 9 | pretrained 10 | *.out 11 | *.err 12 | *.log 13 | log* 14 | tmp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaDSR (ECCV 2020 AIM workshop) 2 | 3 | **PyTorch** implementation of [Deep Adaptive Inference Networks for Single Image Super-Resolution](https://arxiv.org/abs/2004.03915) 4 | 5 |

6 |

Overall structure of our AdaDSR.

7 | 8 | ## Results 9 | 10 |

11 |

SISR results on Set5. More results please refer to the paper.

12 | 13 |

14 |

An exemplar visualization of the SR results and depth map.

15 | 16 | ## Preparation 17 | 18 | - **Prerequisites** 19 | - PyTorch (v1.2) 20 | - Python 3.x, with OpenCV, Numpy, Pillow, tqdm and matplotlib, and tensorboardX is used for visualization 21 | - [optional] Make sure that matlab is in your PATH, if you want to calculate the PSNR/SSIM indices and use the argument `--matlab True` 22 | - [Sparse Conv] `cd masked_conv2d; python setup.py install`. Note that currently we provide a version modified from [open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/mmdet/ops/masked_conv), which supports inference with 3x3 sparse convolution layer. We will provide a more general version in the future. 23 | - **Dataset** 24 | - Training 25 | - [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) is used for training, you can download the dataset from [ETH_CVL](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar). 26 | - The data should be organized as `DIV2K_ROOT/DIV2K_train_HR/*.png`, `DIV2K_ROOT/DIV2K_train_LR_bicubic/X[234]/*.png`, which is identical to the official format. 27 | - Testing 28 | - [Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html), [Set14](https://sites.google.com/site/romanzeyde/research-interests), [B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), [Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr) and [Manga109](http://www.manga109.org/en/index.html) are used for test. 29 | - You can download these datasets (~300MB) from [Google Drive](https://drive.google.com/open?id=1oacPCU5VPPy5swx8X8YU8ax7aVyCpPKJ) or [Baidu Yun](https://pan.baidu.com/s/1go1Y1reQUk68FX_n7ORMVQ) (43qz), and run `Prepare_TestData_HR_LR.m` in matlab, you can get two folders named `HR` and `LR`. Place these two folders in `BENCHMARK_ROOT`. 30 | - Models 31 | - Download the pre-trained models (~2.2GB) from [Google Drive](https://drive.google.com/open?id=1LmrkG5w0-JbP6t5413KOPpNlVG4hK3lZ) or [Baidu Yun](https://pan.baidu.com/s/1dcmO9Pc74Ta5p9liGwnDwA) (cyps), and put the two folders in the root folder. 32 | 33 | ## Quick Start 34 | 35 | We show some exemplar commands here for better introduction, and more useful scripts are given in the [scripts](./scripts) folder. 36 | 37 | ### Testing 38 | 39 | - AdaEDSR 40 | 41 | ```console 42 | python test.py --model adaedsr --name adaedsr_x2 --scale 2 --load_path ./ckpt/adaedsr_x2/AdaEDSR_model.pth --dataset_name set5 set14 b100 urban100 manga109 --depth 32 --chop True --sparse_conv True --matlab True --gpu_ids 0 43 | ``` 44 | 45 | - AdaRCAN 46 | 47 | ```console 48 | python test.py --model adarcan --name adarcan_x2 --scale 2 --load_path ./ckpt/adarcan_x2/AdaRCAN_model.pth --dataset_name set5 set14 b100 urban100 manga109 --depth 20 --chop True --sparse_conv True --matlab True --gpu_ids 0 49 | ``` 50 | 51 | ### Training 52 | 53 | - AdaEDSR (Load pre-trained EDSR model for more stable training) 54 | 55 | ```console 56 | python train.py --model adaedsr --name adaedsr_x2 --scale 2 --load_path ./pretrained/EDSR_official_32_x${scale}.pth 57 | ``` 58 | 59 | - AdaRCAN (Load pre-trained RCAN model for more stable training) 60 | 61 | ```console 62 | python train.py --model adarcan --name adarcan_x2 --scale 2 --load_path ./pretrained/RCAN_BIX2.pth 63 | ``` 64 | 65 | ### Note 66 | 67 | - You should set data root by `--dataroot DIV2K_ROOT` (train) or `--dataroot BENCHMARK_ROOT` (test), or you can add your own path in the rootlist of [div2k_dataset](./data/div2k_dataset.py#L11-L12) or [benchmark_dataset](./data/benchmark_dataset.py#L15-L16). 68 | - You can specify which GPU to use by `--gpu_ids`, e.g., `--gpu_ids 0,1`, `--gpu_ids 3`, `--gpu_ids -1` (for CPU mode). In the default setting, all GPUs are used. 69 | - You can refer to [options](./options/base_options.py) for more arguments. 70 | 71 | ## Citation 72 | If you find AdaDSR useful in your research, please consider citing: 73 | 74 | @inproceedings{AdaDSR, 75 | title={Deep Adaptive Inference Networks for Single Image Super-Resolution}, 76 | author={Liu, Ming and Zhang, Zhilu and Hou, Liya and Zuo, Wangmeng and Zhang, Lei}, 77 | booktitle={European Conference on Computer Vision Workshops}, 78 | year={2020} 79 | } 80 | 81 | ## Acknowledgement 82 | 83 | This repo is built upon the framework of [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and we borrow some code from [DPSR](https://github.com/cszn/DPSR), [mmdetection](https://github.com/open-mmlab/mmdetection), [EDSR](https://github.com/thstkdgus35/EDSR-PyTorch), [RCAN](https://github.com/yulunzhang/RCAN) and [SAN](https://github.com/daitao/SAN), thanks for their excellent work! 84 | -------------------------------------------------------------------------------- /calc_psnr_ssim.m: -------------------------------------------------------------------------------- 1 | function Evaluate_PSNR_SSIM() 2 | 3 | clear all; close all; clc 4 | 5 | %% set path 6 | ext = {'*.jpg', '*.png', '*.bmp'}; 7 | record_results_txt = ['result.txt']; 8 | results = fopen(fullfile(record_results_txt), 'wt'); 9 | 10 | num_imgs = length(dir('./tmp/HR'))-2; 11 | 12 | PSNR_all = zeros(1, num_imgs); 13 | SSIM_all = zeros(1, num_imgs); 14 | for idx_im = 1:num_imgs 15 | im_HR = imread(fullfile('./tmp/HR', [num2str(idx_im-1), '.png'])); 16 | im_SR = imread(fullfile('./tmp/SR', [num2str(idx_im-1), '.png'])); 17 | % change channel for evaluation 18 | if 3 == size(im_HR, 3) 19 | im_HR_YCbCr = single(rgb2ycbcr(im2double(im_HR))); 20 | im_HR_Y = im_HR_YCbCr(:,:,1); 21 | im_SR_YCbCr = single(rgb2ycbcr(im2double(im_SR))); 22 | im_SR_Y = im_SR_YCbCr(:,:,1); 23 | else 24 | im_HR_Y = single(im2double(im_HR)); 25 | im_SR_Y = single(im2double(im_SR)); 26 | end 27 | % calculate PSNR, SSIM 28 | [PSNR_all(idx_im), SSIM_all(idx_im)] = ... 29 | Cal_Y_PSNRSSIM(im_HR_Y*255, im_SR_Y*255); 30 | end 31 | fprintf(results, '%f %f', mean(PSNR_all), mean(SSIM_all)); 32 | fclose(results); 33 | 34 | end 35 | 36 | function [psnr_cur, ssim_cur] = Cal_Y_PSNRSSIM(A,B) 37 | % RGB --> YCbCr 38 | if 3 == size(A, 3) 39 | A = rgb2ycbcr(A); 40 | A = A(:,:,1); 41 | end 42 | if 3 == size(B, 3) 43 | B = rgb2ycbcr(B); 44 | B = B(:,:,1); 45 | end 46 | % calculate PSNR 47 | A=double(A); % Ground-truth 48 | B=double(B); % 49 | 50 | e=A(:)-B(:); 51 | mse=mean(e.^2); 52 | psnr_cur=10*log10(255^2/mse); 53 | 54 | % calculate SSIM 55 | [ssim_cur, ~] = ssim_index(A, B); 56 | end 57 | 58 | 59 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 60 | 61 | %======================================================================== 62 | %SSIM Index, Version 1.0 63 | %Copyright(c) 2003 Zhou Wang 64 | %All Rights Reserved. 65 | % 66 | %The author is with Howard Hughes Medical Institute, and Laboratory 67 | %for Computational Vision at Center for Neural Science and Courant 68 | %Institute of Mathematical Sciences, New York University. 69 | % 70 | %---------------------------------------------------------------------- 71 | %Permission to use, copy, or modify this software and its documentation 72 | %for educational and research purposes only and without fee is hereby 73 | %granted, provided that this copyright notice and the original authors' 74 | %names appear on all copies and supporting documentation. This program 75 | %shall not be used, rewritten, or adapted as the basis of a commercial 76 | %software or hardware product without first obtaining permission of the 77 | %authors. The authors make no representations about the suitability of 78 | %this software for any purpose. It is provided "as is" without express 79 | %or implied warranty. 80 | %---------------------------------------------------------------------- 81 | % 82 | %This is an implementation of the algorithm for calculating the 83 | %Structural SIMilarity (SSIM) index between two images. Please refer 84 | %to the following paper: 85 | % 86 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 87 | %quality assessment: From error measurement to structural similarity" 88 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 89 | % 90 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 91 | % 92 | %---------------------------------------------------------------------- 93 | % 94 | %Input : (1) img1: the first image being compared 95 | % (2) img2: the second image being compared 96 | % (3) K: constants in the SSIM index formula (see the above 97 | % reference). defualt value: K = [0.01 0.03] 98 | % (4) window: local window for statistics (see the above 99 | % reference). default widnow is Gaussian given by 100 | % window = fspecial('gaussian', 11, 1.5); 101 | % (5) L: dynamic range of the images. default: L = 255 102 | % 103 | %Output: (1) mssim: the mean SSIM index value between 2 images. 104 | % If one of the images being compared is regarded as 105 | % perfect quality, then mssim can be considered as the 106 | % quality measure of the other image. 107 | % If img1 = img2, then mssim = 1. 108 | % (2) ssim_map: the SSIM index map of the test image. The map 109 | % has a smaller size than the input images. The actual size: 110 | % size(img1) - size(window) + 1. 111 | % 112 | %Default Usage: 113 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 114 | % 115 | % [mssim ssim_map] = ssim_index(img1, img2); 116 | % 117 | %Advanced Usage: 118 | % User defined parameters. For example 119 | % 120 | % K = [0.05 0.05]; 121 | % window = ones(8); 122 | % L = 100; 123 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 124 | % 125 | %See the results: 126 | % 127 | % mssim %Gives the mssim value 128 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 129 | % 130 | %======================================================================== 131 | 132 | 133 | if (nargin < 2 || nargin > 5) 134 | ssim_index = -Inf; 135 | ssim_map = -Inf; 136 | return; 137 | end 138 | 139 | if (size(img1) ~= size(img2)) 140 | ssim_index = -Inf; 141 | ssim_map = -Inf; 142 | return; 143 | end 144 | 145 | [M N] = size(img1); 146 | 147 | if (nargin == 2) 148 | if ((M < 11) || (N < 11)) 149 | ssim_index = -Inf; 150 | ssim_map = -Inf; 151 | return 152 | end 153 | window = fspecial('gaussian', 11, 1.5); % 154 | K(1) = 0.01; % default settings 155 | K(2) = 0.03; % 156 | L = 255; % 157 | end 158 | 159 | if (nargin == 3) 160 | if ((M < 11) || (N < 11)) 161 | ssim_index = -Inf; 162 | ssim_map = -Inf; 163 | return 164 | end 165 | window = fspecial('gaussian', 11, 1.5); 166 | L = 255; 167 | if (length(K) == 2) 168 | if (K(1) < 0 || K(2) < 0) 169 | ssim_index = -Inf; 170 | ssim_map = -Inf; 171 | return; 172 | end 173 | else 174 | ssim_index = -Inf; 175 | ssim_map = -Inf; 176 | return; 177 | end 178 | end 179 | 180 | if (nargin == 4) 181 | [H W] = size(window); 182 | if ((H*W) < 4 || (H > M) || (W > N)) 183 | ssim_index = -Inf; 184 | ssim_map = -Inf; 185 | return 186 | end 187 | L = 255; 188 | if (length(K) == 2) 189 | if (K(1) < 0 || K(2) < 0) 190 | ssim_index = -Inf; 191 | ssim_map = -Inf; 192 | return; 193 | end 194 | else 195 | ssim_index = -Inf; 196 | ssim_map = -Inf; 197 | return; 198 | end 199 | end 200 | 201 | if (nargin == 5) 202 | [H W] = size(window); 203 | if ((H*W) < 4 || (H > M) || (W > N)) 204 | ssim_index = -Inf; 205 | ssim_map = -Inf; 206 | return 207 | end 208 | if (length(K) == 2) 209 | if (K(1) < 0 || K(2) < 0) 210 | ssim_index = -Inf; 211 | ssim_map = -Inf; 212 | return; 213 | end 214 | else 215 | ssim_index = -Inf; 216 | ssim_map = -Inf; 217 | return; 218 | end 219 | end 220 | 221 | C1 = (K(1)*L)^2; 222 | C2 = (K(2)*L)^2; 223 | window = window/sum(sum(window)); 224 | img1 = double(img1); 225 | img2 = double(img2); 226 | 227 | mu1 = filter2(window, img1, 'valid'); 228 | mu2 = filter2(window, img2, 'valid'); 229 | mu1_sq = mu1.*mu1; 230 | mu2_sq = mu2.*mu2; 231 | mu1_mu2 = mu1.*mu2; 232 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 233 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 234 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 235 | 236 | if (C1 > 0 & C2 > 0) 237 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./ ... 238 | ((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 239 | else 240 | numerator1 = 2*mu1_mu2 + C1; 241 | numerator2 = 2*sigma12 + C2; 242 | denominator1 = mu1_sq + mu2_sq + C1; 243 | denominator2 = sigma1_sq + sigma2_sq + C2; 244 | ssim_map = ones(size(mu1)); 245 | index = (denominator1.*denominator2 > 0); 246 | ssim_map(index) = (numerator1(index).*numerator2(index))./ ... 247 | (denominator1(index).*denominator2(index)); 248 | index = (denominator1 ~= 0) & (denominator2 == 0); 249 | ssim_map(index) = numerator1(index)./denominator1(index); 250 | end 251 | 252 | mssim = mean2(ssim_map); 253 | 254 | end -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_dataset import BaseDataset 4 | 5 | def find_dataset_using_name(dataset_name, split='train'): 6 | benchmark_datasets = ('set5', 'set14', 'urban100', 'b100', 'manga109') 7 | if (dataset_name.lower() in benchmark_datasets) or \ 8 | (dataset_name.lower() == 'div2k' and split == 'test'): 9 | dataset_name = 'benchmark' 10 | dataset_filename = "data." + dataset_name + "_dataset" 11 | datasetlib = importlib.import_module(dataset_filename) 12 | 13 | dataset = None 14 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 15 | for name, cls in datasetlib.__dict__.items(): 16 | if name.lower() == target_dataset_name.lower() \ 17 | and issubclass(cls, BaseDataset): 18 | dataset = cls 19 | 20 | if dataset is None: 21 | raise NotImplementedError("In %s.py, there should be a subclass of " 22 | "BaseDataset with class name that matches %s in " 23 | "lowercase." % (dataset_filename, target_dataset_name)) 24 | return dataset 25 | 26 | 27 | def create_dataset(dataset_name, split, opt): 28 | data_loader = CustomDatasetDataLoader(dataset_name, split, opt) 29 | dataset = data_loader.load_data() 30 | return dataset 31 | 32 | 33 | class CustomDatasetDataLoader(): 34 | def __init__(self, dataset_name, split, opt): 35 | self.opt = opt 36 | dataset_class = find_dataset_using_name(dataset_name, split) 37 | self.dataset = dataset_class(opt, split, dataset_name) 38 | self.imio = self.dataset.imio 39 | print("dataset [%s(%s)] created" % (dataset_name, split)) 40 | self.dataloader = torch.utils.data.DataLoader( 41 | self.dataset, 42 | batch_size=opt.batch_size if split=='train' else 1, 43 | shuffle=opt.shuffle and split=='train', 44 | num_workers=int(opt.num_dataloader), 45 | drop_last=opt.drop_last) 46 | 47 | def load_data(self): 48 | return self 49 | 50 | def __len__(self): 51 | """Return the number of data in the dataset""" 52 | return min(len(self.dataset), self.opt.max_dataset_size) 53 | 54 | def __iter__(self): 55 | """Return a batch of data""" 56 | for i, data in enumerate(self.dataloader): 57 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 58 | break 59 | yield data 60 | 61 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch.utils.data as data 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class BaseDataset(data.Dataset, ABC): 8 | def __init__(self, opt, split, dataset_name): 9 | self.opt = opt 10 | self.split = split 11 | self.root = opt.dataroot 12 | self.dataset_name = dataset_name.lower() 13 | 14 | @abstractmethod 15 | def __len__(self): 16 | return 0 17 | 18 | @abstractmethod 19 | def __getitem__(self, index): 20 | pass -------------------------------------------------------------------------------- /data/benchmark_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from os.path import join 5 | from .sr_dataset import SRDataset 6 | 7 | class BenchmarkDataset(SRDataset): 8 | 9 | name2dir = {'div2k': 'DIV2K_valid_HR', 'set5': 'Set5', 'set14': 'Set14', 10 | 'b100': 'B100', 'urban100': 'Urban100', 'manga109': 'Manga109'} 11 | 12 | def __init__(self, opt, split, dataset_name): 13 | super(BenchmarkDataset, self).__init__(opt, split, dataset_name) 14 | if self.root == '': 15 | rootlist = ['D:/Datasets/SR/SR', 16 | '/data/SR'] 17 | for root in rootlist: 18 | if os.path.isdir(root): 19 | self.root = root 20 | break 21 | self.hr_root = join(self.root, 'HR/%s/x%d' % \ 22 | (self.name2dir[self.dataset_name], self.scale)) 23 | self.lr_root = join(self.root, 'LR/LRBI/%s/x%d' % \ 24 | (self.name2dir[self.dataset_name], self.scale)) 25 | 26 | if split == 'test': 27 | self.HR_images, self.LR_images, self.names = self._scan() 28 | self._getitem = self._getitem_test 29 | self.num = self.len_data = len(self.names) 30 | else: 31 | raise ValueError 32 | self.load_data() 33 | 34 | def _scan(self): 35 | fnames = [] 36 | list_hr = [] 37 | list_lr = [] 38 | for filename in os.listdir(self.hr_root): 39 | if not self.imio.is_image(filename): continue 40 | list_hr.append(join(self.hr_root, filename)) 41 | *fname, _, ext = filename.split('_') # e.g., 0801_HR_x2.png 42 | fname = '_'.join(fname) 43 | fnames.append(join(self.dataset_name, fname + '_SRBI_' + ext)) 44 | list_lr.append(join(self.lr_root, fname + '_LRBI_' + ext)) 45 | return list_hr, list_lr, fnames 46 | 47 | 48 | if __name__ == '__main__': 49 | pass 50 | -------------------------------------------------------------------------------- /data/div2k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import os 4 | from os.path import join 5 | from .sr_dataset import SRDataset 6 | 7 | class DIV2KDataset(SRDataset): 8 | def __init__(self, opt, split='train', dataset_name='div2k'): 9 | super(DIV2KDataset, self).__init__(opt, split, dataset_name) 10 | if self.root == '': 11 | rootlist = ['D:/Datasets/SR/DIV2K', 12 | '/data/DIV2K'] 13 | for root in rootlist: 14 | if os.path.isdir(root): 15 | self.root = root 16 | break 17 | self.patch_size = opt.patch_size 18 | self.patch_size_lr = self.patch_size // self.scale 19 | assert self.patch_size % self.scale == 0 20 | self.hr_root = join(self.root, 'DIV2K_train_HR') 21 | self.lr_root = join(self.root, 'DIV2K_train_LR_bicubic/X%d'%self.scale) 22 | 23 | if split == 'train': 24 | self.start, self.num = 1, 800 25 | self._getitem = self._getitem_train 26 | self.len_data = self.num * (opt.test_every // 27 | (self.num // self.batch_size)) 28 | else: 29 | if split == 'val': 30 | self.start, self.num = 801, 5 31 | else: 32 | raise ValueError 33 | self._getitem = self._getitem_test 34 | self.len_data = self.num 35 | 36 | self.names = ['%04d'%i for i in range(self.start, self.start+self.num)] 37 | self.HR_images = [join(self.hr_root, '%s.png'%(n)) for n in self.names] 38 | self.LR_images = [join(self.lr_root, '%sx%d.png' % (n, self.scale)) \ 39 | for n in self.names] 40 | self.names = [join('DIV2K_%s'%split, i+'_SRBI_x%d.png'%self.scale) \ 41 | for i in self.names] 42 | 43 | self.load_data() 44 | 45 | def _crop(self, HR, LR): 46 | ih, iw = LR.shape[-2:] 47 | ix = random.randrange(0, iw - self.patch_size_lr + 1) 48 | iy = random.randrange(0, ih - self.patch_size_lr + 1) 49 | tx, ty = self.scale * ix, self.scale * iy 50 | return HR[..., ty:ty+self.patch_size, tx:tx+self.patch_size], \ 51 | LR[..., iy:iy+self.patch_size_lr, ix:ix+self.patch_size_lr] 52 | 53 | def _augment_func(self, img, hflip, vflip, rot90): 54 | if hflip: img = img[:, :, ::-1] 55 | if vflip: img = img[:, ::-1, :] 56 | if rot90: img = img.transpose(0, 2, 1) # CHW 57 | return np.ascontiguousarray(img) 58 | 59 | def _augment(self, *imgs): 60 | hflip = random.random() < 0.5 61 | vflip = random.random() < 0.5 62 | rot90 = random.random() < 0.5 63 | return (self._augment_func(img, hflip, vflip, rot90) for img in imgs) 64 | 65 | 66 | if __name__ == '__main__': 67 | pass -------------------------------------------------------------------------------- /data/imlib.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import os 5 | import cv2 6 | from PIL import Image 7 | from functools import partial 8 | 9 | class imlib(): 10 | """ 11 | Note that YCxCx in OpenCV and PIL are different. 12 | Therefore, be careful if a model is trained with OpenCV and tested with 13 | PIL in Y mode, and vise versa 14 | 15 | force_color = True: return a 3 channel YCxCx image 16 | For mode 'Y', if a gray image is given, repeat the channel for 3 times, 17 | and then converted to YCxCx mode. 18 | force_color = False: return a 3 channel YCxCx image or a 1 channel gray one. 19 | For mode 'Y', if a gray image is given, the gray image is directly used. 20 | """ 21 | def __init__(self, mode='RGB', fmt='CHW', lib='cv2', force_color=True): 22 | assert mode.upper() in ('RGB', 'L', 'Y') 23 | self.mode = mode.upper() 24 | 25 | assert fmt.upper() in ('HWC', 'CHW', 'NHWC', 'NCHW') 26 | self.fmt = 'CHW' if fmt.upper() in ('CHW', 'NCHW') else 'HWC' 27 | 28 | assert lib.lower() in ('cv2', 'pillow') 29 | self.lib = lib.lower() 30 | 31 | self.force_color = force_color 32 | 33 | self.dtype = np.uint8 34 | 35 | self._imread = getattr(self, '_imread_%s_%s'%(self.lib, self.mode)) 36 | self._imwrite = getattr(self, '_imwrite_%s_%s'%(self.lib, self.mode)) 37 | self._trans_batch = getattr(self, '_trans_batch_%s_%s' 38 | % (self.mode, self.fmt)) 39 | self._trans_image = getattr(self, '_trans_image_%s_%s' 40 | % (self.mode, self.fmt)) 41 | self._trans_back = getattr(self, '_trans_back_%s_%s' 42 | % (self.mode, self.fmt)) 43 | 44 | def _imread_cv2_RGB(self, path): 45 | return cv2.imread(path, cv2.IMREAD_COLOR)[..., ::-1] 46 | def _imread_cv2_Y(self, path): 47 | if self.force_color: 48 | img = cv2.imread(path, cv2.IMREAD_COLOR) 49 | else: 50 | img = cv2.imread(path, cv2.IMREAD_ANYCOLOR) 51 | if len(img.shape) == 2: 52 | return np.expand_dims(img, 3) 53 | elif len(img.shape) == 3: 54 | return cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) 55 | else: 56 | raise ValueError('The dimension should be either 2 or 3.') 57 | def _imread_cv2_L(self, path): 58 | return cv2.imread(path, cv2.IMREAD_GRAYSCALE) 59 | 60 | def _imread_pillow_RGB(self, path): 61 | img = Image.open(path) 62 | im = np.array(img.convert(self.mode)) 63 | img.close() 64 | return im 65 | _imread_pillow_L = _imread_pillow_RGB 66 | # WARNING: the RGB->YCbCr procedure of PIL may be different with OpenCV 67 | def _imread_pillow_Y(self, path): 68 | img = Image.open(path) 69 | if img.mode == 'RGB': 70 | im = np.array(img.convert('YCbCr')) 71 | elif img.mode == 'L': 72 | if self.force_color: 73 | im = np.array(img.convert('RGB').convert('YCbCr')) 74 | else: 75 | im = np.expand_dims(np.array(img), 3) 76 | else: 77 | img.close() 78 | raise NotImplementedError('Only support RGB and gray images now.') 79 | img.close() 80 | return im 81 | 82 | def _imwrite_cv2_RGB(self, image, path): 83 | cv2.imwrite(path, image[..., ::-1]) 84 | def _imwrite_cv2_Y(self, image, path): 85 | if image.shape[2] == 1: 86 | cv2.imwrite(path, image[..., 0]) 87 | elif image.shape[2] == 3: 88 | cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)) 89 | else: 90 | raise ValueError('There should be 1 or 3 channels.') 91 | def _imwrite_cv2_L(self, image, path): 92 | cv2.imwrite(path, image) 93 | 94 | def _imwrite_pillow_RGB(self, image, path): 95 | Image.fromarray(image).save(path) 96 | _imwrite_pillow_L = _imwrite_pillow_RGB 97 | def _imwrite_pillow_Y(self, image, path): 98 | if image.shape[2] == 1: 99 | self._imwrite_pillow_L(np.squeeze(image, 2), path) 100 | elif image.shape[2] == 3: 101 | Image.fromarray(image, mode='YCbCr').convert('RGB').save(path) 102 | else: 103 | raise ValueError('There should be 1 or 3 channels.') 104 | 105 | def _trans_batch_RGB_HWC(self, images): 106 | return np.ascontiguousarray(images) 107 | def _trans_batch_RGB_CHW(self, images): 108 | return np.ascontiguousarray(np.transpose(images, (0, 3, 1, 2))) 109 | _trans_batch_Y_HWC = _trans_batch_RGB_HWC 110 | _trans_batch_Y_CHW = _trans_batch_RGB_CHW 111 | def _trans_batch_L_HWC(self, images): 112 | return np.ascontiguousarray(np.expand_dims(images, 3)) 113 | def _trans_batch_L_CHW(slef, images): 114 | return np.ascontiguousarray(np.expand_dims(images, 1)) 115 | 116 | def _trans_image_RGB_HWC(self, image): 117 | return np.ascontiguousarray(image) 118 | def _trans_image_RGB_CHW(self, image): 119 | return np.ascontiguousarray(np.transpose(image, (2, 0, 1))) 120 | _trans_image_Y_HWC = _trans_image_RGB_HWC 121 | _trans_image_Y_CHW = _trans_image_RGB_CHW 122 | def _trans_image_L_HWC(self, image): 123 | return np.ascontiguousarray(np.expand_dims(image, 2)) 124 | def _trans_image_L_CHW(self, image): 125 | return np.ascontiguousarray(np.expand_dims(image, 0)) 126 | 127 | def _trans_back_RGB_HWC(self, image): 128 | return image 129 | def _trans_back_RGB_CHW(self, image): 130 | return np.transpose(image, (1, 2, 0)) 131 | _trans_back_Y_HWC = _trans_back_RGB_HWC 132 | _trans_back_Y_CHW = _trans_back_RGB_CHW 133 | def _trans_back_L_HWC(self, image): 134 | return np.squeeze(image, 2) 135 | def _trans_back_L_CHW(self, image): 136 | return np.squeeze(image, 0) 137 | 138 | img_ext = ('png', 'PNG', 'jpg', 'JPG', 'bmp', 'BMP', 'jpeg', 'JPEG') 139 | 140 | def is_image(self, fname): 141 | return any(fname.endswith(i) for i in self.img_ext) 142 | 143 | def read(self, paths): 144 | if isinstance(paths, (list, tuple)): 145 | images = [self._imread(path) for path in paths] 146 | return self._trans_batch(np.array(images)) 147 | return self._trans_image(self._imread(paths)) 148 | 149 | def back(self, image): 150 | return self._trans_back(image) 151 | 152 | def write(self, image, path): 153 | os.makedirs(os.path.dirname(path), exist_ok=True) 154 | self._imwrite(self.back(image), path) 155 | 156 | if __name__ == '__main__': 157 | import matplotlib.pyplot as plt 158 | im_rgb_chw_cv2 = imlib('rgb', fmt='chw', lib='cv2') 159 | im_rgb_hwc_cv2 = imlib('rgb', fmt='hwc', lib='cv2') 160 | im_rgb_chw_pil = imlib('rgb', fmt='chw', lib='pillow') 161 | im_rgb_hwc_pil = imlib('rgb', fmt='hwc', lib='pillow') 162 | im_y_chw_cv2 = imlib('y', fmt='chw', lib='cv2') 163 | im_y_hwc_cv2 = imlib('y', fmt='hwc', lib='cv2') 164 | im_y_chw_pil = imlib('y', fmt='chw', lib='pillow') 165 | im_y_hwc_pil = imlib('y', fmt='hwc', lib='pillow') 166 | im_l_chw_cv2 = imlib('l', fmt='chw', lib='cv2') 167 | im_l_hwc_cv2 = imlib('l', fmt='hwc', lib='cv2') 168 | im_l_chw_pil = imlib('l', fmt='chw', lib='pillow') 169 | im_l_hwc_pil = imlib('l', fmt='hwc', lib='pillow') 170 | path = 'D:/Datasets/test/000001.jpg' 171 | 172 | img_rgb_chw_cv2 = im_rgb_chw_cv2.read(path) 173 | print(img_rgb_chw_cv2.shape) 174 | plt.imshow(im_rgb_chw_cv2.back(img_rgb_chw_cv2)) 175 | plt.show() 176 | im_rgb_chw_cv2.write(img_rgb_chw_cv2, 177 | (path.replace('000001.jpg', 'img_rgb_chw_cv2.jpg'))) 178 | img_rgb_hwc_cv2 = im_rgb_hwc_cv2.read(path) 179 | print(img_rgb_hwc_cv2.shape) 180 | plt.imshow(im_rgb_hwc_cv2.back(img_rgb_hwc_cv2)) 181 | plt.show() 182 | im_rgb_hwc_cv2.write(img_rgb_hwc_cv2, 183 | (path.replace('000001.jpg', 'img_rgb_hwc_cv2.jpg'))) 184 | img_rgb_chw_pil = im_rgb_chw_pil.read(path) 185 | print(img_rgb_chw_pil.shape) 186 | plt.imshow(im_rgb_chw_pil.back(img_rgb_chw_pil)) 187 | plt.show() 188 | im_rgb_chw_pil.write(img_rgb_chw_pil, 189 | (path.replace('000001.jpg', 'img_rgb_chw_pil.jpg'))) 190 | img_rgb_hwc_pil = im_rgb_hwc_pil.read(path) 191 | print(img_rgb_hwc_pil.shape) 192 | plt.imshow(im_rgb_hwc_pil.back(img_rgb_hwc_pil)) 193 | plt.show() 194 | im_rgb_hwc_pil.write(img_rgb_hwc_pil, 195 | (path.replace('000001.jpg', 'img_rgb_hwc_pil.jpg'))) 196 | 197 | 198 | img_y_chw_cv2 = im_y_chw_cv2.read(path) 199 | print(img_y_chw_cv2.shape) 200 | plt.imshow(np.squeeze(im_y_chw_cv2.back(img_y_chw_cv2))) 201 | plt.show() 202 | im_y_chw_cv2.write(img_y_chw_cv2, 203 | (path.replace('000001.jpg', 'img_y_chw_cv2.jpg'))) 204 | img_y_hwc_cv2 = im_y_hwc_cv2.read(path) 205 | print(img_y_hwc_cv2.shape) 206 | plt.imshow(np.squeeze(im_y_hwc_cv2.back(img_y_hwc_cv2))) 207 | plt.show() 208 | im_y_hwc_cv2.write(img_y_hwc_cv2, 209 | (path.replace('000001.jpg', 'img_y_hwc_cv2.jpg'))) 210 | img_y_chw_pil = im_y_chw_pil.read(path) 211 | print(img_y_chw_pil.shape) 212 | plt.imshow(np.squeeze(im_y_chw_pil.back(img_y_chw_pil))) 213 | plt.show() 214 | im_y_chw_pil.write(img_y_chw_pil, 215 | (path.replace('000001.jpg', 'img_y_chw_pil.jpg'))) 216 | img_y_hwc_pil = im_y_hwc_pil.read(path) 217 | print(img_y_hwc_pil.shape) 218 | plt.imshow(np.squeeze(im_y_hwc_pil.back(img_y_hwc_pil))) 219 | plt.show() 220 | im_y_hwc_pil.write(img_y_hwc_pil, 221 | (path.replace('000001.jpg', 'img_y_hwc_pil.jpg'))) 222 | 223 | 224 | img_l_chw_cv2 = im_l_chw_cv2.read(path) 225 | print(img_l_chw_cv2.shape) 226 | plt.imshow(im_l_chw_cv2.back(img_l_chw_cv2)) 227 | plt.show() 228 | im_l_chw_cv2.write(img_l_chw_cv2, 229 | (path.replace('000001.jpg', 'img_l_chw_cv2.jpg'))) 230 | img_l_hwc_cv2 = im_l_hwc_cv2.read(path) 231 | print(img_l_hwc_cv2.shape) 232 | plt.imshow(im_l_hwc_cv2.back(img_l_hwc_cv2)) 233 | plt.show() 234 | im_l_hwc_cv2.write(img_l_hwc_cv2, 235 | (path.replace('000001.jpg', 'img_l_hwc_cv2.jpg'))) 236 | img_l_chw_pil = im_l_chw_pil.read(path) 237 | print(img_l_chw_pil.shape) 238 | plt.imshow(im_l_chw_pil.back(img_l_chw_pil)) 239 | plt.show() 240 | im_l_chw_pil.write(img_l_chw_pil, 241 | (path.replace('000001.jpg', 'img_l_chw_pil.jpg'))) 242 | img_l_hwc_pil = im_l_hwc_pil.read(path) 243 | print(img_l_hwc_pil.shape) 244 | plt.imshow(im_l_hwc_pil.back(img_l_hwc_pil)) 245 | plt.show() 246 | im_l_hwc_pil.write(img_l_hwc_pil, 247 | (path.replace('000001.jpg', 'img_l_hwc_pil.jpg'))) 248 | -------------------------------------------------------------------------------- /data/sr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | from .imlib import imlib 6 | from os.path import join 7 | from data.base_dataset import BaseDataset 8 | 9 | class SRDataset(BaseDataset): 10 | def __init__(self, opt, split, dataset_name): 11 | super(SRDataset, self).__init__(opt, split, dataset_name) 12 | self.mode = opt.mode # RGB, Y or L 13 | self.imio = imlib(self.mode, lib=opt.imlib) 14 | self.scale = opt.scale 15 | self.preload = opt.preload 16 | self.batch_size = opt.batch_size 17 | self.lr_mode = opt.lr_mode 18 | if self.lr_mode == 'lr': 19 | self.lr_process = lambda lr_img: lr_img.astype(np.float32) 20 | else: 21 | self.lr_process = self.lr_process_sr 22 | 23 | self.getimage = self.getimage_read 24 | self.multi_imreader = opt.multi_imreader 25 | 26 | def load_data(self): 27 | if self.preload: 28 | if self.multi_imreader: 29 | read_images(self) 30 | else: 31 | self.HR_images = [self.imio.read(p) for p in self.HR_images] 32 | self.LR_images = [self.imio.read(p) for p in self.LR_images] 33 | self.getimage = self.getimage_preload 34 | 35 | def lr_process_sr(self, lr_img): 36 | if lr_img.shape[0] == 1: 37 | return np.expand_dims(cv2.resize(lr_img[0].astype(np.float32), 38 | dsize=(0, 0), fx=self.scale, fy=self.scale, 39 | interpolation=cv2.INTER_CUBIC), 0) 40 | return cv2.resize(lr_img.transpose(1, 2, 0).astype(np.float32), 41 | dsize=(0, 0), fx=self.scale, fy=self.scale, 42 | interpolation=cv2.INTER_CUBIC).transpose(2, 0, 1) 43 | 44 | def getimage_preload(self, index): 45 | return self.HR_images[index], self.LR_images[index], self.names[index] 46 | 47 | def getimage_read(self, index): 48 | return self.imio.read(self.HR_images[index]), \ 49 | self.imio.read(self.LR_images[index]), self.names[index] 50 | 51 | 52 | def _getitem_train(self, index): 53 | index = index % self.num 54 | hr_img, lr_img, f_name = self.getimage(index) 55 | hr_img, lr_img = self._crop(hr_img, lr_img) 56 | hr_img, lr_img = self._augment(hr_img, lr_img) 57 | return {'hr': hr_img.astype(np.float32), 58 | 'lr': self.lr_process(lr_img), 59 | 'fname': f_name} 60 | 61 | def _getitem_test(self, index): 62 | hr_img, lr_img, f_name = self.getimage(index) 63 | return {'hr': hr_img.astype(np.float32), 64 | 'lr': self.lr_process(lr_img), 65 | 'fname': f_name} 66 | 67 | 68 | def __getitem__(self, index): 69 | return self._getitem(index) 70 | 71 | def __len__(self): 72 | return self.len_data 73 | 74 | 75 | def iter_obj(num, objs): 76 | for i in range(num): 77 | yield (i, objs) 78 | 79 | def imreader(arg): 80 | i, obj = arg 81 | obj.HR_images[i] = obj.imio.read(obj.HR_images[i]) 82 | obj.LR_images[i] = obj.imio.read(obj.LR_images[i]) 83 | # for _ in range(3): 84 | # try: 85 | # obj.HR_images[i] = obj.imio.read(obj.HR_images[i]) 86 | # obj.LR_images[i] = obj.imio.read(obj.LR_images[i]) 87 | # failed = False 88 | # break 89 | # except: 90 | # failed = True 91 | # if failed: print('%s fails!' % obj.names[i]) 92 | 93 | def read_images(obj): 94 | # may use `from multiprocessing import Pool` instead, but less efficient and 95 | # NOTE: `multiprocessing.Pool` will duplicate given object for each process. 96 | from multiprocessing.dummy import Pool 97 | from tqdm import tqdm 98 | print('Starting to load images via multiple imreaders') 99 | pool = Pool() # use all threads by default 100 | for _ in tqdm(pool.imap(imreader, iter_obj(obj.num, obj)), total=obj.num): 101 | pass 102 | pool.close() 103 | pool.join() 104 | 105 | if __name__ == '__main__': 106 | pass 107 | -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaDSR/8997b5f978cb7fae2e61b6753a950dcae7ead470/figs/architecture.png -------------------------------------------------------------------------------- /figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaDSR/8997b5f978cb7fae2e61b6753a950dcae7ead470/figs/results.png -------------------------------------------------------------------------------- /figs/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaDSR/8997b5f978cb7fae2e61b6753a950dcae7ead470/figs/visualization.png -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: csmliu 4 | @e-mail: csmliu@outlook.com 5 | """ 6 | import numpy as np 7 | 8 | def Conv(in_shape, inc, outc, ks, stride=1, padding=None, 9 | groups=1, bias=True, mask=None): 10 | if padding is None: 11 | padding = ks//2 12 | if groups != 1: 13 | assert inc % groups == 0 and outc % groups == 0 14 | inc = inc // groups 15 | outc = outc // groups 16 | 17 | _per_pos = ks * ks * inc * outc * groups 18 | if mask is not None: 19 | assert all(in_shape == mask.shape) 20 | n_pos = (mask > 0).sum() 21 | else: 22 | n_pos = np.array(in_shape).prod() 23 | _sum = _per_pos * n_pos 24 | if bias: 25 | return _sum + n_pos * outc 26 | return _sum 27 | 28 | def BN(in_shape, inc): 29 | return np.array(in_shape).prod() * inc * 2 # affine 30 | 31 | def ReLU(in_shape, inc): 32 | return np.array(in_shape).prod() * inc 33 | 34 | def pixsf(in_shape, inc, scale): 35 | _sum_conv = Conv(in_shape, inc, inc*scale**2, 3) 36 | return np.array(in_shape).prod() * inc + _sum_conv 37 | 38 | def pool(in_shape, inc): 39 | return np.array(in_shape).prod() * inc 40 | 41 | def linear(inc, outc, bias=True): 42 | _sum = inc * outc 43 | if bias: 44 | return _sum + outc 45 | return _sum 46 | 47 | def upsample(in_shape, inc, scale=2): 48 | return (np.array(in_shape)*scale).prod() * inc 49 | 50 | def ResBlock(in_shape, inc, mode='CRC', mask=None): 51 | _sum = 0 52 | for m in mode: 53 | if m == 'C': 54 | _sum += Conv(in_shape, inc, inc, ks=3, mask=mask) 55 | elif m in 'RPL': 56 | _sum += ReLU(in_shape, inc) 57 | elif m == 'B': 58 | _sum += BN(in_shape, inc) 59 | else: 60 | print('mode %s is not supported in ResBlock.'%m) 61 | return _sum + np.array(in_shape).prod() * inc 62 | 63 | def CA(in_shape, inc): 64 | _sum = np.array(in_shape).prod() * inc # AvgPool 65 | _sum += linear(inc, inc//16) # 1st conv 66 | _sum += inc // 16 # ReLU 67 | _sum += linear(inc//16, inc) # 2nd conv 68 | _sum += inc // 16 # Sigmoid 69 | _sum += np.array(in_shape).prod() * inc 70 | return _sum 71 | 72 | def clip(x, layer): 73 | return np.clip(x, layer, layer+1) - layer 74 | 75 | class FLOPs(): 76 | @staticmethod 77 | def EDSR(in_shape, scale, mask=None, nb=32): 78 | _sum = 0 79 | _sum += Conv(in_shape, 3, 256, 3) 80 | if mask is None: 81 | _sum += ResBlock(in_shape, 256) * nb 82 | else: 83 | for i in range(nb): 84 | _sum += ResBlock(in_shape, 256, mask=clip(mask, i)) 85 | _sum += Conv(in_shape, 256, 256, 3) + in_shape.prod() * 256 86 | if scale == 3: 87 | _sum += pixsf(in_shape, 256, 3) 88 | in_shape *= 3 89 | else: 90 | assert scale in (2, 4) 91 | for i in range(1, scale, 2): 92 | _sum += pixsf(in_shape, 256, 2) 93 | in_shape *= 2 94 | _sum += Conv(in_shape, 256, 3, 3) 95 | return _sum 96 | 97 | @staticmethod 98 | def AdaEDSR(in_shape, scale, mask=None): 99 | return Conv(in_shape, 256, 128, 3) + \ 100 | Conv(in_shape, 128, 128, 3) * 3 + ReLU(in_shape, 128) * 4 + \ 101 | Conv(in_shape, 128, 1, 3) 102 | 103 | @staticmethod 104 | def AdaEDSR_fixd(in_shape, scale, mask=None): 105 | return Conv(in_shape, 256, 128, 3) + \ 106 | Conv(in_shape, 128, 128, 3) * 3 + ReLU(in_shape, 128) * 4 + \ 107 | Conv(in_shape, 128, 1, 3) 108 | 109 | @staticmethod 110 | def RCAN(in_shape, scale, mask=None): 111 | _sum = 0 112 | _sum += Conv(in_shape, 3, 64, 3) 113 | if mask is None: 114 | _sum += (ResBlock(in_shape, 64) + CA(in_shape, 64)) * 10 * 20 115 | _sum += (Conv(in_shape, 64, 64, 3) + in_shape.prod() * 64) * 11 116 | else: 117 | for i in range(mask.shape[0]): 118 | for j in range(20): 119 | _sum += ResBlock(in_shape, 64, mask=clip(mask[i], j)) 120 | _sum += CA(in_shape, 64) * 10 * 20 121 | _sum += (Conv(in_shape, 64, 64, 3) + in_shape.prod() * 64) * 11 122 | if scale == 3: 123 | _sum += pixsf(in_shape, 256, 3) 124 | in_shape *= 3 125 | else: 126 | assert scale in (2, 4) 127 | for i in range(1, scale, 2): 128 | _sum += pixsf(in_shape, 256, 2) 129 | in_shape *= 2 130 | _sum += Conv(in_shape, 64, 3, 3) 131 | return _sum 132 | 133 | @staticmethod 134 | def AdaRCAN(in_shape, scale, mask=None): 135 | return Conv(in_shape, 64, 64, 3) * 4 + ReLU(in_shape, 128) * 4 + \ 136 | Conv(in_shape, 64, 10, 3) 137 | 138 | 139 | @staticmethod 140 | def SRCNN(in_shape, scale, mask=None): 141 | _sum = 0 142 | _sum += Conv(in_shape, 1, 64, 9) + ReLU(in_shape, 64) 143 | _sum += Conv(in_shape, 64, 32, 5) + ReLU(in_shape, 32) 144 | _sum += Conv(in_shape, 32, 1, 5) 145 | return _sum 146 | 147 | @staticmethod 148 | def VDSR(in_shape, scale, mask=None): 149 | _sum = 0 150 | _sum += Conv(in_shape, 1, 64, 3) + ReLU(in_shape, 64) 151 | # NOTE that ReLU is omitted due to that there is no residual 152 | _sum += ResBlock(in_shape, 64, mode='C') * 18 153 | _sum += Conv(in_shape, 64, 1, 3) 154 | _sum += in_shape.prod() 155 | return _sum 156 | 157 | @staticmethod 158 | def RDN(in_shape, scale, mask=None): 159 | def RDB_Conv(in_shape, inc): 160 | _sum = 0 161 | _sum += Conv(in_shape, inc, 64, 3) + in_shape.prod() * 64 162 | _sum += in_shape.prod() * (inc+64) 163 | return _sum 164 | 165 | def RDB(in_shape): 166 | _sum = 0 167 | for i in range(8): 168 | _sum += RDB_Conv(in_shape, i*64+64) 169 | _sum += Conv(in_shape, 64*9, 64, 1) 170 | _sum += in_shape.prod() * 64 171 | return _sum 172 | 173 | _sum = 0 174 | _sum += Conv(in_shape, 3, 64, 3) 175 | _sum += Conv(in_shape, 64, 64, 3) 176 | _sum += RDB(in_shape) * 16 177 | _sum += Conv(in_shape, 16*64, 64, 1) + Conv(in_shape, 64, 64, 3) 178 | _sum += in_shape.prod() * 64 179 | if scale == 3: 180 | _sum += pixsf(in_shape, 256, 3) 181 | in_shape *= 3 182 | else: 183 | assert scale in (2, 4) 184 | for i in range(1, scale, 2): 185 | _sum += pixsf(in_shape, 256, 2) 186 | in_shape *= 2 187 | _sum += Conv(in_shape, 64, 3, 3) 188 | return _sum 189 | 190 | @staticmethod 191 | def SAN(in_shape, scale, mask=None): 192 | def SOCA(in_shape): 193 | def Covpool(in_shape): 194 | _sum = 0 195 | size = in_shape.prod() 196 | area = size ** 2 197 | # can be optimized to area + size 198 | _sum += area * 3 199 | _sum += size * size * size * 2 200 | return _sum 201 | def Sqrtm(in_shape, iterN=5): 202 | _sum = 0 203 | ch = 64 204 | _sum += ch*ch 205 | _sum += ch*ch*3 206 | _sum += ch*ch*3 207 | _sum += (iterN-2)*(ch*ch*5) 208 | _sum += (ch*ch*5) 209 | _sum += ch*ch 210 | return _sum 211 | _sum = 0 212 | in_shape = np.min([in_shape, np.array([1000, 1000])], axis=0) 213 | _sum += Covpool(in_shape) 214 | in_shape = np.array([64, 64]) 215 | _sum += Sqrtm(in_shape) 216 | _sum += in_shape.prod() 217 | in_shape = np.array([1, 1]) 218 | _sum += Conv(in_shape, 64, 64//16, 1)*2 + ReLU(in_shape, 64//16+64) 219 | return _sum 220 | 221 | def LSRAG(in_shape): 222 | def RB(in_shape): 223 | _sum = 0 224 | _sum += Conv(in_shape, 64, 64, 3) * 2 225 | _sum += ReLU(in_shape, 64) 226 | return _sum + in_shape.prod() * 64 227 | _sum = 0 228 | _sum += RB(in_shape) * 10 229 | _sum += SOCA(in_shape) 230 | _sum += Conv(in_shape, 64, 64, 3) 231 | return _sum + in_shape.prod() 232 | 233 | def Nonlocal(in_shape): 234 | def NB(in_shape): 235 | _sum = 0 236 | _sum += Conv(in_shape, 64, 32, 1) * 3 237 | _sum += in_shape.prod()**2 * 32 * 2 238 | _sum += ReLU(in_shape, 32) 239 | _sum += Conv(in_shape, 32, 64, 1) 240 | return _sum 241 | _sum = 0 242 | in_shape //= 2 243 | _sum += NB(in_shape) * 4 244 | return _sum 245 | _sum = 0 246 | _sum += Conv(in_shape, 3, 64, 3) 247 | _sum += Nonlocal(in_shape) * 2 248 | _sum += (LSRAG(in_shape) + in_shape.prod()*64) * 20 249 | if scale == 3: 250 | _sum += pixsf(in_shape, 256, 3) 251 | in_shape *= 3 252 | else: 253 | assert scale in (2, 4) 254 | for i in range(1, scale, 2): 255 | _sum += pixsf(in_shape, 256, 2) 256 | in_shape *= 2 257 | # Nonlocal has been calculated before 258 | _sum += in_shape.prod() * 64 259 | _sum += Conv(in_shape, 64, 3, 3) 260 | return _sum 261 | 262 | 263 | def find(name): 264 | for func in FLOPs.__dict__.keys(): 265 | if func.lower() == name.lower(): 266 | return func 267 | raise ValueError('No function named %s is found'%name) 268 | 269 | # def cvt(num): 270 | # units = ['', 'K', 'M', 'G', 'T', 'P', 'Z'] 271 | # cur = 0 272 | # while num > 1024: 273 | # cur += 1 274 | # num /= 1024 275 | # return '%.3f %s FLOPs' % (num, units[cur]) 276 | 277 | def cvt(num, binary=True): 278 | step = 1024 if binary else 1000 279 | return '%.2f GFLOPs' %(num / step**3) 280 | 281 | def chop(input_shape, shave=10, min_size=160000): 282 | h, w = input_shape 283 | h_half, w_half = h//2, w//2 284 | h_size, w_size = h_half+shave, w_half+shave 285 | if h_size * w_size < min_size: 286 | return np.array([np.array([h_size, w_size])]*4) 287 | else: 288 | ret = np.array([chop(np.array([h_size, w_size]))]*4) 289 | return ret 290 | 291 | def chop_pred(pred, shave=10, min_size=160000): 292 | if pred is None: return None 293 | h, w = pred.shape 294 | h_half, w_half = h//2, w//2 295 | h_size, w_size = h_half+shave, w_half+shave 296 | if h_size * w_size < min_size: 297 | return np.array([ 298 | pred[0:h_size, 0:w_size], 299 | pred[0:h_size, (w-w_size):w], 300 | pred[(h-h_size):h, 0:w_size], 301 | pred[(h-h_size):h, (w-w_size):w] 302 | ]) 303 | else: 304 | return np.array([ 305 | chop_pred(pred[0:h_size, 0:w_size]), 306 | chop_pred(pred[0:h_size, (w-w_size):w]), 307 | chop_pred(pred[(h-h_size):h, 0:w_size]), 308 | chop_pred(pred[(h-h_size):h, (w-w_size):w]) 309 | ]) 310 | 311 | 312 | methods = { 313 | 'hr': ['srcnn', 'vdsr'], 314 | 'lr': ['edsr', 'adaedsr', 'adaedsr_fixd', 'rdn', 'rcan', 'san', 'adarcan'], 315 | } 316 | methods = {i:j for j in methods.keys() for i in methods[j]} -------------------------------------------------------------------------------- /masked_conv2d/masked_conv2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .masked_conv import MaskedConv2d, masked_conv2d 2 | 3 | __all__ = ['masked_conv2d', 'MaskedConv2d'] 4 | -------------------------------------------------------------------------------- /masked_conv2d/masked_conv2d/masked_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.nn.modules.utils import _pair 8 | 9 | from . import masked_conv2d_cuda 10 | 11 | 12 | class MaskedConv2dFunction(Function): 13 | 14 | @staticmethod 15 | def forward(ctx, features, mask, weight, bias, padding=0, stride=1): 16 | assert mask.dim() == 3 and mask.size(0) == 1 17 | assert features.dim() == 4 and features.size(0) == 1 18 | assert features.size()[2:] == mask.size()[1:] 19 | pad_h, pad_w = _pair(padding) 20 | stride_h, stride_w = _pair(stride) 21 | if stride_h != 1 or stride_w != 1: 22 | raise ValueError( 23 | 'Stride could not only be 1 in masked_conv2d currently.') 24 | if not features.is_cuda: 25 | raise NotImplementedError 26 | 27 | out_channel, in_channel, kernel_h, kernel_w = weight.size() 28 | 29 | batch_size = features.size(0) 30 | out_h = int( 31 | math.floor((features.size(2) + 2 * pad_h - 32 | (kernel_h - 1) - 1) / stride_h + 1)) 33 | out_w = int( 34 | math.floor((features.size(3) + 2 * pad_w - 35 | (kernel_h - 1) - 1) / stride_w + 1)) 36 | mask_inds = torch.nonzero(mask[0] > 0) 37 | output = features.new_zeros(batch_size, out_channel, out_h, out_w) 38 | if mask_inds.numel() > 0: 39 | mask_h_idx = mask_inds[:, 0].contiguous() 40 | mask_w_idx = mask_inds[:, 1].contiguous() 41 | data_col = features.new_zeros(in_channel * kernel_h * kernel_w, 42 | mask_inds.size(0)) 43 | masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx, 44 | mask_w_idx, kernel_h, 45 | kernel_w, pad_h, pad_w, 46 | data_col) 47 | if bias is None: 48 | masked_output = torch.mm(weight.view(out_channel, -1), data_col) 49 | else: 50 | masked_output = torch.addmm(1, bias[:, None], 1, 51 | weight.view(out_channel, -1), 52 | data_col) 53 | masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx, 54 | mask_w_idx, out_h, out_w, 55 | out_channel, output) 56 | return output 57 | 58 | @staticmethod 59 | @once_differentiable 60 | def backward(ctx, grad_output): 61 | return (None, ) * 5 62 | 63 | 64 | masked_conv2d = MaskedConv2dFunction.apply 65 | 66 | 67 | class MaskedConv2d(nn.Conv2d): 68 | """A MaskedConv2d which inherits the official Conv2d. 69 | 70 | The masked forward doesn't implement the backward function and only 71 | supports the stride parameter to be 1 currently. 72 | """ 73 | 74 | def __init__(self, 75 | in_channels, 76 | out_channels, 77 | kernel_size, 78 | stride=1, 79 | padding=0, 80 | dilation=1, 81 | groups=1, 82 | bias=True): 83 | super(MaskedConv2d, 84 | self).__init__(in_channels, out_channels, kernel_size, stride, 85 | padding, dilation, groups, bias) 86 | 87 | def forward(self, input, mask=None): 88 | # if mask is None: # fallback to the normal Conv2d 89 | # return super(MaskedConv2d, self).forward(input) 90 | # else: 91 | # 92 | if mask is None: 93 | mask = torch.ones((1, *input.shape[-2:]), device=input.device) 94 | return masked_conv2d(input, mask, self.weight, self.bias, self.padding) 95 | -------------------------------------------------------------------------------- /masked_conv2d/masked_conv2d/src/masked_conv2d_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | int MaskedIm2colForwardLaucher(const at::Tensor im, const int height, 7 | const int width, const int channels, 8 | const int kernel_h, const int kernel_w, 9 | const int pad_h, const int pad_w, 10 | const at::Tensor mask_h_idx, 11 | const at::Tensor mask_w_idx, const int mask_cnt, 12 | at::Tensor col); 13 | 14 | int MaskedCol2imForwardLaucher(const at::Tensor col, const int height, 15 | const int width, const int channels, 16 | const at::Tensor mask_h_idx, 17 | const at::Tensor mask_w_idx, const int mask_cnt, 18 | at::Tensor im); 19 | 20 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 21 | #define CHECK_CONTIGUOUS(x) \ 22 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 23 | #define CHECK_INPUT(x) \ 24 | CHECK_CUDA(x); \ 25 | CHECK_CONTIGUOUS(x) 26 | 27 | int masked_im2col_forward_cuda(const at::Tensor im, const at::Tensor mask_h_idx, 28 | const at::Tensor mask_w_idx, const int kernel_h, 29 | const int kernel_w, const int pad_h, 30 | const int pad_w, at::Tensor col) { 31 | CHECK_INPUT(im); 32 | CHECK_INPUT(mask_h_idx); 33 | CHECK_INPUT(mask_w_idx); 34 | CHECK_INPUT(col); 35 | // im: (n, ic, h, w), kernel size (kh, kw) 36 | // kernel: (oc, ic * kh * kw), col: (kh * kw * ic, ow * oh) 37 | 38 | int channels = im.size(1); 39 | int height = im.size(2); 40 | int width = im.size(3); 41 | int mask_cnt = mask_h_idx.size(0); 42 | 43 | MaskedIm2colForwardLaucher(im, height, width, channels, kernel_h, kernel_w, 44 | pad_h, pad_w, mask_h_idx, mask_w_idx, mask_cnt, 45 | col); 46 | 47 | return 1; 48 | } 49 | 50 | int masked_col2im_forward_cuda(const at::Tensor col, 51 | const at::Tensor mask_h_idx, 52 | const at::Tensor mask_w_idx, int height, 53 | int width, int channels, at::Tensor im) { 54 | CHECK_INPUT(col); 55 | CHECK_INPUT(mask_h_idx); 56 | CHECK_INPUT(mask_w_idx); 57 | CHECK_INPUT(im); 58 | // im: (n, ic, h, w), kernel size (kh, kw) 59 | // kernel: (oc, ic * kh * kh), col: (kh * kw * ic, ow * oh) 60 | 61 | int mask_cnt = mask_h_idx.size(0); 62 | 63 | MaskedCol2imForwardLaucher(col, height, width, channels, mask_h_idx, 64 | mask_w_idx, mask_cnt, im); 65 | 66 | return 1; 67 | } 68 | 69 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 70 | m.def("masked_im2col_forward", &masked_im2col_forward_cuda, 71 | "masked_im2col forward (CUDA)"); 72 | m.def("masked_col2im_forward", &masked_col2im_forward_cuda, 73 | "masked_col2im forward (CUDA)"); 74 | } -------------------------------------------------------------------------------- /masked_conv2d/masked_conv2d/src/masked_conv2d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 5 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 6 | i += blockDim.x * gridDim.x) 7 | 8 | #define THREADS_PER_BLOCK 1024 9 | 10 | inline int GET_BLOCKS(const int N) { 11 | int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; 12 | int max_block_num = 65000; 13 | return min(optimal_block_num, max_block_num); 14 | } 15 | 16 | template 17 | __global__ void MaskedIm2colForward(const int n, const scalar_t *data_im, 18 | const int height, const int width, 19 | const int kernel_h, const int kernel_w, 20 | const int pad_h, const int pad_w, 21 | const int64_t *mask_h_idx, 22 | const int64_t *mask_w_idx, 23 | const int mask_cnt, scalar_t *data_col) { 24 | // mask_cnt * channels 25 | CUDA_1D_KERNEL_LOOP(index, n) { 26 | const int m_index = index % mask_cnt; 27 | const int h_col = mask_h_idx[m_index]; 28 | const int w_col = mask_w_idx[m_index]; 29 | const int c_im = index / mask_cnt; 30 | const int c_col = c_im * kernel_h * kernel_w; 31 | const int h_offset = h_col - pad_h; 32 | const int w_offset = w_col - pad_w; 33 | scalar_t *data_col_ptr = data_col + c_col * mask_cnt + m_index; 34 | for (int i = 0; i < kernel_h; ++i) { 35 | int h_im = h_offset + i; 36 | for (int j = 0; j < kernel_w; ++j) { 37 | int w_im = w_offset + j; 38 | if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 39 | *data_col_ptr = 40 | (scalar_t)data_im[(c_im * height + h_im) * width + w_im]; 41 | } else { 42 | *data_col_ptr = 0.0; 43 | } 44 | data_col_ptr += mask_cnt; 45 | } 46 | } 47 | } 48 | } 49 | 50 | int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height, 51 | const int width, const int channels, 52 | const int kernel_h, const int kernel_w, 53 | const int pad_h, const int pad_w, 54 | const at::Tensor mask_h_idx, 55 | const at::Tensor mask_w_idx, const int mask_cnt, 56 | at::Tensor top_data) { 57 | const int output_size = mask_cnt * channels; 58 | 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 60 | bottom_data.scalar_type(), "MaskedIm2colLaucherForward", ([&] { 61 | const scalar_t *bottom_data_ = bottom_data.data(); 62 | const int64_t *mask_h_idx_ = mask_h_idx.data(); 63 | const int64_t *mask_w_idx_ = mask_w_idx.data(); 64 | scalar_t *top_data_ = top_data.data(); 65 | MaskedIm2colForward 66 | <<>>( 67 | output_size, bottom_data_, height, width, kernel_h, kernel_w, 68 | pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_); 69 | })); 70 | THCudaCheck(cudaGetLastError()); 71 | return 1; 72 | } 73 | 74 | template 75 | __global__ void MaskedCol2imForward(const int n, const scalar_t *data_col, 76 | const int height, const int width, 77 | const int channels, 78 | const int64_t *mask_h_idx, 79 | const int64_t *mask_w_idx, 80 | const int mask_cnt, scalar_t *data_im) { 81 | CUDA_1D_KERNEL_LOOP(index, n) { 82 | const int m_index = index % mask_cnt; 83 | const int h_im = mask_h_idx[m_index]; 84 | const int w_im = mask_w_idx[m_index]; 85 | const int c_im = index / mask_cnt; 86 | // compute the start and end of the output 87 | data_im[(c_im * height + h_im) * width + w_im] = data_col[index]; 88 | } 89 | } 90 | 91 | int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height, 92 | const int width, const int channels, 93 | const at::Tensor mask_h_idx, 94 | const at::Tensor mask_w_idx, const int mask_cnt, 95 | at::Tensor top_data) { 96 | const int output_size = mask_cnt * channels; 97 | 98 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 99 | bottom_data.scalar_type(), "MaskedCol2imLaucherForward", ([&] { 100 | const scalar_t *bottom_data_ = bottom_data.data(); 101 | const int64_t *mask_h_idx_ = mask_h_idx.data(); 102 | const int64_t *mask_w_idx_ = mask_w_idx.data(); 103 | scalar_t *top_data_ = top_data.data(); 104 | 105 | MaskedCol2imForward 106 | <<>>( 107 | output_size, bottom_data_, height, width, channels, mask_h_idx_, 108 | mask_w_idx_, mask_cnt, top_data_); 109 | })); 110 | THCudaCheck(cudaGetLastError()); 111 | return 1; 112 | } 113 | -------------------------------------------------------------------------------- /masked_conv2d/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch, os 8 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 9 | from setuptools import setup, find_packages 10 | 11 | if torch.cuda.is_available(): 12 | assert torch.matmul(torch.ones(2097153,2).cuda(),torch.ones(2,2).cuda()).min().item()==2, 'Please upgrade from CUDA 9.0 to CUDA 10.0+' 13 | 14 | this_dir = os.path.dirname(os.path.realpath(__file__)) 15 | torch_dir = os.path.dirname(torch.__file__) 16 | conda_include_dir = '/'.join(torch_dir.split('/')[:-4]) + '/include' 17 | 18 | extra = {'cxx': ['-std=c++11', '-fopenmp'], 'nvcc': ['-std=c++11', '-Xcompiler', '-fopenmp']} 19 | 20 | setup( 21 | name='masked_conv2d', 22 | version='0.1', 23 | description='Mask Conv', 24 | packages=find_packages(), 25 | ext_modules=[ 26 | CUDAExtension('masked_conv2d.masked_conv2d_cuda', [ 'masked_conv2d/src/masked_conv2d_kernel.cu', 'masked_conv2d/src/masked_conv2d_cuda.cpp'], 27 | #include_dirs=[conda_include_dir],#, this_dir+'/'], 28 | extra_compile_args=extra)], 29 | cmdclass={'build_ext': BuildExtension}, 30 | zip_safe=False, 31 | ) 32 | 33 | -------------------------------------------------------------------------------- /models/MPNCOV/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaDSR/8997b5f978cb7fae2e61b6753a950dcae7ead470/models/MPNCOV/__init__.py -------------------------------------------------------------------------------- /models/MPNCOV/python/MPNCOV.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @file: MPNCOV.py 3 | @author: Jiangtao Xie 4 | @author: Peihua Li 5 | 6 | Copyright (C) 2018 Peihua Li and Jiangtao Xie 7 | 8 | All rights reserved. 9 | ''' 10 | ''' 11 | @update by: Ming Liu 12 | @email: csmliu@outlook.com 13 | @comment: Reduce memory useage and boost running speed 14 | ''' 15 | 16 | import torch 17 | import numpy as np 18 | from torch.autograd import Function 19 | 20 | class Covpool(Function): 21 | @staticmethod 22 | def forward(ctx, input): 23 | x = input 24 | batchSize = x.data.shape[0] 25 | dim = x.data.shape[1] 26 | h = x.data.shape[2] 27 | w = x.data.shape[3] 28 | M = h*w 29 | x = x.reshape(batchSize,dim,M) 30 | I_hat = torch.empty(M, M, device=x.device).fill_(-1./M/M) 31 | I_hat_diag = I_hat.diagonal() 32 | I_hat_diag += (1./M) 33 | y = x @ I_hat @ x.transpose(1,2) 34 | # I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 35 | # I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 36 | # y = x.bmm(I_hat).bmm(x.transpose(1,2)) 37 | ctx.save_for_backward(input,I_hat) 38 | return y 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | input,I_hat = ctx.saved_tensors 42 | x = input 43 | batchSize = x.data.shape[0] 44 | dim = x.data.shape[1] 45 | h = x.data.shape[2] 46 | w = x.data.shape[3] 47 | M = h*w 48 | x = x.reshape(batchSize,dim,M) 49 | grad_input = grad_output + grad_output.transpose(1,2) 50 | grad_input = grad_input @ x @ I_hat 51 | # grad_input = grad_input.bmm(x).bmm(I_hat) 52 | grad_input = grad_input.reshape(batchSize,dim,h,w) 53 | return grad_input 54 | 55 | class Sqrtm(Function): 56 | @staticmethod 57 | def forward(ctx, input, iterN): 58 | x = input 59 | batchSize = x.data.shape[0] 60 | dim = x.data.shape[1] 61 | dtype = x.dtype 62 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 63 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 64 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 65 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device) 66 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1) 67 | if iterN < 2: 68 | ZY = 0.5*(I3 - A) 69 | Y[:,0,:,:] = A.bmm(ZY) 70 | else: 71 | ZY = 0.5*(I3 - A) 72 | Y[:,0,:,:] = A.bmm(ZY) 73 | Z[:,0,:,:] = ZY 74 | for i in range(1, iterN-1): 75 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 76 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 77 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 78 | ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) 79 | y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 80 | ctx.save_for_backward(input, A, ZY, normA, Y, Z) 81 | ctx.iterN = iterN 82 | return y 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 86 | iterN = ctx.iterN 87 | x = input 88 | batchSize = x.data.shape[0] 89 | dim = x.data.shape[1] 90 | dtype = x.dtype 91 | der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 92 | der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) 93 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 94 | if iterN < 2: 95 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) 96 | else: 97 | dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - 98 | Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) 99 | dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) 100 | for i in range(iterN-3, -1, -1): 101 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 102 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 103 | dldY_ = 0.5*(dldY.bmm(YZ) - 104 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 105 | ZY.bmm(dldY)) 106 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 107 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 108 | dldZ.bmm(ZY)) 109 | dldY = dldY_ 110 | dldZ = dldZ_ 111 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 112 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 113 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 114 | for i in range(batchSize): 115 | grad_input[i,:,:] += (der_postComAux[i] \ 116 | - grad_aux[i] / (normA[i] * normA[i])) \ 117 | *torch.ones(dim,device = x.device).diag() 118 | return grad_input, None 119 | 120 | class Triuvec(Function): 121 | @staticmethod 122 | def forward(ctx, input): 123 | x = input 124 | batchSize = x.data.shape[0] 125 | dim = x.data.shape[1] 126 | dtype = x.dtype 127 | x = x.reshape(batchSize, dim*dim) 128 | I = torch.ones(dim,dim).triu().t().reshape(dim*dim) 129 | index = I.nonzero() 130 | y = torch.zeros(batchSize,dim*(dim+1)/2,device = x.device) 131 | for i in range(batchSize): 132 | y[i, :] = x[i, index].t() 133 | ctx.save_for_backward(input,index) 134 | return y 135 | @staticmethod 136 | def backward(ctx, grad_output): 137 | input,index = ctx.saved_tensors 138 | x = input 139 | batchSize = x.data.shape[0] 140 | dim = x.data.shape[1] 141 | dtype = x.dtype 142 | grad_input = torch.zeros(batchSize,dim,dim,device = x.device,requires_grad=False) 143 | grad_input = grad_input.reshape(batchSize,dim*dim) 144 | for i in range(batchSize): 145 | grad_input[i,index] = grad_output[i,:].reshape(index.size(),1) 146 | grad_input = grad_input.reshape(batchSize,dim,dim) 147 | return grad_input 148 | 149 | def CovpoolLayer(var): 150 | return Covpool.apply(var) 151 | 152 | def SqrtmLayer(var, iterN): 153 | return Sqrtm.apply(var, iterN) 154 | 155 | def TriuvecLayer(var): 156 | return Triuvec.apply(var) 157 | -------------------------------------------------------------------------------- /models/MPNCOV/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaDSR/8997b5f978cb7fae2e61b6753a950dcae7ead470/models/MPNCOV/python/__init__.py -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | """Import the module "models/[model_name]_model.py". 7 | 8 | In the file, the class called DatasetNameModel() will 9 | be instantiated. It has to be a subclass of BaseModel, 10 | and it is case-insensitive. 11 | """ 12 | model_filename = "models." + model_name + "_model" 13 | modellib = importlib.import_module(model_filename) 14 | model = None 15 | target_model_name = model_name.replace('_', '') + 'model' 16 | for name, cls in modellib.__dict__.items(): 17 | if name.lower() == target_model_name.lower() \ 18 | and issubclass(cls, BaseModel): 19 | model = cls 20 | 21 | if model is None: 22 | raise NotImplementedError("In %s.py, there should be a subclass of " 23 | "BaseModel with class name that matches %s in " 24 | "lowercase." % (model_filename, target_model_name)) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | """Create a model given the option. 36 | 37 | This function warps the class CustomDatasetDataLoader. 38 | This is the main interface between this package and 'train.py'/'test.py' 39 | 40 | Example: 41 | >>> from models import create_model 42 | >>> model = create_model(opt) 43 | """ 44 | model = find_model_using_name(opt.model) 45 | instance = model(opt) 46 | print("model [%s] was created" % type(instance).__name__) 47 | return instance 48 | -------------------------------------------------------------------------------- /models/adaedsr_fixd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class AdaEDSRFixDModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_resblocks = 32, 10 | n_feats = 256, 11 | block_mode = 'CRC', 12 | nc_adapter = 1, 13 | constrain = 'soft', 14 | depth = [1], 15 | adapter_reduction = 2, 16 | ) 17 | return parser 18 | 19 | def __init__(self, opt): 20 | super(AdaEDSRFixDModel, self).__init__(opt, SRModel=SRModel) 21 | assert len(opt.depth) == 1 22 | self.model_names = ['AdaEDSRFixD'] 23 | self.optimizer_names = ['AdaEDSRFixD_optimizer_%s' % opt.optimizer] 24 | self.netAdaEDSRFixD = self.netDSR 25 | 26 | 27 | class SRModel(base_SRModel): 28 | def __init__(self, opt): 29 | self.block = AdaResBlock 30 | self.n_blocks = opt.n_resblocks 31 | self.block_name = 'block' 32 | super(SRModel, self).__init__(opt) -------------------------------------------------------------------------------- /models/adaedsr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class AdaEDSRModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_resblocks = 32, 10 | n_feats = 256, 11 | block_mode = 'CRC', 12 | nc_adapter = 1, 13 | constrain = 'soft', 14 | depth = [0, 32], 15 | adapter_pos = 5, 16 | adapter_reduction = 2, 17 | ) 18 | return parser 19 | 20 | def __init__(self, opt): 21 | super(AdaEDSRModel, self).__init__(opt, SRModel=SRModel) 22 | assert self.isTrain and len(opt.depth) == 2 or \ 23 | not self.isTrain and len(opt.depth) == 1 24 | self.model_names = ['AdaEDSR'] 25 | self.optimizer_names = ['AdaEDSR_optimizer_%s' % opt.optimizer] 26 | self.netAdaEDSR = self.netDSR 27 | 28 | 29 | class SRModel(base_SRModel): 30 | def __init__(self, opt): 31 | self.block = AdaResBlock 32 | self.n_blocks = opt.n_resblocks 33 | self.block_name = 'block' 34 | super(SRModel, self).__init__(opt) -------------------------------------------------------------------------------- /models/adarcan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class AdaRCANModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_groups = 10, 10 | n_resblocks = 20, 11 | n_feats = 64, 12 | block_mode = 'CRC', 13 | channel_attention = 'ca', 14 | nc_adapter = 10, 15 | constrain = 'soft', 16 | adapter_pos = 0, 17 | adapter_reduction = 1, 18 | depth = [0.1, 20], 19 | lambda_pred = 0.03, 20 | ) 21 | return parser 22 | 23 | def __init__(self, opt): 24 | super(AdaRCANModel, self).__init__(opt, SRModel=SRModel) 25 | assert self.isTrain and len(opt.depth) == 2 or \ 26 | not self.isTrain and len(opt.depth) == 1 27 | self.model_names = ['AdaRCAN'] 28 | self.optimizer_names = ['AdaRCAN_optimizer_%s' % opt.optimizer] 29 | self.netAdaRCAN = self.netDSR 30 | 31 | 32 | class SRModel(base_SRModel): 33 | def __init__(self, opt): 34 | self.block = AdaRCAGroup 35 | self.n_blocks = opt.n_groups 36 | self.block_name = 'group' 37 | super(SRModel, self).__init__(opt) 38 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | import numpy as np 7 | import torch.nn as nn 8 | 9 | class BaseModel(ABC): 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.isTrain = opt.isTrain 14 | self.scale = opt.scale 15 | 16 | if len(self.gpu_ids) > 0: 17 | self.device = torch.device('cuda', self.gpu_ids[0]) 18 | else: 19 | self.device = torch.device('cpu') 20 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 21 | self.loss_names = [] 22 | self.model_names = [] 23 | self.visual_names = [] 24 | self.optimizers = [] 25 | self.optimizer_names = [] 26 | self.image_paths = [] 27 | self.metric = 0 # used for learning rate policy 'plateau' 28 | self.start_epoch = 0 29 | 30 | @staticmethod 31 | def modify_commandline_options(parser, is_train): 32 | return parser 33 | 34 | @abstractmethod 35 | def set_input(self, input): 36 | pass 37 | 38 | @abstractmethod 39 | def forward(self): 40 | pass 41 | 42 | @abstractmethod 43 | def optimize_parameters(self): 44 | pass 45 | 46 | def setup(self, opt=None): 47 | opt = opt if opt is not None else self.opt 48 | if self.isTrain: 49 | self.schedulers = [networks.get_scheduler(optimizer, opt) \ 50 | for optimizer in self.optimizers] 51 | for scheduler in self.schedulers: 52 | scheduler.last_epoch = opt.load_iter 53 | if opt.load_iter > 0 or opt.load_path != '': 54 | load_suffix = opt.load_iter 55 | if opt.model == 'rdn' or opt.model == 'vdsr': 56 | self.load_networks_rdn(load_suffix) 57 | else: 58 | self.load_networks(load_suffix) 59 | if opt.load_optimizers: 60 | self.load_optimizers(opt.load_iter) 61 | 62 | self.print_networks(opt.verbose) 63 | 64 | def eval(self): 65 | for name in self.model_names: 66 | net = getattr(self, 'net' + name) 67 | net.eval() 68 | 69 | def train(self): 70 | for name in self.model_names: 71 | net = getattr(self, 'net' + name) 72 | net.train() 73 | 74 | def test(self, FLOPs_only=False): 75 | with torch.no_grad(): 76 | try: 77 | self.forward(FLOPs_only) 78 | except: 79 | self.forward() 80 | 81 | def get_image_paths(self): 82 | return self.image_paths 83 | 84 | def update_learning_rate(self): 85 | for i, scheduler in enumerate(self.schedulers): 86 | if scheduler.__class__.__name__ == 'ReduceLROnPlateau': 87 | scheduler.step(self.metric) 88 | else: 89 | scheduler.step() 90 | print('lr of %s = %.7f' % ( 91 | self.optimizer_names[i], scheduler.get_lr()[0])) 92 | 93 | def get_current_visuals(self): 94 | visual_ret = OrderedDict() 95 | for name in self.visual_names: 96 | if name == 'pred': 97 | visual_ret[name] = getattr(self, name).detach() 98 | else: 99 | visual_ret[name] = torch.clamp( 100 | getattr(self, name).detach(), 0, 255).round() 101 | return visual_ret 102 | 103 | def get_current_losses(self): 104 | errors_ret = OrderedDict() 105 | for name in self.loss_names: 106 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 107 | return errors_ret 108 | 109 | def save_networks(self, epoch): 110 | for name in self.model_names: 111 | save_filename = '%s_model_%d.pth' % (name, epoch) 112 | save_path = os.path.join(self.save_dir, save_filename) 113 | net = getattr(self, 'net' + name) 114 | if self.device.type == 'cuda': 115 | state = {'scale': self.scale, 116 | 'state_dict': net.module.cpu().state_dict()} 117 | torch.save(state, save_path) 118 | net.to(self.device) 119 | else: 120 | state = {'scale': self.scale, 121 | 'state_dict': net.state_dict()} 122 | torch.save(state, save_path) 123 | self.save_optimizers(epoch) 124 | 125 | def load_networks_rdn(self, epoch): 126 | for name in self.model_names: 127 | load_filename = '%s_model_%d.pth' % (name, epoch) 128 | if self.opt.load_path != '': 129 | load_path = self.opt.load_path 130 | else: 131 | load_path = os.path.join(self.save_dir, load_filename) 132 | net = getattr(self, 'net' + name) 133 | if isinstance(net, torch.nn.DataParallel): 134 | net = net.module 135 | 136 | state_dict = torch.load(load_path, map_location=self.device) 137 | if hasattr(state_dict, '_metadata'): 138 | del state_dict._metadata 139 | 140 | print('loading the model from %s' % (load_path)) 141 | 142 | net_state = net.state_dict() 143 | 144 | net_name = [] 145 | for name, param in net_state.items(): 146 | net_name.append(name) 147 | is_loaded = {n:False for n in net_state.keys()} 148 | 149 | for idx, (name, param) in enumerate(state_dict.items()): 150 | try: 151 | net_state[net_name[idx]].copy_(param) 152 | is_loaded[net_name[idx]] = True 153 | except Exception: 154 | if name.find('UPNet') != -1: 155 | continue 156 | raise RuntimeError( 157 | 'While copying the parameter named [%s], ' 158 | 'whose dimensions in the model are %s and ' 159 | 'whose dimensions in the checkpoint are %s.' 160 | % (name, list(net_state[name].shape), 161 | list(param.shape))) 162 | 163 | mark = True 164 | for name in is_loaded: 165 | if not is_loaded[name]: 166 | print('Parameter named [%s] is randomly initialized' % name) 167 | mark = False 168 | if mark: 169 | print('All parameters are initialized using [%s]' % load_path) 170 | 171 | self.start_epoch = epoch 172 | 173 | def load_networks(self, epoch): 174 | for name in self.model_names: 175 | load_filename = '%s_model_%d.pth' % (name, epoch) 176 | if self.opt.load_path != '': 177 | load_path = self.opt.load_path 178 | else: 179 | load_path = os.path.join(self.save_dir, load_filename) 180 | net = getattr(self, 'net' + name) 181 | if isinstance(net, torch.nn.DataParallel): 182 | net = net.module 183 | state_dict = torch.load(load_path, map_location=self.device) 184 | print('loading the model from %s (scale: %s)' 185 | % (load_path, state_dict['scale'])) 186 | if hasattr(state_dict, '_metadata'): 187 | del state_dict._metadata 188 | 189 | net_state = net.state_dict() 190 | is_loaded = {n:False for n in net_state.keys()} 191 | for name, param in state_dict['state_dict'].items(): 192 | if name in net_state: 193 | try: 194 | net_state[name].copy_(param) 195 | is_loaded[name] = True 196 | except Exception: 197 | print('While copying the parameter named [%s], ' 198 | 'whose dimensions in the model are %s and ' 199 | 'whose dimensions in the checkpoint are %s.' 200 | % (name, list(net_state[name].shape), 201 | list(param.shape))) 202 | if name.find('up') != -1: 203 | continue 204 | raise RuntimeError 205 | else: 206 | print('Saved parameter named [%s] is skipped' % name) 207 | mark = True 208 | for name in is_loaded: 209 | if not is_loaded[name]: 210 | print('Parameter named [%s] is randomly initialized' % name) 211 | mark = False 212 | if mark: 213 | print('All parameters are initialized using [%s]' % load_path) 214 | 215 | self.start_epoch = epoch 216 | 217 | def save_optimizers(self, epoch): 218 | assert len(self.optimizers) == len(self.optimizer_names) 219 | for id, optimizer in enumerate(self.optimizers): 220 | save_filename = self.optimizer_names[id] 221 | state = {'name': save_filename, 222 | 'epoch': epoch, 223 | 'state_dict': optimizer.state_dict()} 224 | save_path = os.path.join(self.save_dir, save_filename+'.pth') 225 | torch.save(state, save_path) 226 | 227 | def load_optimizers(self, epoch): 228 | assert len(self.optimizers) == len(self.optimizer_names) 229 | for id, optimizer in enumerate(self.optimizer_names): 230 | load_filename = self.optimizer_names[id] 231 | load_path = os.path.join(self.save_dir, load_filename+'.pth') 232 | print('loading the optimizer from %s' % load_path) 233 | state_dict = torch.load(load_path) 234 | assert optimizer == state_dict['name'] 235 | assert epoch == state_dict['epoch'] 236 | self.optimizers[id].load_state_dict(state_dict['state_dict']) 237 | 238 | def print_networks(self, verbose): 239 | print('---------- Networks initialized -------------') 240 | for name in self.model_names: 241 | if isinstance(name, str): 242 | net = getattr(self, 'net' + name) 243 | num_params = 0 244 | for param in net.parameters(): 245 | num_params += param.numel() 246 | if verbose: 247 | print(net) 248 | print('[Network %s] Total number of parameters : %.3f M' 249 | % (name, num_params / 1e6)) 250 | print('-----------------------------------------------') 251 | 252 | def set_requires_grad(self, nets, requires_grad=False): 253 | if not isinstance(nets, list): 254 | nets = [nets] 255 | for net in nets: 256 | if net is not None: 257 | for param in net.parameters(): 258 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | # Copied from the official SAN repo, only used for implementing SAN in our repo. 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, padding=0, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | class MeanShift(nn.Conv2d): 14 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 15 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 16 | std = torch.Tensor(rgb_std) 17 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 18 | self.weight.data.div_(std.view(3, 1, 1, 1)) 19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 20 | self.bias.data.div_(std) 21 | self.requires_grad = False 22 | 23 | class BasicBlock(nn.Sequential): 24 | def __init__( 25 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 26 | bn=True, act=nn.ReLU(True)): 27 | 28 | m = [nn.Conv2d( 29 | in_channels, out_channels, kernel_size, 30 | padding=(kernel_size//2), stride=stride, bias=bias) 31 | ] 32 | if bn: m.append(nn.BatchNorm2d(out_channels)) 33 | if act is not None: m.append(act) 34 | super(BasicBlock, self).__init__(*m) 35 | 36 | class ResBlock(nn.Module): 37 | def __init__( 38 | self, conv, n_feat, kernel_size, 39 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 40 | 41 | super(ResBlock, self).__init__() 42 | m = [] 43 | for i in range(2): 44 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 45 | if bn: m.append(nn.BatchNorm2d(n_feat)) 46 | if i == 0: m.append(act) 47 | 48 | self.body = nn.Sequential(*m) 49 | self.res_scale = res_scale 50 | 51 | def forward(self, x): 52 | res = self.body(x).mul(self.res_scale) 53 | res += x 54 | 55 | return res 56 | 57 | class Upsampler(nn.Sequential): 58 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 59 | 60 | m = [] 61 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 62 | for _ in range(int(math.log(scale, 2))): 63 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 64 | m.append(nn.PixelShuffle(2)) 65 | if bn: m.append(nn.BatchNorm2d(n_feat)) 66 | if act: m.append(act()) 67 | elif scale == 3: 68 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 69 | m.append(nn.PixelShuffle(3)) 70 | if bn: m.append(nn.BatchNorm2d(n_feat)) 71 | if act: m.append(act()) 72 | else: 73 | raise NotImplementedError 74 | 75 | super(Upsampler, self).__init__(*m) 76 | 77 | ## add SELayer 78 | class SELayer(nn.Module): 79 | def __init__(self, channel, reduction=16): 80 | super(SELayer, self).__init__() 81 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 82 | self.conv_du = nn.Sequential( 83 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 86 | nn.Sigmoid() 87 | ) 88 | 89 | def forward(self, x): 90 | y = self.avg_pool(x) 91 | y = self.conv_du(y) 92 | return x * y 93 | 94 | ## add SEResBlock 95 | class SEResBlock(nn.Module): 96 | def __init__( 97 | self, conv, n_feat, kernel_size, reduction, 98 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 99 | 100 | super(SEResBlock, self).__init__() 101 | modules_body = [] 102 | for i in range(2): 103 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 104 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 105 | if i == 0: modules_body.append(act) 106 | modules_body.append(SELayer(n_feat, reduction)) 107 | self.body = nn.Sequential(*modules_body) 108 | self.res_scale = res_scale 109 | 110 | def forward(self, x): 111 | res = self.body(x) 112 | #res = self.body(x).mul(self.res_scale) 113 | res += x 114 | 115 | return res -------------------------------------------------------------------------------- /models/dsr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks as N 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import math 7 | import torch.nn.functional as F 8 | from . import losses as L 9 | 10 | 11 | class DSRModel(BaseModel): 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | return parser 15 | 16 | def __init__(self, opt, SRModel=None): 17 | super(DSRModel, self).__init__(opt) 18 | 19 | self.opt = opt 20 | self.nc_adapter = opt.nc_adapter 21 | self.constrain = opt.constrain 22 | 23 | if self.nc_adapter != 0: 24 | self.loss_names = [opt.loss, 'Pred', 'Total'] 25 | self.visual_names = ['data_lr', 'data_hr', 'data_sr', 'pred'] 26 | else: 27 | self.loss_names = [opt.loss, 'Total'] 28 | self.visual_names = ['data_lr', 'data_hr', 'data_sr'] 29 | self.model_names = ['DSR'] # will rename in subclasses 30 | self.optimizer_names = ['DSR_optimizer_%s' % opt.optimizer] 31 | 32 | DSR = SRModel(opt) 33 | self.netDSR = N.init_net(DSR, opt.init_type, opt.init_gain, opt.gpu_ids) 34 | 35 | if self.constrain != 'none': 36 | self.depth_gen = N.num_generator(opt.depth) 37 | else: 38 | self.depth_gen = None 39 | self.depth = None 40 | 41 | if self.isTrain: 42 | if opt.optimizer == 'Adam': 43 | self.optimizer = optim.Adam(self.netDSR.parameters(), 44 | lr=opt.lr, 45 | betas=(opt.beta1, opt.beta2), 46 | weight_decay=opt.weight_decay) 47 | elif opt.optimizer == 'SGD': 48 | self.optimizer = optim.SGD(self.netDSR.parameters(), 49 | lr=opt.lr, 50 | momentum=opt.momentum, 51 | weight_decay=opt.weight_decay) 52 | elif opt.optimizer == 'RMSprop': 53 | self.optimizer = optim.RMSprop(self.netDSR.parameters(), 54 | lr=opt.lr, 55 | alpha=opt.alpha, 56 | momentum=opt.momentum, 57 | weight_decay=opt.weight_decay) 58 | else: 59 | raise NotImplementedError( 60 | 'optimizer named [%s] is not supported' % opt.optimizer) 61 | 62 | self.optimizers = [self.optimizer] 63 | 64 | def set_input(self, input): 65 | self.data_lr = input['lr'].to(self.device) 66 | self.data_hr = input['hr'].to(self.device) 67 | self.image_paths = input['fname'] 68 | if self.depth_gen is not None: 69 | batch_size = self.data_lr.shape[0] 70 | self.depth = self.depth_gen((batch_size, 1), device=self.device) 71 | 72 | def forward(self, FLOPs_only=False): 73 | if self.isTrain: 74 | self.data_sr, self.pred, *losses = \ 75 | self.netDSR(self.data_lr, self.data_hr, self.depth, FLOPs_only) 76 | for i, loss_name in enumerate(self.loss_names): 77 | setattr(self, 'loss_'+loss_name, losses[i].mean()) 78 | elif self.opt.model.lower() in ('adaedsr', 'adarcan', 'adaedsr_fixd'): 79 | # We write a chop function for AdaEDSR and AdaRCAN for running 80 | # the adapter only once, see `class base_SRModel` for details. 81 | self.data_sr, self.pred = self.netDSR(self.data_lr, 82 | depth=self.depth, 83 | FLOPs_only=FLOPs_only, 84 | chop=self.opt.chop) 85 | else: 86 | if not self.opt.chop: 87 | self.data_sr, self.pred = self.netDSR(self.data_lr, 88 | depth=self.depth) 89 | else: 90 | self.data_sr, self.pred = N.forward_chop(self.opt, self.netDSR, 91 | self.data_lr, self.depth, shave=10, min_size=160000) 92 | 93 | def backward(self): 94 | self.loss_Total.backward() 95 | 96 | def optimize_parameters(self): 97 | self.forward() 98 | self.optimizer.zero_grad() 99 | self.backward() 100 | self.optimizer.step() 101 | 102 | 103 | class base_SRModel(nn.Module): 104 | def __init__(self, opt): 105 | super(base_SRModel, self).__init__() 106 | 107 | self.opt = opt 108 | self.lambda_pred = opt.lambda_pred 109 | self.nc_adapter = opt.nc_adapter 110 | self.multi_adapter = opt.multi_adapter 111 | self.constrain = opt.constrain 112 | self.with_depth = opt.with_depth 113 | self.scale = opt.scale 114 | 115 | if self.nc_adapter > 0 and self.multi_adapter: 116 | assert self.n_blocks == self.nc_adapter 117 | 118 | n_feats = opt.n_feats 119 | n_upscale = int(math.log(opt.scale, 2)) 120 | 121 | m_head = [N.MeanShift(), 122 | N.conv(opt.input_nc, n_feats, mode='C')] 123 | self.head = N.seq(m_head) 124 | 125 | for i in range(self.n_blocks): 126 | setattr(self, '%s%d'%(self.block_name, i), self.block( 127 | n_feats, n_feats, res_scale=opt.res_scale, mode=opt.block_mode, 128 | clamp=self.clamp_wrapper(i) if self.nc_adapter != 0 else None, 129 | channel_attention=opt.channel_attention, 130 | sparse_conv=opt.sparse_conv, 131 | n_resblocks=opt.n_resblocks, 132 | clamp_wrapper=self.clamp_wrapper, 133 | side_ca=opt.side_ca)) 134 | if self.nc_adapter != 0 and self.multi_adapter: 135 | setattr(self, 'predictor%d'%i, Predictor( 136 | n_feats=n_feats, n_layers=opt.adapter_layers, 137 | reduction=opt.adapter_reduction, 138 | hard_constrain=(self.constrain=='hard'), 139 | nc_adapter=1, 140 | depth_pos=opt.adapter_pos, 141 | upper_bound=opt.adapter_bound)) 142 | self.body_lastconv = N.conv(n_feats, n_feats, mode='C') 143 | 144 | if opt.scale == 3: 145 | m_up = N.upsample_pixelshuffle(n_feats, n_feats, mode='3') 146 | else: 147 | m_up = [N.upsample_pixelshuffle(n_feats, n_feats, mode='2') \ 148 | for _ in range(n_upscale)] 149 | self.up = N.seq(m_up) 150 | 151 | m_tail = [N.conv(n_feats, opt.output_nc, mode='C'), 152 | N.MeanShift(sign=1)] 153 | self.tail = N.seq(m_tail) 154 | 155 | if self.nc_adapter != 0 and not self.multi_adapter: 156 | assert self.nc_adapter in (1, self.n_blocks) 157 | self.predictor = Predictor( 158 | n_feats=n_feats, n_layers=opt.adapter_layers, 159 | reduction=opt.adapter_reduction, 160 | hard_constrain=(self.constrain=='hard'), 161 | nc_adapter=self.nc_adapter, 162 | depth_pos=opt.adapter_pos, 163 | upper_bound=opt.adapter_bound) 164 | 165 | self.isTrain = opt.isTrain 166 | self.loss = opt.loss 167 | if self.isTrain: 168 | setattr(self, 'criterion%s'%self.loss, 169 | getattr(L, '%sLoss'%self.loss)()) 170 | 171 | def clamp_wrapper(self, i): 172 | def clamp(x): 173 | return torch.clamp(x-i, 0, 1) 174 | return clamp 175 | 176 | def forward_main_tail(self, x, pred): 177 | res = x 178 | for i in range(self.n_blocks): 179 | if self.nc_adapter <= 1 and not self.multi_adapter: 180 | res = getattr(self, '%s%d'%(self.block_name, i))( 181 | res, pred) 182 | elif self.multi_adapter: 183 | setattr(self, 'pred%d'%i, 184 | getattr(self, 'predictor%d'%i)(res, 185 | depth if self.with_depth else None)) 186 | res = getattr(self, '%s%d'%(self.block_name, i))( 187 | res, getattr(self, 'pred%d'%i)) 188 | else: 189 | res = getattr(self, '%s%d'%(self.block_name, i))( 190 | res, pred[:, i:i+1, ...]) 191 | res = self.body_lastconv(res) 192 | res += x 193 | 194 | res = self.up(res) 195 | res = self.tail(res) 196 | return res 197 | 198 | def forward_chop(self, x, pred, shave=10, min_size=160000): 199 | scale = self.scale 200 | n_GPUs = len(self.opt.gpu_ids) 201 | n, c, h, w = x.shape 202 | h_half, w_half = h//2, w//2 203 | h_size, w_size = h_half + shave, w_half + shave 204 | lr_list = [ 205 | x[..., 0:h_size, 0:w_size], 206 | x[..., 0:h_size, (w - w_size):w], 207 | x[..., (h - h_size):h, 0:w_size], 208 | x[..., (h - h_size):h, (w - w_size):w] 209 | ] 210 | pred_list = [ 211 | pred[..., 0:h_size, 0:w_size], 212 | pred[..., 0:h_size, (w - w_size):w], 213 | pred[..., (h - h_size):h, 0:w_size], 214 | pred[..., (h - h_size):h, (w - w_size):w] 215 | ] 216 | if w_size * h_size < min_size: 217 | sr_list = [] 218 | for i in range(0, 4, n_GPUs): 219 | lr_batch = torch.cat(lr_list[i:(i+n_GPUs)], dim=0) 220 | pred_batch = torch.cat(pred_list[i:(i+n_GPUs)], dim=0) 221 | res = self.forward_main_tail(lr_batch, pred_batch) 222 | sr_list.extend(res.chunk(n_GPUs, dim=0)) 223 | else: 224 | sr_list = [ 225 | self.forward_chop(lr_, pred_, shave, min_size) \ 226 | for lr_, pred_ in zip(lr_list, pred_list)] 227 | 228 | h, w = scale * h, scale * w 229 | h_half, w_half = scale * h_half, scale * w_half 230 | h_size, w_size = scale * h_size, scale * w_size 231 | shave *= scale 232 | c = sr_list[0].shape[1] 233 | 234 | output = x.new(n, c, h, w) 235 | output[:, :, 0:h_half, 0:w_half] \ 236 | = sr_list[0][:, :, 0:h_half, 0:w_half] 237 | output[:, :, 0:h_half, w_half:w] \ 238 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 239 | output[:, :, h_half:h, 0:w_half] \ 240 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 241 | output[:, :, h_half:h, w_half:w] \ 242 | = sr_list[3][:, :, (h_size - h + h_half):h_size, 243 | (w_size - w + w_half):w_size] 244 | return output 245 | 246 | 247 | 248 | def forward(self, x, hr=None, depth=None, FLOPs_only=False, chop=False): 249 | x = self.head(x) 250 | 251 | if not self.multi_adapter: 252 | if self.nc_adapter: 253 | if self.with_depth: 254 | pred = self.predictor(x, depth) # N*1*H*W, and depth is N*1 255 | else: 256 | pred = self.predictor(x) 257 | else: 258 | pred = None 259 | if FLOPs_only: 260 | return x, pred 261 | 262 | if chop: 263 | x = self.forward_chop(x, pred) 264 | else: 265 | x = self.forward_main_tail(x, pred) 266 | 267 | if self.isTrain: 268 | criterion1 = getattr(self, 'criterion%s'%self.loss) 269 | loss1 = criterion1(x, hr) 270 | if self.nc_adapter != 0: 271 | if self.constrain == 'none': 272 | loss_Pred = self.lambda_pred * pred.abs() 273 | loss = loss1 + loss_Pred 274 | elif self.constrain == 'soft': 275 | if self.multi_adapter: 276 | pred = torch.cat([getattr(self, 'pred%d'%i) \ 277 | for i in range(self.nc_adapter)], dim=1) 278 | loss_Pred = self.lambda_pred * \ 279 | (pred.mean((2,3)) - depth).clamp_min_(0).sum(dim=1) 280 | else: 281 | loss_Pred = self.lambda_pred * \ 282 | (pred.mean((2,3)) - depth).clamp_min_(0).mean(dim=1) 283 | #(pred.mean((1,2,3)) - depth).clamp_min_(0) 284 | # loss_Pred = self.lambda_pred * \ 285 | # (pred.mean((1,2,3)) - depth).abs() 286 | loss = loss1 + loss_Pred 287 | else: 288 | loss = loss1 289 | loss_Pred = torch.zeros_like(loss1) 290 | return x, pred, loss1, loss_Pred, loss 291 | return x, pred, loss1, loss1 292 | else: 293 | if self.multi_adapter: 294 | pred = torch.cat([getattr(self, 'pred%d'%i) \ 295 | for i in range(self.nc_adapter)], dim=1) 296 | return x, pred 297 | 298 | class Predictor(nn.Module): 299 | def __init__(self, n_feats, n_layers=5, reduction=2, hard_constrain=False, 300 | nc_adapter=1, depth_pos=-1, upper_bound=float('inf')): 301 | super(Predictor, self).__init__() 302 | 303 | self.hard_constrain = hard_constrain 304 | self.depth_pos = depth_pos 305 | self.upper_bound = upper_bound 306 | self.n_layers = n_layers 307 | 308 | pred_feats = n_feats // reduction 309 | layers = [ 310 | N.conv(n_feats, pred_feats, 3, mode='C'), 311 | *(N.conv(pred_feats, pred_feats, 3, mode='PC') \ 312 | for _ in range(n_layers - 2)), 313 | N.conv(pred_feats, nc_adapter, 3, mode='PC') 314 | ] 315 | for i, layer in enumerate(layers): 316 | setattr(self, 'layer%d'%i, layer) 317 | 318 | def forward(self, x, depth=None): 319 | for i in range(self.n_layers): 320 | if self.depth_pos == i: 321 | x = x * depth.view(-1, 1, 1, 1) 322 | x = getattr(self, 'layer%d'%i)(x) 323 | if self.depth_pos >= self.n_layers: 324 | x = x * depth.view(-1, 1, 1, 1) 325 | if self.hard_constrain: 326 | return x / x.mean((1, 2, 3), keepdim=True) * depth.view(-1, 1, 1, 1) 327 | return x.clamp(0, self.upper_bound) 328 | -------------------------------------------------------------------------------- /models/edsr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class EDSRModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_resblocks = 32, 10 | n_feats = 256, 11 | block_mode = 'CRC' 12 | ) 13 | return parser 14 | 15 | def __init__(self, opt): 16 | super(EDSRModel, self).__init__(opt, SRModel=SRModel) 17 | self.model_names = ['EDSR'] 18 | self.optimizer_names = ['EDSR_optimizer_%s' % opt.optimizer] 19 | self.netEDSR = self.netDSR 20 | 21 | 22 | class SRModel(base_SRModel): 23 | def __init__(self, opt): 24 | self.block = AdaResBlock 25 | self.n_blocks = opt.n_resblocks 26 | self.block_name = 'block' 27 | super(SRModel, self).__init__(opt) -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from math import exp 9 | from torch.nn import L1Loss, MSELoss 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor([exp( 13 | -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ 14 | for x in range(window_size)]) 15 | return gauss / gauss.sum() 16 | 17 | def create_window(window_size, channel): 18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable(_2D_window.expand( 21 | channel, 1, window_size, window_size).contiguous()) 22 | return window 23 | 24 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 25 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 26 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 27 | 28 | mu1_sq = mu1.pow(2) 29 | mu2_sq = mu2.pow(2) 30 | mu1_mu2 = mu1 * mu2 31 | 32 | sigma1_sq = F.conv2d( 33 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 34 | sigma2_sq = F.conv2d( 35 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 36 | sigma12 = F.conv2d( 37 | img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 38 | 39 | C1 = 0.01 ** 2 40 | C2 = 0.03 ** 2 41 | 42 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 43 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 44 | 45 | if size_average: 46 | return ssim_map.mean() 47 | else: 48 | return ssim_map.mean(1).mean(1).mean(1) 49 | 50 | def ssim(img1, img2, window_size=11, size_average=True): 51 | (_, channel, _, _) = img1.size() 52 | window = create_window(window_size, channel) 53 | 54 | if img1.is_cuda: 55 | window = window.cuda(img1.get_device()) 56 | window = window.type_as(img1) 57 | 58 | return _ssim(img1, img2, window, window_size, channel, size_average) 59 | 60 | class SSIMLoss(torch.nn.Module): 61 | def __init__(self, window_size=11, size_average=True): 62 | super(SSIMLoss, self).__init__() 63 | self.window_size = window_size 64 | self.size_average = size_average 65 | self.channel = 1 66 | self.window = create_window(window_size, self.channel) 67 | 68 | def forward(self, img1, img2): 69 | (_, channel, _, _) = img1.size() 70 | 71 | if channel == self.channel and \ 72 | self.window.data.type() == img1.data.type(): 73 | window = self.window 74 | else: 75 | window = create_window(self.window_size, channel) 76 | 77 | if img1.is_cuda: 78 | window = window.cuda(img1.get_device()) 79 | window = window.type_as(img1) 80 | 81 | self.window = window 82 | self.channel = channel 83 | 84 | return -_ssim(img1, img2, window, self.window_size, 85 | channel, self.size_average) 86 | 87 | 88 | def calc_psnr(sr, hr): 89 | diff = (sr - hr) / 255. 90 | diff *= torch.tensor([65.738, 129.057, 25.064], 91 | device='cuda').view(1, 3, 1, 1) / 256 92 | diff = diff.sum(dim=1, keepdim=True) 93 | mse = torch.pow(diff, 2).mean() 94 | return (-10 * torch.log10(mse)) 95 | 96 | class PSNRLoss(torch.nn.Module): 97 | def __init__(self): 98 | super(PSNRLoss, self).__init__() 99 | 100 | def forward(self, img1, img2): 101 | (batch, channel, _, _) = img1.size() 102 | psnrs = [] 103 | for i in range(batch): 104 | psnrs.append(calc_psnr(img1[i:i+1,...], img2[i:i+1,...])) 105 | return -sum(psnrs)/batch -------------------------------------------------------------------------------- /models/non_local/network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from lib.non_local_simple_version import NONLocalBlock2D 3 | # from lib.non_local import NONLocalBlock2D 4 | 5 | 6 | class Network(nn.Module): 7 | def __init__(self): 8 | super(Network, self).__init__() 9 | 10 | self.convs = nn.Sequential( 11 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), 12 | nn.BatchNorm2d(32), 13 | nn.ReLU(), 14 | nn.MaxPool2d(2), 15 | 16 | NONLocalBlock2D(in_channels=32), 17 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(), 20 | nn.MaxPool2d(2), 21 | 22 | NONLocalBlock2D(in_channels=64), 23 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 24 | nn.BatchNorm2d(128), 25 | nn.ReLU(), 26 | nn.MaxPool2d(2), 27 | ) 28 | 29 | self.fc = nn.Sequential( 30 | nn.Linear(in_features=128*3*3, out_features=256), 31 | nn.ReLU(), 32 | nn.Dropout(0.5), 33 | 34 | nn.Linear(in_features=256, out_features=10) 35 | ) 36 | 37 | def forward(self, x): 38 | batch_size = x.size(0) 39 | output = self.convs(x).view(batch_size, -1) 40 | output = self.fc(output) 41 | return output 42 | 43 | if __name__ == '__main__': 44 | import torch 45 | from torch.autograd import Variable 46 | 47 | img = Variable(torch.randn(3, 1, 28, 28)) 48 | net = Network() 49 | out = net(img) 50 | print(out.size()) 51 | 52 | -------------------------------------------------------------------------------- /models/non_local/non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', 8 | sub_sample=True, bn_layer=True): 9 | super(_NonLocalBlockND, self).__init__() 10 | 11 | assert dimension in [1, 2, 3] 12 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation'] 13 | 14 | # print('Dimension: %d, mode: %s' % (dimension, mode)) 15 | 16 | self.mode = mode 17 | self.dimension = dimension 18 | self.sub_sample = sub_sample 19 | 20 | self.in_channels = in_channels 21 | self.inter_channels = inter_channels 22 | 23 | if self.inter_channels is None: 24 | self.inter_channels = in_channels // 2 25 | if self.inter_channels == 0: 26 | self.inter_channels = 1 27 | 28 | if dimension == 3: 29 | conv_nd = nn.Conv3d 30 | max_pool = nn.MaxPool3d 31 | bn = nn.BatchNorm3d 32 | elif dimension == 2: 33 | conv_nd = nn.Conv2d 34 | max_pool = nn.MaxPool2d 35 | bn = nn.BatchNorm2d 36 | else: 37 | conv_nd = nn.Conv1d 38 | max_pool = nn.MaxPool1d 39 | bn = nn.BatchNorm1d 40 | 41 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | if bn_layer: 45 | self.W = nn.Sequential( 46 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 47 | kernel_size=1, stride=1, padding=0), 48 | bn(self.in_channels) 49 | ) 50 | nn.init.constant(self.W[1].weight, 0) 51 | nn.init.constant(self.W[1].bias, 0) 52 | else: 53 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | nn.init.constant(self.W.weight, 0) 56 | nn.init.constant(self.W.bias, 0) 57 | 58 | self.theta = None 59 | self.phi = None 60 | self.concat_project = None 61 | 62 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation']: 63 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 66 | kernel_size=1, stride=1, padding=0) 67 | 68 | if mode == 'embedded_gaussian': 69 | self.operation_function = self._embedded_gaussian 70 | elif mode == 'dot_product': 71 | self.operation_function = self._dot_product 72 | elif mode == 'concatenation': 73 | self.operation_function = self._concatenation 74 | self.concat_project = nn.Sequential( 75 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 76 | nn.ReLU() 77 | ) 78 | elif mode == 'gaussian': 79 | self.operation_function = self._gaussian 80 | 81 | if sub_sample: 82 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 83 | if self.phi is None: 84 | self.phi = max_pool(kernel_size=2) 85 | else: 86 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 87 | 88 | def forward(self, x): 89 | ''' 90 | :param x: (b, c, t, h, w) 91 | :return: 92 | ''' 93 | 94 | output = self.operation_function(x) 95 | return output 96 | 97 | def _embedded_gaussian(self, x): 98 | batch_size = x.size(0) 99 | 100 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) 101 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 102 | g_x = g_x.permute(0, 2, 1) 103 | 104 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 105 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 106 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 107 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 108 | theta_x = theta_x.permute(0, 2, 1) 109 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 110 | f = torch.matmul(theta_x, phi_x) 111 | f_div_C = F.softmax(f, dim=-1) 112 | 113 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) 114 | y = torch.matmul(f_div_C, g_x) 115 | y = y.permute(0, 2, 1).contiguous() 116 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 117 | W_y = self.W(y) 118 | z = W_y + x 119 | 120 | return z 121 | 122 | def _gaussian(self, x): 123 | batch_size = x.size(0) 124 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 125 | g_x = g_x.permute(0, 2, 1) 126 | 127 | theta_x = x.view(batch_size, self.in_channels, -1) 128 | theta_x = theta_x.permute(0, 2, 1) 129 | 130 | if self.sub_sample: 131 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 132 | else: 133 | phi_x = x.view(batch_size, self.in_channels, -1) 134 | 135 | f = torch.matmul(theta_x, phi_x) 136 | f_div_C = F.softmax(f, dim=-1) 137 | 138 | y = torch.matmul(f_div_C, g_x) 139 | y = y.permute(0, 2, 1).contiguous() 140 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 141 | W_y = self.W(y) 142 | z = W_y + x 143 | 144 | return z 145 | 146 | def _dot_product(self, x): 147 | batch_size = x.size(0) 148 | 149 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 150 | g_x = g_x.permute(0, 2, 1) 151 | 152 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 153 | theta_x = theta_x.permute(0, 2, 1) 154 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 155 | f = torch.matmul(theta_x, phi_x) 156 | N = f.size(-1) 157 | f_div_C = f / N 158 | 159 | y = torch.matmul(f_div_C, g_x) 160 | y = y.permute(0, 2, 1).contiguous() 161 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 162 | W_y = self.W(y) 163 | z = W_y + x 164 | 165 | return z 166 | 167 | def _concatenation(self, x): 168 | batch_size = x.size(0) 169 | 170 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 171 | g_x = g_x.permute(0, 2, 1) 172 | 173 | # (b, c, N, 1) 174 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 175 | # (b, c, 1, N) 176 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 177 | 178 | h = theta_x.size(2) 179 | w = phi_x.size(3) 180 | theta_x = theta_x.repeat(1, 1, 1, w) 181 | phi_x = phi_x.repeat(1, 1, h, 1) 182 | 183 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 184 | f = self.concat_project(concat_feature) 185 | b, _, h, w = f.size() 186 | f = f.view(b, h, w) 187 | 188 | N = f.size(-1) 189 | f_div_C = f / N 190 | 191 | y = torch.matmul(f_div_C, g_x) 192 | y = y.permute(0, 2, 1).contiguous() 193 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 194 | W_y = self.W(y) 195 | z = W_y + x 196 | 197 | return z 198 | 199 | 200 | class NONLocalBlock1D(_NonLocalBlockND): 201 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 202 | super(NONLocalBlock1D, self).__init__(in_channels, 203 | inter_channels=inter_channels, 204 | dimension=1, mode=mode, 205 | sub_sample=sub_sample, 206 | bn_layer=bn_layer) 207 | 208 | 209 | class NONLocalBlock2D(_NonLocalBlockND): 210 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 211 | super(NONLocalBlock2D, self).__init__(in_channels, 212 | inter_channels=inter_channels, 213 | dimension=2, mode=mode, 214 | sub_sample=sub_sample, 215 | bn_layer=bn_layer) 216 | 217 | 218 | class NONLocalBlock3D(_NonLocalBlockND): 219 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 220 | super(NONLocalBlock3D, self).__init__(in_channels, 221 | inter_channels=inter_channels, 222 | dimension=3, mode=mode, 223 | sub_sample=sub_sample, 224 | bn_layer=bn_layer) 225 | 226 | 227 | if __name__ == '__main__': 228 | from torch.autograd import Variable 229 | 230 | mode_list = ['concatenation', 'embedded_gaussian', 'gaussian', 'dot_product', ] 231 | # mode_list = ['concatenation'] 232 | 233 | for mode in mode_list: 234 | print(mode) 235 | img = Variable(torch.zeros(2, 4, 5)) 236 | net = NONLocalBlock1D(4, mode=mode, sub_sample=True) 237 | out = net(img) 238 | print(out.size()) 239 | 240 | img = Variable(torch.zeros(2, 4, 10, 10)) 241 | net = NONLocalBlock2D(4, mode=mode, sub_sample=False, bn_layer=False) 242 | out = net(img) 243 | print(out.size()) 244 | 245 | img = Variable(torch.zeros(2, 4, 5, 4, 5)) 246 | net = NONLocalBlock3D(4, mode=mode) 247 | out = net(img) 248 | print(out.size()) 249 | 250 | -------------------------------------------------------------------------------- /models/non_local/non_local_simple_version.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool = nn.MaxPool3d 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool = nn.MaxPool2d 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool = nn.MaxPool1d 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | 58 | if sub_sample: 59 | 60 | self.g = nn.Sequential(self.g, nn.AvgPool2d(kernel_size=4,stride=4)) 61 | self.phi = nn.Sequential(self.phi, nn.AvgPool2d(kernel_size=4,stride=4)) 62 | 63 | # self.g = nn.Sequential(self.g, torch.nn.UpsamplingBilinear2d(kernel_size=2)) 64 | # self.phi = nn.Sequential(self.phi, torch.nn.UpsamplingBilinear2d(kernel_size=2)) 65 | 66 | def forward(self, x): 67 | ''' 68 | :param x: (b, c, t, h, w) 69 | :return: 70 | ''' 71 | 72 | batch_size = x.size(0) 73 | 74 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 75 | g_x = g_x.permute(0, 2, 1) 76 | 77 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 78 | theta_x = theta_x.permute(0, 2, 1) 79 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 80 | f = torch.matmul(theta_x, phi_x) 81 | f_div_C = F.softmax(f, dim=-1) 82 | 83 | y = torch.matmul(f_div_C, g_x) 84 | y = y.permute(0, 2, 1).contiguous() 85 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 86 | W_y = self.W(y) 87 | z = W_y + x 88 | 89 | return z 90 | 91 | 92 | class NONLocalBlock1D(_NonLocalBlockND): 93 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 94 | super(NONLocalBlock1D, self).__init__(in_channels, 95 | inter_channels=inter_channels, 96 | dimension=1, sub_sample=sub_sample, 97 | bn_layer=bn_layer) 98 | 99 | 100 | class NONLocalBlock2D(_NonLocalBlockND): 101 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 102 | super(NONLocalBlock2D, self).__init__(in_channels, 103 | inter_channels=inter_channels, 104 | dimension=2, sub_sample=sub_sample, 105 | bn_layer=bn_layer) 106 | 107 | 108 | class NONLocalBlock3D(_NonLocalBlockND): 109 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 110 | super(NONLocalBlock3D, self).__init__(in_channels, 111 | inter_channels=inter_channels, 112 | dimension=3, sub_sample=sub_sample, 113 | bn_layer=bn_layer) 114 | 115 | 116 | if __name__ == '__main__': 117 | from torch.autograd import Variable 118 | import torch 119 | sub_sample = False 120 | 121 | img = Variable(torch.zeros(2, 4, 5)) 122 | net = NONLocalBlock1D(4, sub_sample=sub_sample, bn_layer=False) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = Variable(torch.zeros(2, 4, 5, 3)) 127 | net = NONLocalBlock2D(4, sub_sample=sub_sample) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | img = Variable(torch.zeros(2, 4, 5, 4, 5)) 132 | net = NONLocalBlock3D(4, sub_sample=sub_sample) 133 | out = net(img) 134 | print(out.size()) 135 | 136 | -------------------------------------------------------------------------------- /models/non_local/utils.py: -------------------------------------------------------------------------------- 1 | import config as cfg 2 | import os 3 | 4 | 5 | def create_architecture(): 6 | if not os.path.exists(cfg.model_dir): 7 | os.mkdir(cfg.model_dir) 8 | 9 | -------------------------------------------------------------------------------- /models/rcan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class RCANModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_groups = 10, 10 | n_resblocks = 20, 11 | n_feats = 64, 12 | block_mode = 'CRC', 13 | channel_attention = 'ca', 14 | ) 15 | return parser 16 | 17 | def __init__(self, opt): 18 | super(RCANModel, self).__init__(opt, SRModel=SRModel) 19 | self.model_names = ['RCAN'] 20 | self.optimizer_names = ['RCAN_optimizer_%s' % opt.optimizer] 21 | self.netRCAN = self.netDSR 22 | 23 | 24 | class SRModel(base_SRModel): 25 | def __init__(self, opt): 26 | self.block = AdaRCAGroup 27 | self.n_blocks = opt.n_groups 28 | self.block_name = 'group' 29 | super(SRModel, self).__init__(opt) -------------------------------------------------------------------------------- /models/rdn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks as N 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import math 7 | import torch.nn.functional as F 8 | from . import losses as L 9 | 10 | class RDB_Conv(nn.Module): 11 | def __init__(self, inChannels, growRate, sparse_conv): 12 | super(RDB_Conv, self).__init__() 13 | mode = 'CR' 14 | if sparse_conv: 15 | mode = mode.replace('C', 'c') 16 | self.conv = N.conv(inChannels, growRate, 3, 1, 1, mode=mode) 17 | 18 | def forward(self, x): 19 | out = self.conv(x) 20 | return torch.cat((x, out), 1) 21 | 22 | 23 | class RDB(nn.Module): 24 | def __init__(self, growRate0, growRate, nConvLayers, sparse_conv): 25 | super(RDB, self).__init__() 26 | G0 = growRate0 # 64 27 | G = growRate # 64 28 | C = nConvLayers # 8 29 | 30 | convs = [] 31 | for c in range(C): 32 | convs.append(RDB_Conv(G0 + c * G, G, sparse_conv)) 33 | self.convs = N.seq(convs) 34 | 35 | # Local Feature Fusion 36 | self.LFF = N.conv(G0 + C * G, G0, 1, stride=1, padding=0, mode='C') 37 | 38 | def forward(self, x): 39 | return self.LFF(self.convs(x)) + x 40 | 41 | 42 | class RDN(nn.Module): 43 | def __init__(self, opt): 44 | super(RDN, self).__init__() 45 | input_nc = 3 46 | r = opt.scale 47 | sparse_conv = opt.sparse_conv 48 | G0 = 64 49 | 50 | # number of RDB blocks, conv layers, out channels 51 | RDNconfig ='B' 52 | self.D, C, G = { 53 | 'A': (20, 6, 32), 54 | 'B': (16, 8, 64), 55 | }[RDNconfig] 56 | 57 | self.sub_mean = N.MeanShift_rdn() 58 | # Shallow feature extraction net 59 | self.SFENet1 = N.conv(input_nc, G0, 3, 1, 1, mode='C') 60 | self.SFENet2 = N.conv(G0, G0, 3, 1, 1, mode='C') 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | for i in range(self.D): 64 | setattr(self, 'RDB%d'%i, RDB(G0, G, C, sparse_conv)) 65 | 66 | # Global Feature Fusion 67 | self.GFF = N.seq( 68 | N.conv(self.D * G0, G0, 1, stride=1, padding=0, mode='C'), 69 | N.conv(G0, G0, 3, 1, 1, mode='C') 70 | ) 71 | 72 | # Up-sampling net 73 | UPNet = [] 74 | if r == 2 or r == 3: 75 | UPNet.append(N.upsample_pixelshuffle(G0, G, 3, 1, 1, mode=str(r))) 76 | elif r == 4: 77 | UPNet.append(N.upsample_pixelshuffle(G0, G, 3, 1, 1, mode='2')) 78 | UPNet.append(N.upsample_pixelshuffle(G, G, 3, 1, 1, mode='2')) 79 | else: 80 | raise ValueError("scale must be 2 or 3 or 4.") 81 | UPNet.append(N.conv(G, input_nc, 3, 1, 1, mode='C')) 82 | self.UPNet = N.seq(UPNet) 83 | 84 | self.add_mean = N.MeanShift_rdn(sign=1) 85 | 86 | self.isTrain = opt.isTrain 87 | self.loss = opt.loss 88 | if self.isTrain: 89 | setattr(self, 'criterion%s'%self.loss, 90 | getattr(L, '%sLoss'%self.loss)()) 91 | 92 | def forward(self, x, hr=None, depth=None): 93 | x = self.sub_mean(x) 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = getattr(self, 'RDB%d'%i)(x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out, 1)) 103 | x += f__1 104 | 105 | x = self.UPNet(x) 106 | x = self.add_mean(x) 107 | if self.isTrain: 108 | criterion1 = getattr(self, 'criterion%s'%self.loss) 109 | loss1 = criterion1(x, hr) 110 | return x, None, loss1, loss1 111 | return x, None 112 | 113 | 114 | 115 | 116 | class RdnModel(BaseModel): 117 | @staticmethod 118 | def modify_commandline_options(parser, is_train=True): 119 | return parser 120 | 121 | def __init__(self, opt, SRModel=RDN): 122 | super(RdnModel, self).__init__(opt) 123 | 124 | self.opt = opt 125 | self.loss_names = [opt.loss, 'Total'] 126 | self.visual_names = ['data_lr', 'data_hr', 'data_sr'] 127 | self.model_names = ['DSR'] 128 | self.optimizer_names = ['DSR_optimizer_%s' % opt.optimizer] 129 | 130 | DSR = RDN(opt) 131 | self.netDSR = N.init_net(DSR, opt.init_type, opt.init_gain, opt.gpu_ids) 132 | 133 | if self.isTrain: 134 | if opt.optimizer == 'Adam': 135 | self.optimizer = optim.Adam(self.netDSR.parameters(), 136 | lr=opt.lr, 137 | betas=(opt.beta1, opt.beta2), 138 | weight_decay=opt.weight_decay) 139 | elif opt.optimizer == 'SGD': 140 | self.optimizer = optim.SGD(self.netDSR.parameters(), 141 | lr=opt.lr, 142 | momentum=opt.momentum, 143 | weight_decay=opt.weight_decay) 144 | elif opt.optimizer == 'RMSprop': 145 | self.optimizer = optim.RMSprop(self.netDSR.parameters(), 146 | lr=opt.lr, 147 | alpha=opt.alpha, 148 | momentum=opt.momentum, 149 | weight_decay=opt.weight_decay) 150 | else: 151 | raise NotImplementedError( 152 | 'optimizer named [%s] is not supported' % opt.optimizer) 153 | 154 | self.optimizers = [self.optimizer] 155 | 156 | def set_input(self, input): 157 | self.data_lr = input['lr'].to(self.device) 158 | self.data_hr = input['hr'].to(self.device) 159 | self.image_paths = input['fname'] 160 | 161 | def forward(self, FLOPs_only=False): 162 | if self.isTrain: 163 | self.data_sr, self.pred, *losses = \ 164 | self.netDSR(self.data_lr, self.data_hr) 165 | for i, loss_name in enumerate(self.loss_names): 166 | setattr(self, 'loss_'+loss_name, losses[i].mean()) 167 | else: 168 | if not self.opt.chop: 169 | self.data_sr, self.pred = self.netDSR(self.data_lr) 170 | else: 171 | self.data_sr, self.pred = N.forward_chop( 172 | self.opt, self.netDSR, self.data_lr, 173 | None, shave=10, min_size=160000) 174 | 175 | def backward(self): 176 | self.loss_L1.backward() 177 | 178 | def optimize_parameters(self): 179 | self.forward() 180 | self.optimizer.zero_grad() 181 | self.backward() 182 | self.optimizer.step() -------------------------------------------------------------------------------- /models/san_model.py: -------------------------------------------------------------------------------- 1 | # Modified from the authors' version with minor changes 2 | import torch 3 | from .base_model import BaseModel 4 | from . import networks as N 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import math 8 | import torch.nn.functional as F 9 | from . import losses as L 10 | from . import common 11 | from .MPNCOV.python import MPNCOV 12 | from masked_conv2d import MaskedConv2d 13 | 14 | ## non_local module 15 | class _NonLocalBlockND(nn.Module): 16 | def __init__(self, in_channels, inter_channels=None, dimension=3, 17 | mode='embedded_gaussian', sub_sample=True, bn_layer=True): 18 | super(_NonLocalBlockND, self).__init__() 19 | assert dimension in [1, 2, 3] 20 | assert mode in ['embedded_gaussian', 'gaussian', 21 | 'dot_product', 'concatenation'] 22 | 23 | self.mode = mode 24 | self.dimension = dimension 25 | self.sub_sample = sub_sample 26 | 27 | self.in_channels = in_channels 28 | self.inter_channels = inter_channels 29 | 30 | if self.inter_channels is None: 31 | self.inter_channels = in_channels // 2 32 | if self.inter_channels == 0: 33 | self.inter_channels = 1 34 | 35 | if dimension == 3: 36 | conv_nd = nn.Conv3d 37 | max_pool = nn.MaxPool3d 38 | bn = nn.BatchNorm3d 39 | elif dimension == 2: 40 | conv_nd = nn.Conv2d 41 | max_pool = nn.MaxPool2d 42 | sub_sample = nn.Upsample 43 | bn = nn.BatchNorm2d 44 | else: 45 | conv_nd = nn.Conv1d 46 | max_pool = nn.MaxPool1d 47 | bn = nn.BatchNorm1d 48 | 49 | self.g = conv_nd(in_channels=self.in_channels, 50 | out_channels=self.inter_channels, 51 | kernel_size=1, stride=1, padding=0) 52 | 53 | if bn_layer: 54 | self.W = nn.Sequential( 55 | conv_nd(in_channels=self.inter_channels, 56 | out_channels=self.in_channels, 57 | kernel_size=1, stride=1, padding=0), 58 | bn(self.in_channels) 59 | ) 60 | nn.init.constant_(self.W[1].weight, 0) 61 | nn.init.constant_(self.W[1].bias, 0) 62 | else: 63 | self.W = conv_nd(in_channels=self.inter_channels, 64 | out_channels=self.in_channels, 65 | kernel_size=1, stride=1, padding=0) 66 | nn.init.constant_(self.W.weight, 0) 67 | nn.init.constant_(self.W.bias, 0) 68 | 69 | self.theta = None 70 | self.phi = None 71 | self.concat_project = None 72 | 73 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation']: 74 | self.theta = conv_nd(in_channels=self.in_channels, 75 | out_channels=self.inter_channels, 76 | kernel_size=1, stride=1, padding=0) 77 | self.phi = conv_nd(in_channels=self.in_channels, 78 | out_channels=self.inter_channels, 79 | kernel_size=1, stride=1, padding=0) 80 | 81 | if mode == 'embedded_gaussian': 82 | self.operation_function = self._embedded_gaussian 83 | elif mode == 'dot_product': 84 | self.operation_function = self._dot_product 85 | elif mode == 'concatenation': 86 | self.operation_function = self._concatenation 87 | self.concat_project = nn.Sequential( 88 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 89 | nn.ReLU() 90 | ) 91 | elif mode == 'gaussian': 92 | self.operation_function = self._gaussian 93 | 94 | if sub_sample: 95 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 96 | if self.phi is None: 97 | self.phi = max_pool(kernel_size=2) 98 | else: 99 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 100 | 101 | def forward(self, x): 102 | ''' 103 | :param x: (b, c, t, h, w) 104 | :return: 105 | ''' 106 | 107 | output = self.operation_function(x) 108 | return output 109 | 110 | def _embedded_gaussian(self, x): 111 | batch_size,C,H,W = x.shape 112 | 113 | ## 114 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) 115 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 116 | g_x = g_x.permute(0, 2, 1) 117 | 118 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 119 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 120 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 121 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 122 | theta_x = theta_x.permute(0, 2, 1) 123 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 124 | f = torch.matmul(theta_x, phi_x) 125 | # return f 126 | f_div_C = F.softmax(f, dim=-1) 127 | # return f_div_C 128 | # (b, thw, thw)dot(b, thw, 0.5c) = 129 | # (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) 130 | y = torch.matmul(f_div_C, g_x) 131 | y = y.permute(0, 2, 1).contiguous() 132 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 133 | W_y = self.W(y) 134 | z = W_y + x 135 | 136 | return z 137 | 138 | def _gaussian(self, x): 139 | batch_size = x.size(0) 140 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 141 | g_x = g_x.permute(0, 2, 1) 142 | 143 | theta_x = x.view(batch_size, self.in_channels, -1) 144 | theta_x = theta_x.permute(0, 2, 1) 145 | 146 | if self.sub_sample: 147 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 148 | else: 149 | phi_x = x.view(batch_size, self.in_channels, -1) 150 | 151 | f = torch.matmul(theta_x, phi_x) 152 | f_div_C = F.softmax(f, dim=-1) 153 | 154 | y = torch.matmul(f_div_C, g_x) 155 | y = y.permute(0, 2, 1).contiguous() 156 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 157 | W_y = self.W(y) 158 | z = W_y + x 159 | 160 | return z 161 | 162 | def _dot_product(self, x): 163 | batch_size = x.size(0) 164 | 165 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 166 | g_x = g_x.permute(0, 2, 1) 167 | 168 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 169 | theta_x = theta_x.permute(0, 2, 1) 170 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 171 | f = torch.matmul(theta_x, phi_x) 172 | N = f.size(-1) 173 | f_div_C = f / N 174 | 175 | y = torch.matmul(f_div_C, g_x) 176 | y = y.permute(0, 2, 1).contiguous() 177 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 178 | W_y = self.W(y) 179 | z = W_y + x 180 | 181 | return z 182 | 183 | def _concatenation(self, x): 184 | batch_size = x.size(0) 185 | 186 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 187 | g_x = g_x.permute(0, 2, 1) 188 | 189 | # (b, c, N, 1) 190 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 191 | # (b, c, 1, N) 192 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 193 | 194 | h = theta_x.size(2) 195 | w = phi_x.size(3) 196 | theta_x = theta_x.repeat(1, 1, 1, w) 197 | phi_x = phi_x.repeat(1, 1, h, 1) 198 | 199 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 200 | f = self.concat_project(concat_feature) 201 | b, _, h, w = f.size() 202 | f = f.view(b, h, w) 203 | 204 | N = f.size(-1) 205 | f_div_C = f / N 206 | 207 | y = torch.matmul(f_div_C, g_x) 208 | y = y.permute(0, 2, 1).contiguous() 209 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 210 | W_y = self.W(y) 211 | z = W_y + x 212 | 213 | return z 214 | 215 | 216 | class NONLocalBlock1D(_NonLocalBlockND): 217 | def __init__(self, in_channels, inter_channels=None, 218 | mode='embedded_gaussian', sub_sample=True, bn_layer=True): 219 | super(NONLocalBlock1D, self).__init__(in_channels, 220 | inter_channels=inter_channels, 221 | dimension=1, mode=mode, 222 | sub_sample=sub_sample, 223 | bn_layer=bn_layer) 224 | 225 | 226 | class NONLocalBlock2D(_NonLocalBlockND): 227 | def __init__(self, in_channels, inter_channels=None, 228 | mode='embedded_gaussian', sub_sample=True, bn_layer=True): 229 | super(NONLocalBlock2D, self).__init__(in_channels, 230 | inter_channels=inter_channels, 231 | dimension=2, mode=mode, 232 | sub_sample=sub_sample, 233 | bn_layer=bn_layer) 234 | 235 | 236 | ## Channel Attention (CA) Layer 237 | class CALayer(nn.Module): 238 | def __init__(self, channel, reduction=8): 239 | super(CALayer, self).__init__() 240 | # global average pooling: feature --> point 241 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 242 | self.max_pool = nn.AdaptiveMaxPool2d(1) 243 | # feature channel downscale and upscale --> channel weight 244 | self.conv_du = nn.Sequential( 245 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 246 | nn.ReLU(inplace=True), 247 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 248 | ) 249 | 250 | def forward(self, x): 251 | _,_,h,w = x.shape 252 | y_ave = self.avg_pool(x) 253 | y_ave = self.conv_du(y_ave) 254 | return y_ave 255 | 256 | 257 | 258 | ## second-order Channel attention (SOCA) 259 | class SOCA(nn.Module): 260 | def __init__(self, channel, reduction=8): 261 | super(SOCA, self).__init__() 262 | # global average pooling: feature --> point 263 | self.max_pool = nn.MaxPool2d(kernel_size=2) 264 | 265 | # feature channel downscale and upscale --> channel weight 266 | self.conv_du = nn.Sequential( 267 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 268 | nn.ReLU(inplace=True), 269 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 270 | nn.Sigmoid() 271 | ) 272 | 273 | def forward(self, x): 274 | batch_size, C, h, w = x.shape # x: NxCxHxW 275 | N = int(h * w) 276 | min_h = min(h, w) 277 | h1 = 1000 278 | w1 = 1000 279 | if h < h1 and w < w1: 280 | x_sub = x 281 | elif h < h1 and w > w1: 282 | # H = (h - h1) // 2 283 | W = (w - w1) // 2 284 | x_sub = x[:, :, :, W:(W + w1)] 285 | elif w < w1 and h > h1: 286 | H = (h - h1) // 2 287 | # W = (w - w1) // 2 288 | x_sub = x[:, :, H:H + h1, :] 289 | else: 290 | H = (h - h1) // 2 291 | W = (w - w1) // 2 292 | x_sub = x[:, :, H:(H + h1), W:(W + w1)] 293 | 294 | ## 295 | ## MPN-COV 296 | cov_mat = MPNCOV.CovpoolLayer(x_sub) # Global Covariance pooling layer 297 | # Matrix square root layer( including pre-norm,Newton-Schulz iter. and 298 | # post-com. with 5 iteration) 299 | cov_mat_sqrt = MPNCOV.SqrtmLayer(cov_mat,5) 300 | ## 301 | cov_mat_sum = torch.mean(cov_mat_sqrt,1) 302 | cov_mat_sum = cov_mat_sum.view(batch_size,C,1,1) 303 | 304 | y_cov = self.conv_du(cov_mat_sum) 305 | 306 | return y_cov*x 307 | 308 | 309 | 310 | ## self-attention+ channel attention module 311 | class Nonlocal_CA(nn.Module): 312 | def __init__(self, in_feat=64, inter_feat=32, reduction=8, 313 | sub_sample=False, bn_layer=True): 314 | super(Nonlocal_CA, self).__init__() 315 | # second-order channel attention 316 | self.soca=SOCA(in_feat, reduction=reduction) 317 | # nonlocal module 318 | self.non_local = NONLocalBlock2D( 319 | in_channels=in_feat, inter_channels=inter_feat, 320 | sub_sample=sub_sample, bn_layer=bn_layer) 321 | 322 | self.sigmoid = nn.Sigmoid() 323 | def forward(self,x): 324 | ## divide feature map into 4 part 325 | batch_size,C,H,W = x.shape 326 | H1 = int(H / 2) 327 | W1 = int(W / 2) 328 | nonlocal_feat = torch.zeros_like(x) 329 | 330 | feat_sub_lu = x[:, :, :H1, :W1] 331 | feat_sub_ld = x[:, :, H1:, :W1] 332 | feat_sub_ru = x[:, :, :H1, W1:] 333 | feat_sub_rd = x[:, :, H1:, W1:] 334 | 335 | 336 | nonlocal_lu = self.non_local(feat_sub_lu) 337 | nonlocal_ld = self.non_local(feat_sub_ld) 338 | nonlocal_ru = self.non_local(feat_sub_ru) 339 | nonlocal_rd = self.non_local(feat_sub_rd) 340 | nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu 341 | nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld 342 | nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru 343 | nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd 344 | 345 | return nonlocal_feat 346 | 347 | 348 | ## Residual Block (RB) 349 | class RB(nn.Module): 350 | def __init__(self, conv, n_feat, kernel_size, reduction, bias=True, 351 | bn=False, act=nn.ReLU(inplace=True), res_scale=1, dilation=2): 352 | super(RB, self).__init__() 353 | modules_body = [] 354 | 355 | self.gamma1 = 1.0 356 | 357 | self.conv_first = nn.Sequential( 358 | conv(n_feat, n_feat, kernel_size, padding=1, bias=bias), 359 | act, 360 | conv(n_feat, n_feat, kernel_size, padding=1, bias=bias) 361 | ) 362 | 363 | 364 | self.res_scale = res_scale 365 | 366 | def forward(self, x): 367 | y = self.conv_first(x) 368 | y = y + x 369 | 370 | return y 371 | 372 | ## Local-source Residual Attention Group (LSRARG) 373 | class LSRAG(nn.Module): 374 | def __init__(self, conv, n_feat, kernel_size, reduction, 375 | act, res_scale, n_resblocks): 376 | super(LSRAG, self).__init__() 377 | ## 378 | self.rcab= nn.ModuleList([RB( 379 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, 380 | act=nn.ReLU(True), res_scale=1) for _ in range(n_resblocks)]) 381 | self.soca = (SOCA(n_feat,reduction=reduction)) 382 | self.conv_last = (conv(n_feat, n_feat, kernel_size, padding=1)) 383 | self.n_resblocks = n_resblocks 384 | ## 385 | self.gamma = nn.Parameter(torch.zeros(1)) 386 | 387 | def make_layer(self, block, num_of_layer): 388 | layers = [] 389 | for _ in range(num_of_layer): 390 | layers.append(block) 391 | return nn.ModuleList(layers) 392 | 393 | def forward(self, x): 394 | residual = x 395 | 396 | ## share-source skip connection 397 | 398 | for i,l in enumerate(self.rcab): 399 | x = l(x) 400 | x = self.soca(x) 401 | x = self.conv_last(x) 402 | 403 | x = x + residual 404 | 405 | return x 406 | ## 407 | 408 | 409 | ## Second-order Channel Attention Network (SAN) 410 | class SAN(nn.Module): 411 | def __init__(self, opt): 412 | super(SAN, self).__init__() 413 | n_resgroups = 20 414 | n_resblocks = 10 415 | n_feats = 64 416 | kernel_size = 3 417 | reduction = 16 418 | scale = opt.scale 419 | act = nn.ReLU(inplace=True) 420 | if opt.sparse_conv: 421 | conv = MaskedConv2d 422 | else: 423 | conv = common.default_conv 424 | 425 | # RGB mean for DIV2K 426 | rgb_mean = (0.4488, 0.4371, 0.4040) 427 | rgb_std = (1.0, 1.0, 1.0) 428 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 429 | 430 | # define head module 431 | modules_head = [common.default_conv(3, n_feats, kernel_size)] 432 | 433 | # define body module 434 | ## share-source skip connection 435 | 436 | ## 437 | self.gamma = nn.Parameter(torch.zeros(1)) 438 | # self.gamma = 0.2 439 | self.n_resgroups = n_resgroups 440 | self.RG = nn.ModuleList([LSRAG( 441 | conv, n_feats, kernel_size, reduction, act=act, res_scale=1, 442 | n_resblocks=n_resblocks) for _ in range(n_resgroups)]) 443 | self.conv_last = conv(n_feats, n_feats, kernel_size, padding=1) 444 | 445 | # define tail module 446 | modules_tail = [ 447 | common.Upsampler(common.default_conv, scale, n_feats, act=False), 448 | common.default_conv(n_feats, 3, kernel_size)] 449 | 450 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 451 | self.non_local = Nonlocal_CA( 452 | in_feat=n_feats, inter_feat=n_feats//8, 453 | reduction=8,sub_sample=False, bn_layer=False) 454 | 455 | 456 | self.head = nn.Sequential(*modules_head) 457 | self.tail = nn.Sequential(*modules_tail) 458 | 459 | self.isTrain = opt.isTrain 460 | self.loss = opt.loss 461 | if self.isTrain: 462 | setattr(self, 'criterion%s'%self.loss, 463 | getattr(L, '%sLoss'%self.loss)()) 464 | 465 | 466 | def make_layer(self, block, num_of_layer): 467 | layers = [] 468 | for _ in range(num_of_layer): 469 | layers.append(block) 470 | 471 | return nn.ModuleList(layers) 472 | 473 | def forward(self, x, hr=None, depth=None): 474 | x = self.sub_mean(x) 475 | x = self.head(x) 476 | 477 | ## add nonlocal 478 | xx = self.non_local(x) 479 | 480 | # share-source skip connection 481 | residual = xx 482 | 483 | ## share-source residual gruop 484 | for i,l in enumerate(self.RG): 485 | xx = l(xx) + self.gamma*residual 486 | # xx = self.gamma*xx + residual 487 | # body part 488 | ## 489 | ## add nonlocal 490 | res = self.non_local(xx) 491 | ## 492 | res = res + x 493 | 494 | x = self.tail(res) 495 | x = self.add_mean(x) 496 | 497 | if self.isTrain: 498 | criterion1 = getattr(self, 'criterion%s'%self.loss) 499 | loss1 = criterion1(x, hr) 500 | return x, None, loss1, loss1 501 | return x, None 502 | 503 | def load_state_dict(self, state_dict, strict=False): 504 | own_state = self.state_dict() 505 | for name, param in state_dict.items(): 506 | if name in own_state: 507 | if isinstance(param, nn.Parameter): 508 | param = param.data 509 | try: 510 | own_state[name].copy_(param) 511 | except Exception: 512 | if name.find('tail') >= 0: 513 | print('Replace pre-trained upsampler to new one...') 514 | else: 515 | raise RuntimeError( 516 | 'While copying the parameter named {}, ' 517 | 'whose dimensions in the model are {} and ' 518 | 'whose dimensions in the checkpoint are {}.' 519 | .format(name, own_state[name].size(), param.size())) 520 | elif strict: 521 | if name.find('tail') == -1: 522 | raise KeyError('unexpected key "{}" in state_dict' 523 | .format(name)) 524 | 525 | if strict: 526 | missing = set(own_state.keys()) - set(state_dict.keys()) 527 | if len(missing) > 0: 528 | raise KeyError( 529 | 'missing keys in state_dict: "{}"'.format(missing)) 530 | 531 | class SANModel(BaseModel): 532 | @staticmethod 533 | def modify_commandline_options(parser, is_train=True): 534 | return parser 535 | 536 | def __init__(self, opt, SRModel=None): 537 | super(SANModel, self).__init__(opt) 538 | 539 | self.opt = opt 540 | self.loss_names = [opt.loss, 'Total'] 541 | self.visual_names = ['data_lr', 'data_hr', 'data_sr'] 542 | self.model_names = ['DSR'] 543 | self.optimizer_names = ['DSR_optimizer_%s' % opt.optimizer] 544 | 545 | DSR = SAN(opt) 546 | self.netDSR = N.init_net(DSR, opt.init_type, opt.init_gain, opt.gpu_ids) 547 | 548 | if self.isTrain: 549 | if opt.optimizer == 'Adam': 550 | self.optimizer = optim.Adam(self.netDSR.parameters(), 551 | lr=opt.lr, 552 | betas=(opt.beta1, opt.beta2), 553 | weight_decay=opt.weight_decay) 554 | elif opt.optimizer == 'SGD': 555 | self.optimizer = optim.SGD(self.netDSR.parameters(), 556 | lr=opt.lr, 557 | momentum=opt.momentum, 558 | weight_decay=opt.weight_decay) 559 | elif opt.optimizer == 'RMSprop': 560 | self.optimizer = optim.RMSprop(self.netDSR.parameters(), 561 | lr=opt.lr, 562 | alpha=opt.alpha, 563 | momentum=opt.momentum, 564 | weight_decay=opt.weight_decay) 565 | else: 566 | raise NotImplementedError( 567 | 'optimizer named [%s] is not supported' % opt.optimizer) 568 | 569 | self.optimizers = [self.optimizer] 570 | 571 | def set_input(self, input): 572 | self.data_lr = input['lr'].to(self.device) 573 | self.data_hr = input['hr'].to(self.device) 574 | self.image_paths = input['fname'] 575 | 576 | def forward(self, FLOPs_only=False): 577 | if self.isTrain: 578 | self.data_sr, self.pred, *losses = \ 579 | self.netDSR(self.data_lr, self.data_hr) 580 | for i, loss_name in enumerate(self.loss_names): 581 | setattr(self, 'loss_'+loss_name, losses[i].mean()) 582 | else: 583 | if not self.opt.chop: 584 | self.data_sr, self.pred = self.netDSR(self.data_lr) 585 | else: 586 | self.data_sr, self.pred = N.forward_chop( 587 | self.opt, self.netDSR, self.data_lr, 588 | None, shave=10, min_size=160000) 589 | def backward(self): 590 | self.loss_Total.backward() 591 | 592 | def optimize_parameters(self): 593 | self.forward() 594 | self.optimizer.zero_grad() 595 | self.backward() 596 | self.optimizer.step() 597 | -------------------------------------------------------------------------------- /models/srcnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks as N 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import math 7 | import torch.nn.functional as F 8 | from . import losses as L 9 | 10 | class SRCNN(nn.Module): 11 | def __init__(self, opt): 12 | super(SRCNN, self).__init__() 13 | self.conv1 = N.conv(1, 64, 9, padding=4, mode='CR') 14 | self.conv2 = N.conv(64, 32, 5, padding=2, mode='CR') 15 | self.conv3 = N.conv(32, 1, 5, padding=2, mode='C') 16 | 17 | self.isTrain = opt.isTrain 18 | self.loss = opt.loss 19 | if self.isTrain: 20 | setattr(self, 'criterion%s'%self.loss, 21 | getattr(L, '%sLoss'%self.loss)()) 22 | 23 | def forward(self, x, hr=None, depth=None): 24 | x = self.conv1(x) 25 | x = self.conv2(x) 26 | x = self.conv3(x) 27 | 28 | if self.isTrain: 29 | criterion1 = getattr(self, 'criterion%s'%self.loss) 30 | loss1 = criterion1(x, hr) 31 | return x, None, loss1, loss1 32 | return x, None 33 | 34 | class SRCNNModel(BaseModel): 35 | @staticmethod 36 | def modify_commandline_options(parser, is_train=True): 37 | parser.set_defaults( 38 | lr_mode = 'sr', 39 | mode = 'Y', 40 | ) 41 | return parser 42 | 43 | def __init__(self, opt, SRModel=SRCNN): 44 | super(SRCNNModel, self).__init__(opt) 45 | 46 | self.opt = opt 47 | self.loss_names = [opt.loss, 'Total'] 48 | self.visual_names = ['data_lr', 'data_hr', 'data_sr'] 49 | self.model_names = ['DSR'] 50 | self.optimizer_names = ['DSR_optimizer_%s' % opt.optimizer] 51 | 52 | DSR = SRModel(opt) 53 | self.netDSR = N.init_net(DSR, opt.init_type, opt.init_gain, opt.gpu_ids) 54 | 55 | if self.isTrain: 56 | if opt.optimizer == 'Adam': 57 | self.optimizer = optim.Adam(self.netDSR.parameters(), 58 | lr=opt.lr, 59 | betas=(opt.beta1, opt.beta2), 60 | weight_decay=opt.weight_decay) 61 | elif opt.optimizer == 'SGD': 62 | self.optimizer = optim.SGD(self.netDSR.parameters(), 63 | lr=opt.lr, 64 | momentum=opt.momentum, 65 | weight_decay=opt.weight_decay) 66 | elif opt.optimizer == 'RMSprop': 67 | self.optimizer = optim.RMSprop(self.netDSR.parameters(), 68 | lr=opt.lr, 69 | alpha=opt.alpha, 70 | momentum=opt.momentum, 71 | weight_decay=opt.weight_decay) 72 | else: 73 | raise NotImplementedError( 74 | 'optimizer named [%s] is not supported' % opt.optimizer) 75 | 76 | self.optimizers = [self.optimizer] 77 | 78 | def set_input(self, input): 79 | self.data_lr = input['lr'].to(self.device) # save the Cx channels 80 | self.data_hr = input['hr'].to(self.device) 81 | self.data_lr_input = self.data_lr[:, :1, ...] 82 | self.data_hr_input = self.data_hr[:, :1, ...] 83 | self.image_paths = input['fname'] 84 | 85 | def forward(self): 86 | if self.isTrain: 87 | self.data_sr_output, self.pred, *losses = \ 88 | self.netDSR(self.data_lr_input, self.data_hr_input) 89 | for i, loss_name in enumerate(self.loss_names): 90 | setattr(self, 'loss_'+loss_name, losses[i].mean()) 91 | else: 92 | if not self.opt.chop: 93 | self.data_sr_output, self.pred = self.netDSR(self.data_lr_input) 94 | else: 95 | self.data_sr_output, self.pred = N.forward_chop( 96 | self.opt, self.netDSR, self.data_lr_input, 97 | None, shave=10, min_size=160000) 98 | # Y channel is from the network output, while Cb and Cr channels are 99 | # from the tensor super resolved with bicubic algorithm. 100 | self.data_sr = self.data_lr.clone().detach() 101 | self.data_sr[:, :1, ...] = self.data_sr_output 102 | 103 | def backward(self): 104 | self.loss_Total.backward() 105 | 106 | def optimize_parameters(self): 107 | self.forward() 108 | self.optimizer.zero_grad() 109 | self.backward() 110 | self.optimizer.step() 111 | 112 | -------------------------------------------------------------------------------- /models/srresnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dsr_model import DSRModel, base_SRModel 3 | from .networks import AdaResBlock, AdaRCAGroup 4 | 5 | class SRResNetModel(DSRModel): 6 | @staticmethod 7 | def modify_commandline_options(parser, is_train=True): 8 | parser.set_defaults( 9 | n_resblocks = 16, 10 | n_feats = 64, 11 | block_mode = 'CBRCB' 12 | ) 13 | return parser 14 | 15 | def __init__(self, opt): 16 | super(SRResNetModel, self).__init__(opt, SRModel=SRModel) 17 | self.model_names = ['SRResNet'] 18 | self.optimizer_names = ['SRResNet_optimizer_%s' % opt.optimizer] 19 | self.netSRResNet = self.netDSR 20 | 21 | 22 | class SRModel(base_SRModel): 23 | def __init__(self, opt): 24 | self.block = AdaResBlock 25 | self.n_blocks = opt.n_resblocks 26 | self.block_name = 'block' 27 | super(SRModel, self).__init__(opt) -------------------------------------------------------------------------------- /models/vdsr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks as N 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import math 7 | import torch.nn.functional as F 8 | from . import losses as L 9 | 10 | class Conv_ReLU_Block(nn.Module): 11 | def __init__(self, sparse_conv): 12 | super(Conv_ReLU_Block, self).__init__() 13 | self.sparse_conv = sparse_conv 14 | if sparse_conv: 15 | mode = 'c' 16 | else: 17 | mode = 'C' 18 | self.conv = N.conv(64, 64, 3, 1, 1, bias=False, mode=mode) 19 | self.relu = nn.ReLU(True) 20 | 21 | def forward(self, x): 22 | # mask = torch.ones((1, *x.shape[2:]), device=x.device) 23 | if self.sparse_conv: 24 | return self.relu(self.conv(x))#, mask)) 25 | return self.relu(self.conv(x)) 26 | 27 | 28 | class VDSR(nn.Module): 29 | def __init__(self, opt): 30 | super(VDSR, self).__init__() 31 | sparse_conv = opt.sparse_conv 32 | 33 | layers = [] 34 | for ii in range(18): 35 | layers.append(Conv_ReLU_Block(sparse_conv)) 36 | self.layers = layers 37 | self.residual_layer = N.seq(layers) 38 | 39 | self.input = N.conv(1, 64, 3, 1, 1, bias=False, mode='C') 40 | self.relu = nn.ReLU(True) 41 | self.output = N.conv(64, 1, 3, 1, 1, bias=False, mode='C') 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | 48 | self.isTrain = opt.isTrain 49 | self.loss = opt.loss 50 | if self.isTrain: 51 | setattr(self, 'criterion%s'%self.loss, 52 | getattr(L, '%sLoss'%self.loss)()) 53 | 54 | def forward(self, x, hr=None, depth=None): 55 | out = self.relu(self.input(x)) 56 | out = self.residual_layer(out) 57 | out = self.output(out) 58 | out += x 59 | if self.isTrain: 60 | criterion1 = getattr(self, 'criterion%s'%self.loss) 61 | loss1 = criterion1(out, hr) 62 | return out, None, loss1, loss1 63 | return out, None 64 | 65 | 66 | class VDSRModel(BaseModel): 67 | @staticmethod 68 | def modify_commandline_options(parser, is_train=True): 69 | parser.set_defaults( 70 | lr_mode = 'sr', 71 | mode = 'Y', 72 | ) 73 | return parser 74 | 75 | def __init__(self, opt, SRModel=VDSR): 76 | super(VDSRModel, self).__init__(opt) 77 | 78 | self.opt = opt 79 | self.loss_names = [opt.loss, 'Total'] 80 | self.visual_names = ['data_lr', 'data_hr', 'data_sr'] 81 | self.model_names = ['DSR'] 82 | self.optimizer_names = ['DSR_optimizer_%s' % opt.optimizer] 83 | 84 | DSR = SRModel(opt) 85 | self.netDSR = N.init_net(DSR, opt.init_type, opt.init_gain, opt.gpu_ids) 86 | 87 | if self.isTrain: 88 | if opt.optimizer == 'Adam': 89 | self.optimizer = optim.Adam(self.netDSR.parameters(), 90 | lr=opt.lr, 91 | betas=(opt.beta1, opt.beta2), 92 | weight_decay=opt.weight_decay) 93 | elif opt.optimizer == 'SGD': 94 | self.optimizer = optim.SGD(self.netDSR.parameters(), 95 | lr=opt.lr, 96 | momentum=opt.momentum, 97 | weight_decay=opt.weight_decay) 98 | elif opt.optimizer == 'RMSprop': 99 | self.optimizer = optim.RMSprop(self.netDSR.parameters(), 100 | lr=opt.lr, 101 | alpha=opt.alpha, 102 | momentum=opt.momentum, 103 | weight_decay=opt.weight_decay) 104 | else: 105 | raise NotImplementedError( 106 | 'optimizer named [%s] is not supported' % opt.optimizer) 107 | 108 | self.optimizers = [self.optimizer] 109 | 110 | def set_input(self, input): 111 | self.data_lr = input['lr'].to(self.device) # save the Cx channels 112 | self.data_hr = input['hr'].to(self.device) 113 | self.data_lr_input = self.data_lr[:, :1, ...] 114 | self.data_hr_input = self.data_hr[:, :1, ...] 115 | self.image_paths = input['fname'] 116 | 117 | def forward(self, FLOPs_only=False): 118 | if self.isTrain: 119 | self.data_sr_output, self.pred, *losses = \ 120 | self.netDSR(self.data_lr_input, self.data_hr_input) 121 | for i, loss_name in enumerate(self.loss_names): 122 | setattr(self, 'loss_'+loss_name, losses[i].mean()) 123 | else: 124 | if not self.opt.chop: 125 | self.data_sr_output, self.pred = self.netDSR(self.data_lr_input) 126 | else: 127 | self.data_sr_output, self.pred = N.forward_chop( 128 | self.opt, self.netDSR, self.data_lr_input, 129 | None, shave=10, min_size=160000) 130 | # Y channel is from the network output, while Cb and Cr channels are 131 | # from the tensor super resolved with bicubic algorithm. 132 | self.data_sr = self.data_lr.clone().detach() 133 | self.data_sr[:, :1, ...] = self.data_sr_output 134 | 135 | def backward(self): 136 | self.loss_Total.backward() 137 | 138 | def optimize_parameters(self): 139 | self.forward() 140 | self.optimizer.zero_grad() 141 | self.backward() 142 | self.optimizer.step() 143 | 144 | 145 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules.""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | from util import util 5 | import torch 6 | import models 7 | import data 8 | import time 9 | 10 | def str2bool(v): 11 | return v.lower() in ('yes', 'y', 'true', 't', '1') 12 | 13 | inf = float('inf') 14 | 15 | class BaseOptions(): 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # data parameters 23 | parser.add_argument('--dataroot', type=str, default='') 24 | parser.add_argument('--dataset_name', type=str, default=['div2k'], 25 | nargs='+', choices=['div2k', 'set5', 'set14', 26 | 'urban100', 'b100', 'manga109']) 27 | parser.add_argument('--max_dataset_size', type=int, default=inf) 28 | parser.add_argument('--scale', type=int, required=True, 29 | choices=[2, 3, 4], help='Super-resolution scale.') 30 | parser.add_argument('--mode', default='RGB', choices=['RGB', 'L', 'Y'], 31 | help='Currently, only RGB mode is supported.') 32 | parser.add_argument('--imlib', default='cv2', choices=['cv2', 'pillow'], 33 | help='Keep using cv2 unless encountered with problems.') 34 | parser.add_argument('--preload', type=str2bool, default=True, 35 | help='Load all images into memory for efficiency.') 36 | parser.add_argument('--multi_imreader', type=str2bool, default=True, 37 | help='Use multiple cores/threads to load images, will be very ' 38 | 'fast when the images are loaded into cache.') 39 | parser.add_argument('--batch_size', type=int, default=16) 40 | parser.add_argument('--patch_size', type=int, default=None) 41 | parser.add_argument('--lr_mode', default='lr', choices=['lr', 'sr'], 42 | help='lr: take the lr image directly as input. ' 43 | 'sr: upsample the lr image via bicubic at first.') 44 | parser.add_argument('--shuffle', type=str2bool, default=True) 45 | parser.add_argument('-j', '--num_dataloader', default=4, type=int) 46 | parser.add_argument('--drop_last', type=str2bool, default=True) 47 | 48 | # device parameters 49 | parser.add_argument('--gpu_ids', type=str, default='all', 50 | help='Separate the GPU ids by `,`, using all GPUs by default. ' 51 | 'eg, `--gpu_ids 0`, `--gpu_ids 2,3`, `--gpu_ids -1`(CPU)') 52 | parser.add_argument('--checkpoints_dir', type=str, default='./ckpt') 53 | parser.add_argument('-v', '--verbose', type=str2bool, default=True) 54 | parser.add_argument('--suffix', default='', type=str) 55 | 56 | # model parameters 57 | parser.add_argument('--name', type=str, required=True, 58 | help='Name of the folder to save models and logs.') 59 | parser.add_argument('--model', type=str, required=True) 60 | parser.add_argument('--load_path', type=str, default='', 61 | help='Will load pre-trained model if load_path is set') 62 | parser.add_argument('--load_iter', type=int, default=[0], nargs='+', 63 | help='Load parameters if > 0 and load_path is not set. ' 64 | 'Set the value of `last_epoch`') 65 | parser.add_argument('--n_groups', type=int, default=0) 66 | parser.add_argument('--n_resblocks', type=int, default=16) 67 | parser.add_argument('--n_feats', type=int, default=64) 68 | parser.add_argument('--res_scale', type=float, default=1) 69 | parser.add_argument('--block_mode', type=str, default='CRC') 70 | parser.add_argument('--side_ca', type=str2bool, default=False, 71 | help='If True, put Channel Attention module alongside the ' 72 | 'convolution layers in the residual blocks.') 73 | 74 | # AdaDSR parameters 75 | parser.add_argument('--depth', type=float, nargs='+', default=[0]) 76 | parser.add_argument('--sparse_conv', type=str2bool, default=False,\ 77 | help='Replace convolution layers in main body with sparse conv') 78 | parser.add_argument('--channel_attention', type=str, default='none', 79 | choices=['none', '0', '1', 'ca']) 80 | parser.add_argument('--constrain', type=str, default='none', 81 | choices=['none', 'soft', 'hard'], 82 | help='none: no constrain on adapter output; ' 83 | 'soft: constrain with depth loss; ' 84 | 'hard: rescale the depth map to a desired average.') 85 | parser.add_argument('--chop', type=str2bool, default=False) 86 | # adapter parameters 87 | parser.add_argument('--nc_adapter', type=int, default=0, 88 | help='0: no adapter, n: output n channels') 89 | parser.add_argument('--with_depth', type=str2bool, default=True, 90 | help='whether adapter take desired depth as input') 91 | parser.add_argument('--adapter_layers', type=int, default=5) 92 | parser.add_argument('--adapter_reduction', type=int, default=2) 93 | parser.add_argument('--adapter_pos', type=int, default=0) 94 | parser.add_argument('--adapter_bound', type=int, default=None) 95 | parser.add_argument('--multi_adapter', type=str2bool, default=False) 96 | 97 | # training parameters 98 | parser.add_argument('--init_type', type=str, default='default', 99 | choices=['default', 'normal', 'xavier', 100 | 'kaiming', 'orthogonal', 'uniform'], 101 | help='`default` means using PyTorch default init functions.') 102 | parser.add_argument('--init_gain', type=float, default=0.02) 103 | parser.add_argument('--loss', type=str, default='L1', 104 | help='choose from [L1, MSE, SSIM, PSNR]') 105 | parser.add_argument('--optimizer', type=str, default='Adam', 106 | choices=['Adam', 'SGD', 'RMSprop']) 107 | parser.add_argument('--niter', type=int, default=1000) 108 | parser.add_argument('--niter_decay', type=int, default=0) 109 | parser.add_argument('--lr_policy', type=str, default='step') 110 | parser.add_argument('--lr_decay_iters', type=int, default=200) 111 | parser.add_argument('--lr', type=float, default=0.0001) 112 | parser.add_argument('--lambda_pred', type=float, default=0.01) 113 | 114 | # Optimizer 115 | parser.add_argument('--load_optimizers', type=str2bool, default=False, 116 | help='Loading optimizer parameters for continuing training.') 117 | parser.add_argument('--weight_decay', type=float, default=0) 118 | # Adam 119 | parser.add_argument('--beta1', type=float, default=0.9) 120 | parser.add_argument('--beta2', type=float, default=0.999) 121 | # SGD & RMSprop 122 | parser.add_argument('--momentum', type=float, default=0) 123 | # RMSprop 124 | parser.add_argument('--alpha', type=float, default=0.99) 125 | 126 | # visualization parameters 127 | parser.add_argument('--print_freq', type=int, default=100) 128 | parser.add_argument('--test_every', type=int, default=1000) 129 | parser.add_argument('--save_epoch_freq', type=int, default=1) 130 | parser.add_argument('--calc_psnr', type=str2bool, default=False) 131 | parser.add_argument('--save_imgs', type=str2bool, default=False) 132 | 133 | parser.add_argument('--FLOPs_only', type=str2bool, default=False) 134 | parser.add_argument('--matlab', type=str2bool, default=False) 135 | 136 | self.initialized = True 137 | return parser 138 | 139 | def gather_options(self): 140 | """Initialize our parser with basic options(only once). 141 | Add additional model-specific and dataset-specific options. 142 | These options are difined in the function 143 | in model and dataset classes. 144 | """ 145 | if not self.initialized: # check if it has been initialized 146 | parser = argparse.ArgumentParser(formatter_class= 147 | argparse.ArgumentDefaultsHelpFormatter) 148 | parser = self.initialize(parser) 149 | 150 | # get the basic options 151 | opt, _ = parser.parse_known_args() 152 | 153 | # modify model-related parser options 154 | model_name = opt.model 155 | model_option_setter = models.get_option_setter(model_name) 156 | parser = model_option_setter(parser, self.isTrain) 157 | opt, _ = parser.parse_known_args() # parse again with new defaults 158 | 159 | # save and return the parser 160 | self.parser = parser 161 | return parser.parse_args() 162 | 163 | def print_options(self, opt): 164 | """Print and save options 165 | 166 | It will print both current options and default values(if different). 167 | It will save options into a text file / [checkpoints_dir] / opt.txt 168 | """ 169 | message = '' 170 | message += '----------------- Options ---------------\n' 171 | for k, v in sorted(vars(opt).items()): 172 | comment = '' 173 | default = self.parser.get_default(k) 174 | if v != default: 175 | comment = '\t[default: %s]' % str(default) 176 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 177 | message += '----------------- End -------------------' 178 | print(message) 179 | 180 | # save to the disk 181 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 182 | util.mkdirs(expr_dir) 183 | file_name = os.path.join(expr_dir, 'opt_%s.txt' 184 | % ('train' if self.isTrain else 'test')) 185 | with open(file_name, 'wt') as opt_file: 186 | opt_file.write(message) 187 | opt_file.write('\n') 188 | 189 | def parse(self): 190 | opt = self.gather_options() 191 | opt.isTrain = self.isTrain # train or test 192 | opt.serial_batches = not opt.shuffle 193 | 194 | if self.isTrain and (opt.load_iter != [0] or opt.load_path != '') \ 195 | and not opt.load_optimizers: 196 | util.prompt('You are loading a checkpoint and continuing training, ' 197 | 'and no optimizer parameters are loaded. Please make ' 198 | 'sure that the hyper parameters are correctly set.', 80) 199 | time.sleep(3) 200 | 201 | if opt.mode == 'RGB': 202 | opt.input_nc = opt.output_nc = 3 203 | else: # mode = 'L' or 'Y' 204 | opt.input_nc = opt.output_nc = 1 205 | opt.model = opt.model.lower() 206 | opt.name = opt.name.lower() 207 | 208 | scale_patch = {2: 96, 3: 144, 4: 192} 209 | if opt.patch_size is None: 210 | opt.patch_size = scale_patch[opt.scale] 211 | 212 | if opt.name.startswith(opt.checkpoints_dir): 213 | opt.name = opt.name.replace(opt.checkpoints_dir+'/', '') 214 | if opt.name.endswith('/'): 215 | opt.name = opt.name[:-1] 216 | 217 | if len(opt.dataset_name) == 1: 218 | opt.dataset_name = opt.dataset_name[0] 219 | 220 | if len(opt.load_iter) == 1: 221 | opt.load_iter = opt.load_iter[0] 222 | 223 | # process opt.suffix 224 | if opt.suffix != '': 225 | suffix = ('_' + opt.suffix.format(**vars(opt))) 226 | opt.name = opt.name + suffix 227 | 228 | self.print_options(opt) 229 | 230 | # set gpu ids 231 | cuda_device_count = torch.cuda.device_count() 232 | if opt.gpu_ids == 'all': 233 | # GT 710 (3.5), GT 610 (2.1) 234 | gpu_ids = [i for i in range(cuda_device_count)] 235 | else: 236 | p = re.compile('[^-0-9]+') 237 | gpu_ids = [int(i) for i in re.split(p, opt.gpu_ids) if int(i) >= 0] 238 | opt.gpu_ids = [i for i in gpu_ids \ 239 | if torch.cuda.get_device_capability(i) >= (4,0)] 240 | 241 | if len(opt.gpu_ids) == 0 and len(gpu_ids) > 0: 242 | opt.gpu_ids = gpu_ids 243 | util.prompt('You\'re using GPUs with computing capability < 4') 244 | elif len(opt.gpu_ids) != len(gpu_ids): 245 | util.prompt('GPUs(computing capability < 4) have been disabled') 246 | 247 | if len(opt.gpu_ids) > 0: 248 | assert torch.cuda.is_available(), 'No cuda available !!!' 249 | torch.cuda.set_device(opt.gpu_ids[0]) 250 | print('The GPUs you are using:') 251 | for gpu_id in opt.gpu_ids: 252 | print(' %2d *%s* with capability %d.%d' % ( 253 | gpu_id, 254 | torch.cuda.get_device_name(gpu_id), 255 | *torch.cuda.get_device_capability(gpu_id))) 256 | else: 257 | util.prompt('You are using CPU mode') 258 | 259 | self.opt = opt 260 | return self.opt 261 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | self.isTrain = False 8 | return parser 9 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions, str2bool 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | self.isTrain = True 8 | return parser 9 | -------------------------------------------------------------------------------- /scripts/test_adaedsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -n "$1" ]; then 3 | scale=$1 4 | else 5 | scale=2 6 | fi 7 | 8 | if [ -n "$2" ]; then 9 | depth=$2 10 | else 11 | depth=32 12 | fi 13 | 14 | cd .. 15 | 16 | echo "testing with scale $scale" 17 | python test.py \ 18 | --model adaedsr \ 19 | --name adaedsr_x${scale} \ 20 | --scale $scale \ 21 | --load_path ./ckpt/adaedsr_x${scale}/AdaEDSR_model.pth \ 22 | --dataset_name set5 \ 23 | --depth $depth \ 24 | --chop True \ 25 | --sparse_conv True \ 26 | --matlab True \ 27 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_adaedsr_fixd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is for `FAdaEDSR` in the ablation study of the paper. 3 | 4 | scale=2 5 | # Only models of scale x2 are provided, you can train with train_adaedsr_fixd.sh 6 | # if other conditions are needed. 7 | # Note that for convenience, `--depth` is set to 1 by default in all conditions, 8 | # which is equivalent to removing the desired depth $d$. Though 'depth' is used 9 | # as a parameter here, it means the desired depth $d$ in the training procedure. 10 | 11 | if [ -n "$1" ]; then 12 | depth=$1 13 | else 14 | depth=32 15 | fi 16 | 17 | cd .. 18 | 19 | echo "testing with scale $scale" 20 | python test.py \ 21 | --model adaedsr_fixd \ 22 | --name adaedsr_fixd_32_x2_d${depth} \ 23 | --scale $scale \ 24 | --load_path ./ckpt/adaedsr_fixd_32_x2_d${depth}/AdaEDSRFixD_model.pth \ 25 | --dataset_name set5 \ 26 | --chop True \ 27 | --sparse_conv True \ 28 | --matlab True \ 29 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_adarcan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -n "$1" ]; then 3 | scale=$1 4 | else 5 | scale=2 6 | fi 7 | 8 | if [ -n "$2" ]; then 9 | depth=$2 10 | else 11 | depth=20 12 | fi 13 | 14 | cd .. 15 | 16 | echo "testing with scale $scale" 17 | python test.py \ 18 | --model adarcan \ 19 | --name adarcan_x${scale} \ 20 | --scale $scale \ 21 | --load_path ./ckpt/adarcan_x${scale}/AdaRCAN_model.pth \ 22 | --dataset_name set5 \ 23 | --depth $depth \ 24 | --chop True \ 25 | --sparse_conv True \ 26 | --matlab True \ 27 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_edsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is only used for counting inference time and calculating FLOPs, 4 | # and the given checkpoint file is converted from the authors' pytorch version, 5 | # which is slightly higher than their torch version (used in their paper). 6 | # See https://github.com/thstkdgus35/EDSR-PyTorch for official pytorch version. 7 | # To get qualitative results and PSNR/SSIM indices, please refer to the authors' 8 | # torch version: https://github.com/LimBee/NTIRE2017 9 | 10 | # Reference: 11 | # Lim B, Son S, Kim H, et al. Enhanced deep residual networks for single image 12 | # super-resolution[C]//Proceedings of the IEEE conference on computer vision and 13 | # pattern recognition workshops. 2017: 136-144. 14 | 15 | if [ -n "$1" ]; then 16 | scale=$1 17 | else 18 | scale=2 19 | fi 20 | 21 | cd .. 22 | 23 | echo "testing with scale $scale" 24 | python test.py \ 25 | --model edsr \ 26 | --name edsr_x${scale} \ 27 | --scale $scale \ 28 | --load_path ./pretrained/EDSR_official_32_x${scale}.pth \ 29 | --dataset_name set5 \ 30 | --chop True \ 31 | --sparse_conv True \ 32 | --matlab True \ 33 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_rcan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script can generate exactly the same results with the official RCAN code, 4 | # which can be found at https://github.com/yulunzhang/RCAN 5 | 6 | # Reference: 7 | # Zhang Y, Li K, Li K, et al. Image super-resolution using very deep residual 8 | # channel attention networks[C]//Proceedings of the European Conference on 9 | # Computer Vision (ECCV). 2018: 286-301. 10 | 11 | if [ -n "$1" ]; then 12 | scale=$1 13 | else 14 | scale=2 15 | fi 16 | 17 | cd .. 18 | 19 | echo "testing with scale $scale" 20 | python test.py \ 21 | --model rcan \ 22 | --name rcan_x${scale} \ 23 | --scale $scale \ 24 | --load_path ./pretrained/RCAN_BIX${scale}.pth \ 25 | --dataset_name set5 \ 26 | --chop True \ 27 | --sparse_conv True \ 28 | --matlab True \ 29 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_rdn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script generates !! nearly !! the same results with the official code, 4 | # which can be found at https://github.com/yulunzhang/RDN (a torch version), as 5 | # we converted the official torch models to pytorch models. 6 | # If exactly official results are required, please refer to the authors' repo. 7 | 8 | # Reference: 9 | # Zhang Y, Tian Y, Kong Y, et al. Residual dense network for image 10 | # super-resolution[C]//Proceedings of the IEEE conference on computer vision 11 | # and pattern recognition. 2018: 2472-2481. 12 | 13 | if [ -n "$1" ]; then 14 | scale=$1 15 | else 16 | scale=2 17 | fi 18 | 19 | cd .. 20 | 21 | echo "testing with scale $scale" 22 | python test.py \ 23 | --model rdn \ 24 | --name rdn_x${scale} \ 25 | --scale $scale \ 26 | --load_path ./ckpt/rdn_x${scale}/RDN_BIX${scale}.pth \ 27 | --dataset_name set5 \ 28 | --chop True \ 29 | --sparse_conv True \ 30 | --matlab True \ 31 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_san.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script can generate exactly the same results with the official SAN code, 4 | # which can be found at https://github.com/daitao/SAN 5 | 6 | # NOTE that we optimized `class Covpool` (AdaDSR/models/MPNCOV/python/MPNCOV.py) 7 | # for faster inference, and you may obtain much shorter inference time than that 8 | # reported in the paper using this script. 9 | 10 | # Reference: 11 | # Dai T, Cai J, Zhang Y, et al. Second-order attention network for single image 12 | # super-resolution[C]//Proceedings of the IEEE Conference on Computer Vision and 13 | # Pattern Recognition. 2019: 11065-11074. 14 | 15 | if [ -n "$1" ]; then 16 | scale=$1 17 | else 18 | scale=2 19 | fi 20 | 21 | cd .. 22 | 23 | echo "testing with scale $scale" 24 | python test.py \ 25 | --model san \ 26 | --name san_x${scale} \ 27 | --scale $scale \ 28 | --load_path ./ckpt/san_model/SAN_BIX${scale}.pth \ 29 | --dataset_name set5 \ 30 | --chop True \ 31 | --sparse_conv True \ 32 | --matlab True \ 33 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_srcnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is only used for counting inference time and calculating FLOPs. 4 | # To get qualitative results and PSNR/SSIM indices, please refer to the authors' 5 | # project: http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html 6 | 7 | # Reference: 8 | # Dong C, Loy C C, He K, et al. Image super-resolution using deep convolutional 9 | # networks[J]. IEEE transactions on pattern analysis and machine intelligence, 10 | # 2015, 38(2): 295-307. 11 | 12 | # Note that SRCNN takes super-resolved image as input. 13 | 14 | if [ -n "$1" ]; then 15 | scale=$1 16 | else 17 | scale=2 18 | fi 19 | 20 | cd .. 21 | 22 | echo "testing with scale $scale" 23 | python test.py \ 24 | --model srcnn \ 25 | --name srcnn_${scale} \ 26 | --scale $scale \ 27 | --dataset_name set5 \ 28 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/test_vdsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script is only used for counting inference time and calculating FLOPs, 4 | # and the given checkpoint file is not able to generate official results. 5 | # To get qualitative results and PSNR/SSIM indices, please refer to the authors' 6 | # project: https://cv.snu.ac.kr/research/VDSR/ 7 | 8 | # Reference: 9 | # Kim J, Kwon Lee J, Mu Lee K. Accurate image super-resolution using very deep 10 | # convolutional networks[C]//Proceedings of the IEEE conference on computer 11 | # vision and pattern recognition. 2016: 1646-1654. 12 | 13 | # Note that VDSR takes super-resolved image as input. 14 | 15 | if [ -n "$1" ]; then 16 | scale=$1 17 | else 18 | scale=2 19 | fi 20 | 21 | cd .. 22 | 23 | echo "testing with scale $scale" 24 | python test.py \ 25 | --model vdsr \ 26 | --name vdsr_x${scale} \ 27 | --scale $scale \ 28 | --dataset_name set5 \ 29 | --load_path ./ckpt/vdsr/vdsr_model.pth \ 30 | --matlab True \ 31 | --sparse_conv True \ 32 | --gpu_ids 0 -------------------------------------------------------------------------------- /scripts/train_adaedsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -n "$1" ]; then 3 | scale=$1 4 | else 5 | scale=2 6 | fi 7 | 8 | cd .. 9 | 10 | echo "training with scale $scale" 11 | python train.py \ 12 | --model adaedsr \ 13 | --name adaedsr_x${scale} \ 14 | --scale $scale \ 15 | --load_path ./pretrained/EDSR_official_32_x${scale}.pth -------------------------------------------------------------------------------- /scripts/train_adaedsr_fixd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -n "$1" ]; then 3 | scale=$1 4 | else 5 | scale=2 6 | fi 7 | 8 | if [ -n "$2" ]; then 9 | depth=$2 10 | else 11 | depth=32 12 | fi 13 | 14 | cd .. 15 | 16 | echo "training with scale $scale" 17 | python train.py \ 18 | --model adaedsr_fixd \ 19 | --name adaedsr_fixd_32_x${scale}_d${depth} \ 20 | --scale $scale \ 21 | --depth $depth \ 22 | --load_path ./pretrained/EDSR_official_32_x${scale}.pth -------------------------------------------------------------------------------- /scripts/train_adarcan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -n "$1" ]; then 3 | scale=$1 4 | else 5 | scale=2 6 | fi 7 | 8 | cd .. 9 | 10 | echo "training with scale $scale" 11 | python train.py \ 12 | --model adarcan \ 13 | --name adarcan_x${scale} \ 14 | --scale $scale \ 15 | --load ./pretrained/RCAN_BIX${scale}.pth -------------------------------------------------------------------------------- /scripts/train_edsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains an edsr model. 4 | 5 | # Reference: 6 | # Lim B, Son S, Kim H, et al. Enhanced deep residual networks for single image 7 | # super-resolution[C]//Proceedings of the IEEE conference on computer vision and 8 | # pattern recognition workshops. 2017: 136-144. 9 | 10 | if [ -n "$1" ]; then 11 | scale=$1 12 | else 13 | scale=2 14 | fi 15 | 16 | cd .. 17 | 18 | echo "training with scale $scale" 19 | python train.py \ 20 | --model edsr \ 21 | --name edsr_x${scale} \ 22 | --scale $scale -------------------------------------------------------------------------------- /scripts/train_rcan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains a rcan model. 4 | 5 | # Reference: 6 | # Zhang Y, Li K, Li K, et al. Image super-resolution using very deep residual 7 | # channel attention networks[C]//Proceedings of the European Conference on 8 | # Computer Vision (ECCV). 2018: 286-301. 9 | 10 | if [ -n "$1" ]; then 11 | scale=$1 12 | else 13 | scale=2 14 | fi 15 | 16 | cd .. 17 | 18 | echo "training with scale $scale" 19 | python test.py \ 20 | --model rcan \ 21 | --name rcan_x${scale} \ 22 | --scale $scale -------------------------------------------------------------------------------- /scripts/train_rdn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains a rdn model. 4 | 5 | # Reference: 6 | # Zhang Y, Tian Y, Kong Y, et al. Residual dense network for image 7 | # super-resolution[C]//Proceedings of the IEEE conference on computer vision 8 | # and pattern recognition. 2018: 2472-2481. 9 | 10 | if [ -n "$1" ]; then 11 | scale=$1 12 | else 13 | scale=2 14 | fi 15 | 16 | cd .. 17 | 18 | echo "training with scale $scale" 19 | python test.py \ 20 | --model rdn \ 21 | --name rdn_x${scale} \ 22 | --scale $scale -------------------------------------------------------------------------------- /scripts/train_san.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains an SAN model. 4 | 5 | # NOTE that currently supports single-GPU training only. 6 | 7 | # Reference: 8 | # Dai T, Cai J, Zhang Y, et al. Second-order attention network for single image 9 | # super-resolution[C]//Proceedings of the IEEE Conference on Computer Vision and 10 | # Pattern Recognition. 2019: 11065-11074. 11 | 12 | if [ -n "$1" ]; then 13 | scale=$1 14 | else 15 | scale=2 16 | fi 17 | 18 | cd .. 19 | 20 | echo "training with scale $scale" 21 | python test.py \ 22 | --model san \ 23 | --name san_x${scale} \ 24 | --scale $scale \ 25 | --gpu_ids 0 \ 26 | --chop True # otherwise, may cause `Out Of Memory (OOM)` error -------------------------------------------------------------------------------- /scripts/train_srcnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains a srcnn model. 4 | 5 | # The input is image super-resolved by bicubic algorithm, and only Y channel 6 | # (of YCbCr color space) is used. 7 | # In the evaluation procedure (set `--calc_psnr True`), the Y channel output is 8 | # directly used to calculate the PSNR index. 9 | 10 | # Reference: 11 | # Dong C, Loy C C, He K, et al. Image super-resolution using deep convolutional 12 | # networks[J]. IEEE transactions on pattern analysis and machine intelligence, 13 | # 2015, 38(2): 295-307. 14 | 15 | if [ -n "$1" ]; then 16 | scale=$1 17 | else 18 | scale=2 19 | fi 20 | 21 | cd .. 22 | 23 | echo "training with scale $scale" 24 | python train.py \ 25 | --model srcnn \ 26 | --name srcnn_x${scale} \ 27 | --scale $scale -------------------------------------------------------------------------------- /scripts/train_vdsr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script trains a vdsr model. 4 | 5 | # The input is image super-resolved by bicubic algorithm, and only Y channel 6 | # (of YCbCr color space) is used. 7 | # In the evaluation procedure (set `--calc_psnr True`), the Y channel output is 8 | # directly used to calculate the PSNR index. 9 | 10 | # Reference: 11 | # Kim J, Kwon Lee J, Mu Lee K. Accurate image super-resolution using very deep 12 | # convolutional networks[C]//Proceedings of the IEEE conference on computer 13 | # vision and pattern recognition. 2016: 1646-1654. 14 | 15 | if [ -n "$1" ]; then 16 | scale=$1 17 | else 18 | scale=2 19 | fi 20 | 21 | cd .. 22 | 23 | echo "training with scale $scale" 24 | python train.py \ 25 | --model vdsr \ 26 | --name vdsr_x${scale} \ 27 | --scale $scale -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from options.test_options import TestOptions 4 | from data import create_dataset 5 | from models import create_model, networks as N 6 | from util.visualizer import Visualizer 7 | from tqdm import tqdm 8 | from train import calc_psnr 9 | import time 10 | import numpy as np 11 | from matplotlib import pyplot as plt 12 | from collections import OrderedDict as odict 13 | from copy import deepcopy 14 | import shutil 15 | 16 | # for FLOPs 17 | from flops import FLOPs, find, methods, chop, chop_pred, cvt 18 | 19 | if __name__ == '__main__': 20 | opt = TestOptions().parse() 21 | # log_dir = '%s/%s/psnr_x%s.txt' % (opt.checkpoints_dir, opt.name, opt.scale) 22 | # f = open(log_dir, 'a') 23 | 24 | opt_depths = deepcopy(opt.depth) 25 | opt.depth = [opt.depth[0]] 26 | if not isinstance(opt.load_iter, list): 27 | load_iters = [opt.load_iter] 28 | else: 29 | load_iters = deepcopy(opt.load_iter) 30 | 31 | if not isinstance(opt.dataset_name, list): 32 | dataset_names = [opt.dataset_name] 33 | else: 34 | dataset_names = deepcopy(opt.dataset_name) 35 | datasets = odict() 36 | for dataset_name in dataset_names: 37 | dataset = create_dataset(dataset_name, 'test', opt) 38 | datasets[dataset_name] = tqdm(dataset) 39 | 40 | # FLOPs 41 | if opt.model in ('adaedsr', 'adarcan'): 42 | func = getattr(FLOPs, find(opt.model[3:])) 43 | elif opt.model == 'adaedsr_fixd': 44 | func = FLOPs.EDSR 45 | else: 46 | func = getattr(FLOPs, find(opt.model)) 47 | 48 | for load_iter in load_iters: 49 | opt.load_iter = load_iter 50 | model = create_model(opt) 51 | model.setup(opt) 52 | model.eval() 53 | with_depth = hasattr(model, 'nc_adapter') and model.nc_adapter 54 | log_dir = '%s/%s/logs/log_x%d_epoch%d.txt' % ( 55 | opt.checkpoints_dir, opt.name, opt.scale, load_iter) 56 | os.makedirs(os.path.split(log_dir)[0], exist_ok=True) 57 | f = open(log_dir, 'a') 58 | 59 | for depth in opt_depths: 60 | if with_depth: 61 | opt.depth = [depth] 62 | model.depth_gen = N.num_generator(opt.depth) 63 | 64 | for dataset_name in dataset_names: 65 | opt.dataset_name = dataset_name 66 | tqdm_val = datasets[dataset_name] 67 | dataset_test = tqdm_val.iterable 68 | dataset_size_test = len(dataset_test) 69 | 70 | print('='*80) 71 | print(dataset_name, depth) 72 | tqdm_val.reset() 73 | 74 | 75 | if opt.matlab: 76 | shutil.rmtree('./tmp', ignore_errors=True) 77 | os.makedirs('./tmp/HR', exist_ok=True) 78 | os.makedirs('./tmp/SR', exist_ok=True) 79 | 80 | psnr = [0.0] * dataset_size_test 81 | ssim = [0.0] * dataset_size_test 82 | _sum = [0.0] * dataset_size_test # FLOPs 83 | if with_depth: 84 | depths = [0.0] * dataset_size_test 85 | time_val = 0 86 | for i, data in enumerate(tqdm_val): 87 | if not opt.FLOPs_only or opt.model not in ( 88 | 'srcnn', 'vdsr', 'rdn', 'san', 'edsr', 'rcan'): 89 | torch.cuda.empty_cache() 90 | model.set_input(data) 91 | torch.cuda.synchronize() 92 | time_val_start = time.time() 93 | model.test(opt.FLOPs_only) 94 | torch.cuda.synchronize() 95 | time_val += time.time() - time_val_start 96 | res = model.get_current_visuals() 97 | if with_depth: 98 | depths[i] = (torch.ceil(torch.clamp( 99 | res['pred'], 0, opt.n_resblocks)).mean()).item() 100 | if not opt.matlab: 101 | if opt.mode in ('L', 'RGB'): 102 | psnr[i] = calc_psnr(res['data_hr'], 103 | res['data_sr'], 104 | opt.scale) 105 | else: # opt.mode == 'Y': 106 | assert opt.mode == 'Y' 107 | psnr[i] = calc_psnr(res['data_hr'][:, :1], 108 | res['data_sr'][:, :1], 109 | opt.scale) 110 | # FLOPs 111 | in_shape = np.array(data[methods[opt.model]].shape[-2:]) 112 | scale = opt.scale 113 | if with_depth: 114 | mask = np.array(res['pred'].cpu().squeeze()) 115 | else: 116 | mask = None 117 | if opt.chop: 118 | in_shapes = chop(in_shape) 119 | if mask is not None: 120 | if len(mask.shape) == 2: 121 | masks = chop_pred(mask) 122 | elif len(mask.shape) == 3: 123 | masks = np.array([chop_pred(m) for m in mask]) 124 | masks = masks.transpose(1, 0, 2, 3) 125 | else: 126 | raise ValueError 127 | for ii in range(in_shapes.shape[0]): 128 | maskii = masks[ii] if mask is not None else None 129 | _sum[i] += func(in_shapes[ii], scale, maskii) 130 | if opt.model in ('adarcan', 'adaedsr', 'adaedsr_fixd'): 131 | _sum[i] += getattr(FLOPs, 132 | find(opt.model))(in_shape, scale) 133 | else: 134 | _sum[i] = func(in_shape, scale, mask) 135 | if opt.FLOPs_only: 136 | continue 137 | if opt.save_imgs: 138 | folder_dir = '%s/compare/x%d/%s/%s' % ( 139 | opt.checkpoints_dir, 140 | opt.scale, 141 | opt.dataset_name, 142 | os.path.basename(data['fname'][0]).split('.')[0]) 143 | depth_folder_dir = folder_dir+'_depth' 144 | os.makedirs(depth_folder_dir, exist_ok=True) 145 | if with_depth: 146 | save_dir = '%s/%s_%ddepth.png' % ( 147 | folder_dir, opt.name, depth) 148 | for idx in range(res['pred'].shape[1]): 149 | pred_dir = '%s/%s_d%d_p%d' % ( 150 | depth_folder_dir, opt.name, depth, idx) 151 | plt.figure(1) 152 | plt.clf() 153 | plt.axis('off') 154 | img = plt.imshow(res['pred'][0, idx].cpu(), 155 | vmin=0, vmax=opt.n_resblocks, 156 | cmap=plt.cm.hot) 157 | plt.colorbar() 158 | plt.savefig(pred_dir) 159 | else: 160 | save_dir = '%s/%s.png' % (folder_dir, opt.name) 161 | dataset_test.imio.write(np.array(res['data_sr'][0].cpu() 162 | ).astype(np.uint8), save_dir) 163 | if opt.matlab: 164 | dataset_test.imio.write(np.array(res['data_sr'][0][:, 165 | opt.scale:-opt.scale, opt.scale:-opt.scale].cpu() 166 | ).astype(np.uint8), './tmp/SR/%d.png'%i) 167 | dataset_test.imio.write(np.array(res['data_hr'][0][:, 168 | opt.scale:-opt.scale, opt.scale:-opt.scale].cpu() 169 | ).astype(np.uint8), './tmp/HR/%d.png'%i) 170 | if opt.FLOPs_only: 171 | print('dataset: %s, depth: %d\n%s %s' % ( 172 | dataset_name, depth, 173 | cvt(np.sum(_sum)), cvt(np.mean(_sum)))) 174 | f.write('dataset: %s, depth: %d\n%s %s\n' % ( 175 | dataset_name, depth, 176 | cvt(np.sum(_sum)), cvt(np.mean(_sum)))) 177 | f.flush() 178 | continue 179 | 180 | if opt.matlab: 181 | print('Calcualting PSNR and SSIM with matlab ...') 182 | os.system('matlab -nodesktop -nosplash -r' 183 | ' "run(\'calc_psnr_ssim.m\');exit;"' 184 | ' > /dev/null') 185 | fres = open('result.txt', 'r') 186 | m_psnr, m_ssim = fres.readlines()[0].strip().split() 187 | fres.close() 188 | avg_psnr, avg_ssim = m_psnr, m_ssim 189 | else: 190 | avg_psnr = '%.6f'%np.mean(psnr) 191 | avg_ssim = '%.6f'%np.mean(ssim) 192 | 193 | if with_depth: 194 | print('desired depth:', depth, 195 | 'mean depth:', np.mean(depths)) 196 | f.write('dataset: %s, depth: %d, mean_depth: %.4f, ' 197 | 'PSNR: %s, SSIM: %s, Time: %.3f sec.\n%s %s\n' 198 | % (dataset_name, depth, np.mean(depths), 199 | avg_psnr, avg_ssim, time_val, 200 | cvt(np.sum(_sum)), cvt(np.mean(_sum)))) 201 | else: 202 | f.write('dataset: %s, PSNR: %s, SSIM: %s, ' 203 | 'Time: %.3f sec.\n%s %s\n' 204 | % (dataset_name, avg_psnr, avg_ssim, time_val, 205 | cvt(np.sum(_sum)), cvt(np.mean(_sum)))) 206 | print('Time: %.3f s AVG Time: %.3f ms PSNR: %s SSIM: %s\n%s %s' 207 | % (time_val, time_val/dataset_size_test*1000, avg_psnr, 208 | avg_ssim, cvt(np.sum(_sum)), cvt(np.mean(_sum)))) 209 | f.flush() 210 | f.write('\n') 211 | f.close() 212 | for dataset in datasets: 213 | datasets[dataset].close() 214 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from options.train_options import TrainOptions 4 | from data import create_dataset 5 | from models import create_model 6 | from util.visualizer import Visualizer 7 | from tqdm import tqdm 8 | import numpy as np 9 | import math 10 | import sys 11 | import torch.multiprocessing as mp 12 | 13 | from util.util import calc_psnr as calc_psnr 14 | #from util.util import calc_psnr_np as calc_psnr 15 | 16 | if __name__ == '__main__': 17 | opt = TrainOptions().parse() 18 | dataset_train = create_dataset('div2k', 'train', opt) 19 | dataset_size_train = len(dataset_train) 20 | print('The number of training images = %d' % dataset_size_train) 21 | dataset_val = create_dataset('div2k', 'val', opt) 22 | dataset_size_val = len(dataset_val) 23 | print('The number of val images = %d' % dataset_size_val) 24 | 25 | model = create_model(opt) 26 | model.setup(opt) 27 | visualizer = Visualizer(opt) 28 | total_iters = 0 29 | 30 | for epoch in range(model.start_epoch + 1, opt.niter + opt.niter_decay + 1): 31 | 32 | # training 33 | epoch_start_time = time.time() 34 | epoch_iter = 0 35 | model.train() 36 | if hasattr(model, 'depth_gen') and model.depth_gen is not None: 37 | model.depth_gen.train() 38 | 39 | iter_data_time = iter_start_time = time.time() 40 | for i, data in enumerate(dataset_train): 41 | if total_iters % opt.print_freq == 0: 42 | t_data = time.time() - iter_data_time 43 | total_iters += 1 #opt.batch_size 44 | epoch_iter += 1 #opt.batch_size 45 | model.set_input(data) 46 | model.optimize_parameters() 47 | 48 | if total_iters % opt.print_freq == 0: 49 | losses = model.get_current_losses() 50 | t_comp = (time.time() - iter_start_time) 51 | visualizer.print_current_losses( 52 | epoch, epoch_iter, losses, t_comp, t_data, total_iters) 53 | # if opt.save_imgs: # Too many images 54 | # visualizer.display_current_results( 55 | # 'train', model.get_current_visuals(), total_iters) 56 | iter_start_time = time.time() 57 | 58 | iter_data_time = time.time() 59 | if epoch % opt.save_epoch_freq == 0: 60 | print('saving the model at the end of epoch %d, iters %d' 61 | % (epoch, total_iters)) 62 | model.save_networks(epoch) 63 | 64 | print('End of epoch %d / %d \t Time Taken: %.3f sec' 65 | % (epoch, opt.niter + opt.niter_decay, 66 | time.time() - epoch_start_time)) 67 | model.update_learning_rate() 68 | 69 | # val 70 | if opt.calc_psnr or opt.save_imgs: 71 | model.eval() 72 | if hasattr(model, 'depth_gen') and model.depth_gen is not None: 73 | model.depth_gen.eval() # returns the upper bound of depth 74 | val_iter_time = time.time() 75 | tqdm_val = tqdm(dataset_val) 76 | psnr = [0.0] * dataset_size_val 77 | time_val = 0 78 | for i, data in enumerate(tqdm_val): 79 | model.set_input(data) 80 | time_val_start = time.time() 81 | with torch.no_grad(): 82 | model.test() 83 | time_val += time.time() - time_val_start 84 | res = model.get_current_visuals() 85 | if opt.mode in ('L', 'RGB'): 86 | psnr[i] = calc_psnr(res['data_hr'], 87 | res['data_sr'], 88 | opt.scale) 89 | else: # opt.mode == 'Y': 90 | assert opt.mode == 'Y' 91 | psnr[i] = calc_psnr(res['data_hr'][:, :1], 92 | res['data_sr'][:, :1], 93 | opt.scale) 94 | if opt.save_imgs: 95 | visualizer.display_current_results('val', res, epoch) 96 | visualizer.writer.add_scalar('val/psnr', np.mean(psnr), epoch) 97 | print('End of epoch %d / %d (Val) \t Time Taken: %.3f s \t PSNR: %f' 98 | % (epoch, opt.niter + opt.niter_decay, time_val, np.mean(psnr))) 99 | 100 | sys.stdout.flush() 101 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of helper functions.""" 2 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | def calc_psnr_np(sr, hr, scale): 9 | """ calculate psnr by numpy 10 | 11 | Params: 12 | sr : numpy.uint8 13 | super-resolved image 14 | hr : numpy.uint8 15 | high-resolution ground truth 16 | scale : int 17 | super-resolution scale 18 | """ 19 | diff = (sr.astype(np.float32) - hr.astype(np.float32)) / 255. 20 | shave = scale 21 | if diff.shape[1] > 1: 22 | convert = np.zeros((1, 3, 1, 1), diff.dtype) 23 | convert[0, 0, 0, 0] = 65.738 24 | convert[0, 1, 0, 0] = 129.057 25 | convert[0, 2, 0, 0] = 25.064 26 | diff = diff * (convert) / 256 27 | diff = diff.sum(axis=1, keepdims=True) 28 | 29 | valid = diff[:, :, shave:-shave, shave:-shave] 30 | mse = np.power(valid, 2).mean() 31 | return -10 * math.log10(mse) 32 | 33 | def calc_psnr(sr, hr, scale): 34 | """ calculate psnr by torch 35 | 36 | Params: 37 | sr : torch.float32 38 | super-resolved image 39 | hr : torch.float32 40 | high-resolution ground truth 41 | scale : int 42 | super-resolution scale 43 | """ 44 | with torch.no_grad(): 45 | diff = (sr - hr) / 255. 46 | shave = scale 47 | if diff.shape[1] > 1: 48 | diff *= torch.tensor([65.738, 129.057, 25.064], 49 | device=sr.device).view(1, 3, 1, 1) / 256 50 | diff = diff.sum(dim=1, keepdim=True) 51 | valid = diff[..., shave:-shave, shave:-shave] 52 | mse = torch.pow(valid, 2).mean() 53 | return (-10 * torch.log10(mse)).item() 54 | 55 | 56 | def diagnose_network(net, name='network'): 57 | """Calculate and print the mean of average absolute(gradients) 58 | 59 | Parameters: 60 | net (torch network) -- Torch network 61 | name (str) -- the name of the network 62 | """ 63 | mean = 0.0 64 | count = 0 65 | for param in net.parameters(): 66 | if param.grad is not None: 67 | mean += torch.mean(torch.abs(param.grad.data)) 68 | count += 1 69 | if count > 0: 70 | mean = mean / count 71 | print(name) 72 | print(mean) 73 | 74 | 75 | def print_numpy(x, val=True, shp=False): 76 | """Print the mean, min, max, median, std, and size of a numpy array 77 | 78 | Parameters: 79 | val (bool) -- if print the values of the numpy array 80 | shp (bool) -- if print the shape of the numpy array 81 | """ 82 | x = x.astype(np.float64) 83 | if shp: 84 | print('shape,', x.shape) 85 | if val: 86 | x = x.flatten() 87 | print('mean = %3.3f, min = %3.3f, max = %3.3f, mid = %3.3f, std=%3.3f' 88 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 89 | 90 | 91 | def mkdirs(paths): 92 | """create empty directories if they don't exist 93 | 94 | Parameters: 95 | paths (str list) -- a list of directory paths 96 | """ 97 | if isinstance(paths, list) and not isinstance(paths, str): 98 | for path in paths: 99 | mkdir(path) 100 | else: 101 | mkdir(paths) 102 | 103 | 104 | def mkdir(path): 105 | """create a single empty directory if it didn't exist 106 | 107 | Parameters: 108 | path (str) -- a single directory path 109 | """ 110 | if not os.path.exists(path): 111 | os.makedirs(path) 112 | 113 | def prompt(s, width=66): 114 | print('='*(width+4)) 115 | ss = s.split('\n') 116 | if len(ss) == 1 and len(s) <= width: 117 | print('= ' + s.center(width) + ' =') 118 | else: 119 | for s in ss: 120 | for i in split_str(s, width): 121 | print('= ' + i.ljust(width) + ' =') 122 | print('='*(width+4)) 123 | 124 | def split_str(s, width): 125 | ss = [] 126 | while len(s) > width: 127 | idx = s.rfind(' ', 0, width+1) 128 | if idx > width >> 1: 129 | ss.append(s[:idx]) 130 | s = s[idx+1:] 131 | else: 132 | ss.append(s[:width]) 133 | s = s[width:] 134 | if s.strip() != '': 135 | ss.append(s) 136 | return ss -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import join 3 | from tensorboardX import SummaryWriter 4 | from matplotlib import pyplot as plt 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | class Visualizer(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | if opt.isTrain: 12 | self.name = opt.name 13 | self.save_dir = join(opt.checkpoints_dir, opt.name, 'log') 14 | self.writer = SummaryWriter(logdir=join(self.save_dir)) 15 | else: 16 | self.name = '%s_%s_%d' % ( 17 | opt.name, opt.dataset_name, opt.load_iter) 18 | self.save_dir = join(opt.checkpoints_dir, opt.name) 19 | if opt.save_imgs: 20 | self.writer = SummaryWriter(logdir=join( 21 | self.save_dir, 'ckpts', self.name)) 22 | 23 | def display_current_results(self, phase, visuals, iters): 24 | for k, v in visuals.items(): 25 | v = v.cpu() 26 | if k == 'pred': 27 | self.process_preds(self.writer, phase, k, v, iters) 28 | else: 29 | self.writer.add_image('%s/%s'%(phase, k), v[0]/255, iters) 30 | self.writer.flush() 31 | 32 | def process_pred(self, pred): 33 | buffer = BytesIO() 34 | plt.figure(1) 35 | plt.clf() 36 | plt.axis('off') 37 | img = plt.imshow(pred, cmap=plt.cm.hot) 38 | plt.colorbar() 39 | plt.savefig(buffer) 40 | im = np.array(Image.open(buffer).convert('RGB')).transpose(2, 0, 1) 41 | buffer.close() 42 | return im / 255 43 | 44 | def process_preds(self, writer, phase, k, v, iters): 45 | preds = v[0] 46 | if len(preds) == 1: 47 | writer.add_image('%s/%s'%(phase, k), 48 | self.process_pred(preds[0]), 49 | iters) 50 | else: 51 | writer.add_images('%s/%s'%(phase, k), 52 | np.stack([self.process_pred(pred)\ 53 | for pred in preds]), 54 | iters) 55 | 56 | def print_current_losses(self, epoch, iters, losses, 57 | t_comp, t_data, total_iters): 58 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' \ 59 | % (epoch, iters, t_comp, t_data) 60 | for k, v in losses.items(): 61 | message += '%s: %.4e ' % (k, v) 62 | self.writer.add_scalar('loss/%s'%k, v, total_iters) 63 | 64 | print(message) 65 | --------------------------------------------------------------------------------