├── .idea ├── CSD.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── modules.xml └── workspace.xml ├── MindSpore version ├── README.md ├── csd_train.py ├── eval.py ├── export.py ├── src │ ├── args.py │ ├── common.py │ ├── config.py │ ├── contras_loss.py │ ├── data │ │ ├── common.py │ │ ├── div2k.py │ │ └── srdata.py │ ├── edsr_model.py │ ├── edsr_slim.py │ ├── metric.py │ ├── metrics.py │ ├── rcan_model.py │ └── vgg_model.py ├── train.py └── utils │ ├── __init__.py │ └── var_init.py ├── PyTorch version ├── README.md ├── data │ ├── __init__.py │ ├── benchmark.py │ ├── bsd500.py │ ├── common.py │ ├── demo.py │ ├── div2k.py │ ├── div2kjpeg.py │ ├── sr291.py │ ├── srdata.py │ └── video.py ├── dataloader.py ├── loss │ ├── __init__.py │ ├── adversarial.py │ ├── contrast_loss.py │ ├── discriminator.py │ ├── perceptual.py │ └── vgg.py ├── main.py ├── model │ ├── __init__.py │ ├── carn.py │ ├── common.py │ ├── edsr.py │ └── rcan.py ├── option.py ├── trainer │ └── slim_contrast_trainer.py └── utils │ ├── __init__.py │ ├── niqe.py │ ├── niqe_image_params.mat │ ├── spatial_trans.py │ ├── ssim.py │ └── utility.py ├── README.md └── images ├── debug.log ├── model.png ├── psnr-speed.png ├── psnr-tradeoff.png ├── table.png ├── tradeoff.png └── visual.png /.idea/CSD.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 43 | 44 | 45 | 46 | 47 | 48 | 1620185824731 49 | 54 | 55 | 56 | 57 | 59 | 60 | 61 | 62 | 63 | file://$PROJECT_DIR$/main.py 64 | 8 65 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /MindSpore version/README.md: -------------------------------------------------------------------------------- 1 | ## Dependencies 2 | 3 | - Python == 3.7.5 4 | 5 | - MindSpore: https://www.mindspore.cn/install 6 | 7 | - matplotlib 8 | 9 | - imageio 10 | 11 | - tensorboardX 12 | 13 | - opencv-python 14 | 15 | - scipy 16 | 17 | - scikit-image 18 | 19 | ## Train 20 | 21 | ### Prepare data 22 | 23 | We use DIV2K training set as our training data. 24 | 25 | About how to download data, you could refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) 26 | 27 | ### Train baseline model 28 | 29 | ```bash 30 | # train teacher model 31 | python -u train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_baseline --lr 0.0001 --epochs 5000 32 | ``` 33 | 34 | ```bash 35 | # train student model (width=0.25) 36 | python -u train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_baseline025 --lr 0.0001 --epochs 5000 --n_feats 64 37 | ``` 38 | 39 | ### Train CSD 40 | 41 | VGG pre-trained on ImageNet is used in our contrastive loss. Due to copyright reasons, the pre-trained VGG cannot be shared publicly. 42 | 43 | ```bash 44 | python -u csd_train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_csd --lr 0.0001 --epochs 5000 --ckpt_path ckpt/TEACHER_MODEL_NAME.ckpt --contra_lambda 200 45 | ``` 46 | 47 | ## Test 48 | 49 | ```bash 50 | python eval.py --dir_data LOCATION_OF_DATA --test_only --ext "img" --data_test B100 --ckpt_path ckpt/MODEL_NAME.ckpt --task_id 0 --scale 4 51 | ``` -------------------------------------------------------------------------------- /MindSpore version/csd_train.py: -------------------------------------------------------------------------------- 1 | from mindspore.train.callback import ModelCheckpoint, Callback, LossMonitor, TimeMonitor, CheckpointConfig, _InternalCallbackParam, RunContext 2 | import mindspore.nn as nn 3 | from mindspore import ParameterTuple, Tensor 4 | import mindspore.ops as ops 5 | import mindspore.numpy as numpy 6 | from mindspore import load_checkpoint, load_param_into_net, save_checkpoint 7 | from mindspore.common import set_seed 8 | from mindspore import context 9 | from mindspore.context import ParallelMode 10 | import mindspore.dataset as ds 11 | 12 | import os 13 | import time 14 | 15 | from src.metric import PSNR 16 | from src.args import args 17 | from src.data.div2k import DIV2K 18 | from src.data.srdata import SRData 19 | from src.edsr_slim import EDSR 20 | from src.contras_loss import ContrastLoss 21 | # from eval import do_eval 22 | 23 | class NetWithLossCell(nn.Cell): 24 | def __init__(self, net): 25 | super(NetWithLossCell, self).__init__() 26 | self.net = net 27 | self.l1_loss = nn.L1Loss() 28 | 29 | def construct(self, lr, hr, stu_width_mult, tea_width_mult): 30 | sr = self.net(lr, stu_width_mult) 31 | tea_sr = self.net(lr, tea_width_mult) 32 | loss = self.l1_loss(sr, hr) + self.l1_loss(tea_sr, hr) 33 | return loss 34 | 35 | class NetWithCSDLossCell(nn.Cell): 36 | def __init__(self, net, contrast_w=0, neg_num=0): 37 | super(NetWithCSDLossCell, self).__init__() 38 | self.net = net 39 | self.neg_num = neg_num 40 | self.l1_loss = nn.L1Loss() 41 | self.contrast_loss = ContrastLoss() 42 | self.contrast_w = contrast_w 43 | 44 | def construct(self, lr, hr, stu_width_mult, tea_width_mult): 45 | sr = self.net(lr, stu_width_mult) 46 | tea_sr = self.net(lr, tea_width_mult) 47 | loss = self.l1_loss(sr, hr) + self.l1_loss(tea_sr, hr) 48 | 49 | resize = nn.ResizeBilinear() 50 | bic = resize(lr, size=(lr.shape[-2] * 4, lr.shape[-1] * 4)) 51 | neg = numpy.flip(bic, 0) 52 | neg = neg[:self.neg_num, :, :, :] 53 | loss += self.contrast_w * self.contrast_loss(tea_sr, sr, neg) 54 | return loss 55 | 56 | class TrainOneStepCell(nn.Cell): 57 | def __init__(self, network, optimizer, sens=1.0): 58 | super(TrainOneStepCell, self).__init__(auto_prefix=False) 59 | self.network = network 60 | self.weights = optimizer.parameters 61 | self.optimizer = optimizer 62 | self.grad = ops.GradOperation(get_by_list=True, sens_param=True) 63 | self.sens = sens 64 | 65 | def set_sens(self, value): 66 | self.sens = value 67 | 68 | def construct(self, lr, hr, width_mult, tea_width_mult): 69 | weights = self.weights 70 | loss = self.network(lr, hr, width_mult, tea_width_mult) 71 | 72 | sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) 73 | grads = self.grad(self.network, weights)(lr, hr, width_mult, tea_width_mult, sens) 74 | self.optimizer(grads) 75 | return loss 76 | 77 | def csd_train(train_loader, net, opt): 78 | set_seed(1) 79 | device_id = int(os.getenv('DEVICE_ID', '0')) 80 | print("[CSD] Start Training...") 81 | 82 | step_size = train_loader.get_dataset_size() 83 | lr = [] 84 | for i in range(0, opt.epochs): 85 | cur_lr = opt.lr / (2 ** ((i + 1) // 200)) 86 | lr.extend([cur_lr] * step_size) 87 | optim = nn.Adam(net.trainable_params(), learning_rate=lr, loss_scale=opt.loss_scale) 88 | 89 | # net_with_loss = NetWithLossCell(net) 90 | net_with_loss = NetWithCSDLossCell(net, args.contra_lambda, args.neg_num) 91 | train_cell = TrainOneStepCell(net_with_loss, optim) 92 | net.set_train() 93 | eval_net = net 94 | 95 | # time_cb = TimeMonitor(data_size=step_size) 96 | # loss_cb = LossMonitor() 97 | # metrics = { 98 | # "psnr": PSNR(rgb_range=opt.rgb_range, shave=True), 99 | # } 100 | # eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size / opt.batch_size, metrics=metrics, 101 | # rank_id=rank_id) 102 | # cb = [time_cb, loss_cb] 103 | # config_ck = CheckpointConfig(save_checkpoint_steps=opt.ckpt_save_interval * step_size, 104 | # keep_checkpoint_max=opt.ckpt_save_max) 105 | # ckpt_cb = ModelCheckpoint(prefix=opt.filename, directory=opt.ckpt_save_path, config=config_ck) 106 | # if device_id == 0: 107 | # cb += [ckpt_cb] 108 | 109 | for epoch in range(0, opt.epochs): 110 | epoch_loss = 0 111 | for iteration, batch in enumerate(train_loader.create_dict_iterator(), 1): 112 | lr = batch["LR"] 113 | hr = batch["HR"] 114 | 115 | loss = train_cell(lr, hr, Tensor(opt.stu_width_mult), Tensor(1.0)) 116 | epoch_loss += loss 117 | 118 | print(f"Epoch[{epoch}] loss: {epoch_loss.asnumpy()}") 119 | # with eval_net.set_train(False): 120 | # do_eval(eval_ds, eval_net) 121 | 122 | if (epoch) % 10 == 0: 123 | print('===> Saving model...') 124 | save_checkpoint(net, f'./ckpt/{opt.filename}.ckpt') 125 | # cb_params.cur_epoch_num = epoch + 1 126 | # ckpt_cb.step_end(run_context) 127 | 128 | 129 | if __name__ == '__main__': 130 | time_start = time.time() 131 | device_id = int(os.getenv('DEVICE_ID', '0')) 132 | rank_id = int(os.getenv('RANK_ID', '0')) 133 | device_num = int(os.getenv('RANK_SIZE', '1')) 134 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False, device_id=device_id) 135 | 136 | train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False) 137 | train_dataset.set_scale(args.task_id) 138 | print(len(train_dataset)) 139 | train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num, 140 | shard_id=rank_id, shuffle=True) 141 | train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) 142 | 143 | eval_dataset = SRData(args, name=args.data_test, train=False, benchmark=True) 144 | print(len(eval_dataset)) 145 | eval_ds = ds.GeneratorDataset(eval_dataset, ['LR', 'HR'], shuffle=False) 146 | eval_ds = eval_ds.batch(1, drop_remainder=True) 147 | 148 | # net_m = RCAN(args) 149 | net_m = EDSR(args) 150 | print("Init net weights successfully") 151 | 152 | if args.ckpt_path: 153 | param_dict = load_checkpoint(args.ckpt_path) 154 | load_param_into_net(net_m, param_dict) 155 | print("Load net weight successfully") 156 | 157 | csd_train(train_de_dataset, net_m, args) 158 | time_end = time.time() 159 | print('train_time: %f' % (time_end - time_start)) -------------------------------------------------------------------------------- /MindSpore version/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """eval script""" 16 | import os 17 | import time 18 | import numpy as np 19 | import mindspore.dataset as ds 20 | from mindspore import Tensor, context 21 | from mindspore.common import dtype as mstype 22 | from mindspore.train.serialization import load_checkpoint, load_param_into_net 23 | import mindspore.nn as nn 24 | from mindspore.train.model import Model 25 | 26 | from src.args import args 27 | import src.rcan_model as rcan 28 | # from src.edsr_model import EDSR 29 | from src.edsr_slim import EDSR 30 | from src.data.srdata import SRData 31 | from src.metrics import calc_psnr, quantize, calc_ssim 32 | from src.data.div2k import DIV2K 33 | 34 | # device_id = int(os.getenv('DEVICE_ID', '0')) 35 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id, save_graphs=False) 36 | # context.set_context(max_call_depth=10000) 37 | def eval_net(width_mult=1.0): 38 | """eval""" 39 | if args.epochs == 0: 40 | args.epochs = 1e8 41 | for arg in vars(args): 42 | if vars(args)[arg] == 'True': 43 | vars(args)[arg] = True 44 | elif vars(args)[arg] == 'False': 45 | vars(args)[arg] = False 46 | if args.data_test[0] == 'DIV2K': 47 | train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False) 48 | else: 49 | train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) 50 | train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False) 51 | train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) 52 | train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) 53 | #net_m = rcan.RCAN(args) 54 | net_m = EDSR(args) 55 | net_m.set_train(False) 56 | if args.ckpt_path: 57 | print(f"Load from {args.ckpt_path}") 58 | param_dict = load_checkpoint(args.ckpt_path) 59 | load_param_into_net(net_m, param_dict) 60 | 61 | # opt = nn.Adam(net_m.trainable_params(), learning_rate=0.0001, loss_scale=args.loss_scale) 62 | # model = Model(net_m, nn.L1Loss(), opt) 63 | 64 | print('load mindspore net successfully.') 65 | num_imgs = train_de_dataset.get_dataset_size() 66 | psnrs = np.zeros((num_imgs, 1)) 67 | ssims = np.zeros((num_imgs, 1)) 68 | for batch_idx, imgs in enumerate(train_loader): 69 | lr = imgs['LR'] 70 | hr = imgs['HR'] 71 | lr = Tensor(lr, mstype.float32) 72 | pred = net_m(lr, Tensor(width_mult)) 73 | pred_np = pred.asnumpy() 74 | pred_np = quantize(pred_np, 255) 75 | psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0) 76 | pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0) 77 | hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0) 78 | ssim = calc_ssim(pred_np, hr, args.scale[0]) 79 | print("current psnr: ", psnr) 80 | print("current ssim: ", ssim) 81 | psnrs[batch_idx, 0] = psnr 82 | ssims[batch_idx, 0] = ssim 83 | print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) 84 | print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0])) 85 | 86 | def do_eval(eval_ds, eval_net, width_mult=1.0): 87 | train_loader = eval_ds.create_dict_iterator(output_numpy=True) 88 | num_imgs = eval_ds.get_dataset_size() 89 | psnrs = np.zeros((num_imgs, 1)) 90 | ssims = np.zeros((num_imgs, 1)) 91 | for batch_idx, imgs in enumerate(train_loader): 92 | lr = imgs['LR'] 93 | hr = imgs['HR'] 94 | lr = Tensor(lr, mstype.float32) 95 | pred = eval_net(lr, Tensor(width_mult)) 96 | pred_np = pred.asnumpy() 97 | pred_np = quantize(pred_np, 255) 98 | psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0) 99 | pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0) 100 | hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0) 101 | ssim = calc_ssim(pred_np, hr, args.scale[0]) 102 | # print("current psnr: ", psnr) 103 | # print("current ssim: ", ssim) 104 | psnrs[batch_idx, 0] = psnr 105 | ssims[batch_idx, 0] = ssim 106 | print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) 107 | print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0])) 108 | return np.mean(psnrs) 109 | 110 | if __name__ == '__main__': 111 | device_id = int(os.getenv('DEVICE_ID', '0')) 112 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id, save_graphs=False) 113 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=device_id, save_graphs=False) 114 | context.set_context(max_call_depth=10000) 115 | time_start = time.time() 116 | print("Start eval function!") 117 | print("Eval 1.0 Teacher") 118 | eval_net() 119 | time_end = time.time() 120 | print('eval_time: %f' % (time_end - time_start)) 121 | 122 | print("Eval Student") 123 | time_start = time.time() 124 | eval_net(args.stu_width_mult) 125 | time_end = time.time() 126 | print('eval_time: %f' % (time_end - time_start)) -------------------------------------------------------------------------------- /MindSpore version/export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """export net together with checkpoint into air/mindir/onnx models""" 16 | import os 17 | import argparse 18 | import numpy as np 19 | from src.args import args as arg 20 | from src.rcan_model import RCAN 21 | from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export 22 | 23 | 24 | parser = argparse.ArgumentParser(description='rcan export') 25 | parser.add_argument("--batch_size", type=int, default=1, help="batch size") 26 | parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file") 27 | parser.add_argument("--file_name", type=str, default="rcan", help="output file name.") 28 | parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format") 29 | args_1 = parser.parse_args() 30 | 31 | 32 | def run_export(args): 33 | """ export """ 34 | device_id = int(os.getenv('DEVICE_ID', '0')) 35 | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) 36 | net = RCAN(arg) 37 | param_dict = load_checkpoint(args.ckpt_path) 38 | load_param_into_net(net, param_dict) 39 | net.set_train(False) 40 | print('load mindspore net and checkpoint successfully.') 41 | inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32)) 42 | export(net, inputs, file_name=args.file_name, file_format=args.file_format) 43 | print('export successfully!') 44 | 45 | 46 | if __name__ == "__main__": 47 | run_export(args_1) 48 | -------------------------------------------------------------------------------- /MindSpore version/src/args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """args""" 16 | import argparse 17 | import ast 18 | 19 | parser = argparse.ArgumentParser(description='RCAN') 20 | 21 | # Hardware specifications 22 | parser.add_argument('--seed', type=int, default=1, 23 | help='random seed') 24 | 25 | # Data specifications 26 | parser.add_argument('--dir_data', type=str, default='/cache/data/', 27 | help='dataset directory') 28 | parser.add_argument('--data_train', type=str, default='DIV2K', 29 | help='train dataset name') 30 | parser.add_argument('--data_test', type=str, default='DIV2K', 31 | help='test dataset name') 32 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 33 | help='train/test data range') 34 | parser.add_argument('--ext', type=str, default='sep', 35 | help='dataset file extension') 36 | parser.add_argument('--scale', type=str, default='4', 37 | help='super resolution scale') 38 | parser.add_argument('--patch_size', type=int, default=48, 39 | help='output patch size') 40 | parser.add_argument('--rgb_range', type=int, default=255, 41 | help='maximum value of RGB') 42 | parser.add_argument('--n_colors', type=int, default=3, 43 | help='number of color channels to use') 44 | parser.add_argument('--no_augment', action='store_true', 45 | help='do not use data augmentation') 46 | 47 | # Model specifications 48 | parser.add_argument('--model', default='RCAN', 49 | help='model name') 50 | parser.add_argument('--act', type=str, default='relu', 51 | help='activation function') 52 | parser.add_argument('--n_resblocks', type=int, default=20, 53 | help='number of residual blocks') 54 | parser.add_argument('--n_feats', type=int, default=64, 55 | help='number of feature maps') 56 | parser.add_argument('--res_scale', type=float, default=1, 57 | help='residual scaling') 58 | 59 | 60 | # Option for Residual channel attention network (RCAN) 61 | parser.add_argument('--n_resgroups', type=int, default=10, 62 | help='number of residual groups') 63 | parser.add_argument('--reduction', type=int, default=16, 64 | help='number of feature maps reduction') 65 | 66 | # Training specifications 67 | parser.add_argument('--test_every', type=int, default=4000, 68 | help='do test per every N batches') 69 | parser.add_argument('--epochs', type=int, default=1000, 70 | help='number of epochs to train') 71 | parser.add_argument('--batch_size', type=int, default=16, 72 | help='input batch size for training') 73 | parser.add_argument('--test_only', action='store_true', 74 | help='set this option to test the model') 75 | 76 | 77 | # Optimization specifications 78 | parser.add_argument('--lr', type=float, default=1e-5, 79 | help='learning rate') 80 | parser.add_argument('--loss_scale', type=float, default=1024.0, 81 | help='scaling factor for optim') 82 | parser.add_argument('--init_loss_scale', type=float, default=65536., 83 | help='scaling factor') 84 | parser.add_argument('--decay', type=str, default='200', 85 | help='learning rate decay type') 86 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 87 | help='ADAM beta') 88 | parser.add_argument('--epsilon', type=float, default=1e-8, 89 | help='ADAM epsilon for numerical stability') 90 | parser.add_argument('--weight_decay', type=float, default=0, 91 | help='weight decay') 92 | parser.add_argument('--gclip', type=float, default=0, 93 | help='gradient clipping threshold (0 = no clipping)') 94 | 95 | # ckpt specifications 96 | parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/', 97 | help='path to save ckpt') 98 | parser.add_argument('--ckpt_save_interval', type=int, default=10, 99 | help='save ckpt frequency, unit is epoch') 100 | parser.add_argument('--ckpt_save_max', type=int, default=100, 101 | help='max number of saved ckpt') 102 | parser.add_argument('--ckpt_path', type=str, default='', 103 | help='path of saved ckpt') 104 | parser.add_argument('--filename', type=str, default='') 105 | 106 | # Task 107 | parser.add_argument('--task_id', type=int, default=0) 108 | 109 | # CSD 110 | parser.add_argument('--stu_width_mult', type=float, default=0.25) 111 | parser.add_argument('--neg_num', type=float, default=10) 112 | parser.add_argument('--contra_lambda', type=float, default=1, help='weight of contra_loss') 113 | 114 | # ModelArts 115 | parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False, 116 | help='train on modelarts or not, default is False') 117 | parser.add_argument('--data_url', type=str, default='', help='the directory path of saved file') 118 | 119 | 120 | args, unparsed = parser.parse_known_args() 121 | 122 | args.scale = [int(x) for x in args.scale.split("+")] 123 | args.data_train = args.data_train.split('+') 124 | args.data_test = args.data_test.split('+') 125 | 126 | if args.epochs == 0: 127 | args.epochs = 1e8 128 | 129 | for arg in vars(args): 130 | if vars(args)[arg] == 'True': 131 | vars(args)[arg] = True 132 | elif vars(args)[arg] == 'False': 133 | vars(args)[arg] = False 134 | -------------------------------------------------------------------------------- /MindSpore version/src/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mindspore 4 | import mindspore.nn as nn 5 | import mindspore.ops as ops 6 | from mindspore import Tensor 7 | import mindspore.numpy as numpy 8 | from mindspore.common.initializer import TruncatedNormal 9 | 10 | 11 | def weight_variable(): 12 | """weight initial""" 13 | return TruncatedNormal(0.02) 14 | 15 | def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, 16 | pad_mode='pad', has_bias=True): 17 | """weight initial for conv layer""" 18 | weight = weight_variable() 19 | return nn.Conv2d(in_channels, out_channels, 20 | kernel_size=kernel_size, stride=stride, padding=padding, 21 | weight_init=weight, has_bias=has_bias, pad_mode=pad_mode) 22 | 23 | def pixel_shuffle(tensor, scale_factor): 24 | """ 25 | Implementation of pixel shuffle using numpy 26 | 27 | Parameters: 28 | ----------- 29 | tensor: input tensor, shape is [N, C, H, W] 30 | scale_factor: scale factor to up-sample tensor 31 | 32 | Returns: 33 | -------- 34 | tensor: tensor after pixel shuffle, shape is [N, C/(r*r), r*H, r*W], 35 | where r refers to scale factor 36 | """ 37 | num, ch, height, width = tensor.shape 38 | # assert ch % (scale_factor * scale_factor) == 0 39 | 40 | new_ch = ch // (scale_factor * scale_factor) 41 | new_height = height * scale_factor 42 | new_width = width * scale_factor 43 | 44 | reshape = ops.Reshape() 45 | tensor = reshape(tensor, (num, new_ch, scale_factor, scale_factor, height, width)) 46 | # new axis: [num, new_ch, height, scale_factor, width, scale_factor] 47 | transpose = ops.Transpose() 48 | tensor = transpose(tensor, (0, 1, 4, 2, 5, 3)) 49 | tensor = reshape(tensor, (num, new_ch, new_height, new_width)) 50 | return tensor 51 | 52 | def MeanShift(x, rgb_range, 53 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 54 | 55 | # super(MeanShift, self).__init__(3, 3, kernel_size=1) 56 | 57 | std = Tensor(rgb_std) 58 | conv2d = ops.Conv2D(out_channel=3, kernel_size=1) 59 | biasadd = ops.BiasAdd() 60 | weight = numpy.eye(3, 3).view((3, 3, 1, 1)) / std.view(3, 1, 1, 1) 61 | bias = sign * rgb_range * Tensor(rgb_mean) / std 62 | weight = weight.astype(numpy.float32) 63 | bias = bias.astype(numpy.float32) 64 | 65 | x = conv2d(x, weight) 66 | x = biasadd(x, bias) 67 | return x 68 | 69 | class ResidualBlock(nn.Cell): 70 | def __init__(self, n_feats, kernel_size, act, res_scale): 71 | super(ResidualBlock, self).__init__() 72 | self.n_feats = n_feats 73 | self.res_scale = res_scale 74 | self.kernel_size = kernel_size 75 | 76 | # self.conv1 = nn.Conv2d(in_channels=n_feats, out_channels=n_feats, kernel_size=kernel_size, pad_mode='pad', padding=1, has_bias=True) 77 | self.conv1 = conv(n_feats, n_feats, kernel_size, padding=1) 78 | self.act = act 79 | # self.conv2 = nn.Conv2d(in_channels=n_feats, out_channels=n_feats, kernel_size=kernel_size, pad_mode='pad', padding=1, has_bias=True) 80 | self.conv2 = conv(n_feats, n_feats, kernel_size, padding=1) 81 | 82 | def construct(self, x, width_mult=1): 83 | # round = ops.Round() 84 | width = int(self.n_feats * width_mult) 85 | # width = round(self.n_feats * width_mult) 86 | conv2d = ops.Conv2D(out_channel=width, kernel_size=self.kernel_size, mode=1, pad_mode='pad', pad=1) 87 | biasadd = ops.BiasAdd() 88 | weight = self.conv1.weight[:width, :width, :, :] 89 | bias = self.conv1.bias[:width] 90 | residual = conv2d(x, weight) 91 | if bias is not None: 92 | residual = biasadd(residual, bias) 93 | residual = self.act(residual) 94 | weight = self.conv2.weight[:width, :width, :, :] 95 | bias = self.conv2.bias[:width] 96 | residual = conv2d(residual, weight) 97 | if bias is not None: 98 | residual = biasadd(residual, bias) 99 | 100 | return x + residual * self.res_scale 101 | 102 | class Upsampler(nn.SequentialCell): 103 | def __init__(self, scale_factor, nf): 104 | super(Upsampler, self).__init__() 105 | block = [] 106 | self.nf = nf 107 | self.scale = scale_factor 108 | 109 | if scale_factor == 3: 110 | block += [ 111 | # nn.Conv2d(in_channels=nf, out_channels=nf * 9, kernel_size=3, pad_mode='pad', padding=1, has_bias=True) 112 | conv(nf, nf*9, 3, padding=1) 113 | ] 114 | # self.pixel_shuffle = nn.PixelShuffle(3) 115 | # pixel_shuffle function 116 | else: 117 | self.block_num = scale_factor // 2 118 | # self.pixel_shuffle = nn.PixelShuffle(2) 119 | #self.act = nn.ReLU() 120 | 121 | for _ in range(self.block_num): 122 | block += [ 123 | # nn.Conv2d(in_channels=nf, out_channels=nf * 2 ** 2, kernel_size=3, pad_mode='pad', padding=1, has_bias=True) 124 | conv(nf, nf*2**2, 3, padding=1) 125 | ] 126 | self.blocks = nn.SequentialCell(block) 127 | 128 | def construct(self, x, width_mult=1): 129 | res = x 130 | nf = self.nf 131 | if self.scale == 3: 132 | width = int(width_mult * nf) 133 | width9 = width * 9 134 | conv2d = ops.Conv2D(out_channel=width9, kernel_size=3, mode=1, pad_mode='pad', pad=1) 135 | biasadd = ops.BiasAdd() 136 | for block in self.blocks: 137 | weight = block.weight[:width9, :width, :, :] 138 | bias = block.bias[:width9] 139 | res = conv2d(res, weight) 140 | if bias: 141 | res = biasadd(res, bias) 142 | res = pixel_shuffle(res, self.scale) 143 | else: 144 | width = int(width_mult * nf) 145 | width4 = width * 4 146 | conv2d = ops.Conv2D(out_channel=width4, kernel_size=3, mode=1, pad_mode='pad', pad=1) 147 | biasadd = ops.BiasAdd() 148 | for block in self.blocks: 149 | weight = block.weight[:width4, :width, :, :] 150 | bias = block.bias[:width4] 151 | # print(res.shape) 152 | res = conv2d(res, weight) 153 | # print(res.shape) 154 | if bias is not None: 155 | res = biasadd(res, bias) 156 | res = pixel_shuffle(res, 2) 157 | #res = self.act(res) 158 | 159 | return res 160 | 161 | def SlimModule(input, module, width_mult): 162 | weight = module.weight 163 | out_ch, in_ch = weight.shape[:2] 164 | out_ch = int(out_ch * width_mult) 165 | in_ch = int(in_ch * width_mult) 166 | weight = weight[:out_ch, :in_ch, :, :] 167 | bias = module.bias 168 | 169 | conv2d = ops.Conv2D(out_channel=out_ch, kernel_size=module.kernel_size, mode=1, pad_mode=module.pad_mode, pad=module.padding) 170 | biasadd =ops.BiasAdd() 171 | out = conv2d(input, weight) 172 | if bias is not None: 173 | bias = module.bias[:out_ch] 174 | out = biasadd(out, bias) 175 | return out 176 | -------------------------------------------------------------------------------- /MindSpore version/src/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """ 16 | network config setting, will be used in train.py and eval.py 17 | """ 18 | from easydict import EasyDict as edict 19 | 20 | # config for vgg16, cifar10 21 | cifar_cfg = edict({ 22 | "num_classes": 10, 23 | "lr": 0.01, 24 | "lr_init": 0.01, 25 | "lr_max": 0.1, 26 | "lr_epochs": '30,60,90,120', 27 | "lr_scheduler": "step", 28 | "warmup_epochs": 5, 29 | "batch_size": 64, 30 | "max_epoch": 70, 31 | "momentum": 0.9, 32 | "weight_decay": 5e-4, 33 | "loss_scale": 1.0, 34 | "label_smooth": 0, 35 | "label_smooth_factor": 0, 36 | "buffer_size": 10, 37 | "image_size": '224,224', 38 | "pad_mode": 'same', 39 | "padding": 0, 40 | "has_bias": False, 41 | "batch_norm": True, 42 | "keep_checkpoint_max": 10, 43 | "initialize_mode": "XavierUniform", 44 | "has_dropout": False 45 | }) 46 | 47 | # config for vgg16, imagenet2012 48 | imagenet_cfg = edict({ 49 | "num_classes": 1000, 50 | "lr": 0.01, 51 | "lr_init": 0.01, 52 | "lr_max": 0.1, 53 | "lr_epochs": '30,60,90,120', 54 | "lr_scheduler": 'cosine_annealing', 55 | "warmup_epochs": 0, 56 | "batch_size": 32, 57 | "max_epoch": 150, 58 | "momentum": 0.9, 59 | "weight_decay": 1e-4, 60 | "loss_scale": 1024, 61 | "label_smooth": 1, 62 | "label_smooth_factor": 0.1, 63 | "buffer_size": 10, 64 | "image_size": '224,224', 65 | "pad_mode": 'pad', 66 | "padding": 1, 67 | # "has_bias": True, 68 | "has_bias": False, 69 | "batch_norm": False, 70 | "keep_checkpoint_max": 10, 71 | "initialize_mode": "KaimingNormal", 72 | "has_dropout": True 73 | }) -------------------------------------------------------------------------------- /MindSpore version/src/contras_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import mindspore.ops as ops 4 | from mindspore.nn.loss.loss import _Loss 5 | # import mindspore_hub as mshub 6 | import mindspore 7 | from mindspore import context, Tensor, nn 8 | from mindspore.train.model import Model 9 | from mindspore.common import dtype as mstype 10 | from mindspore.dataset.transforms import py_transforms 11 | from mindspore import load_checkpoint, load_param_into_net 12 | from mindspore.ops.functional import stop_gradient 13 | import mindspore.numpy as np 14 | 15 | from src.config import imagenet_cfg 16 | from src.vgg_model import Vgg 17 | 18 | # context.set_context(mode=context.GRAPH_MODE, 19 | # device_target="Ascend", 20 | # device_id=0) 21 | # context.set_context(device_id=0) 22 | 23 | cfg = { 24 | '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 25 | '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 26 | '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 27 | '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 28 | } 29 | 30 | class Vgg19(nn.Cell): 31 | def __init__(self, requires_grad=False): 32 | super(Vgg19, self).__init__() 33 | 34 | ##load vgg16 35 | vgg = Vgg(cfg['19'], phase="test", args=imagenet_cfg) 36 | # model = os.path.join(opt.data_url, 'vgg19_ImageNet.ckpt') 37 | model = os.path.join('./', 'vgg19_ImageNet.ckpt') 38 | print(model) 39 | param_dict = load_checkpoint(model) 40 | # print(param_dict) 41 | load_param_into_net(vgg, param_dict) 42 | vgg.set_train(False) 43 | 44 | vgg_pretrained_features = vgg.layers 45 | self.slice1 = nn.SequentialCell() 46 | self.slice2 = nn.SequentialCell() 47 | self.slice3 = nn.SequentialCell() 48 | self.slice4 = nn.SequentialCell() 49 | self.slice5 = nn.SequentialCell() 50 | for x in range(2): 51 | self.slice1.append(vgg_pretrained_features[x]) 52 | for x in range(2, 7): 53 | self.slice2.append(vgg_pretrained_features[x]) 54 | for x in range(7, 12): 55 | self.slice3.append(vgg_pretrained_features[x]) 56 | for x in range(12, 21): 57 | self.slice4.append(vgg_pretrained_features[x]) 58 | for x in range(21, 30): 59 | self.slice5.append(vgg_pretrained_features[x]) 60 | if not requires_grad: 61 | for param in self.get_parameters(): 62 | param.requires_grad = False 63 | 64 | def construct(self, x): 65 | h_relu1 = self.slice1(x) 66 | h_relu2 = self.slice2(h_relu1) 67 | h_relu3 = self.slice3(h_relu2) 68 | h_relu4 = self.slice4(h_relu3) 69 | h_relu5 = self.slice5(h_relu4) 70 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 71 | return out 72 | 73 | 74 | class ContrastLoss(_Loss): 75 | def __init__(self): 76 | super(ContrastLoss, self).__init__() 77 | self.vgg = Vgg19() 78 | self.l1 = nn.L1Loss() 79 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 80 | 81 | def construct(self, teacher, student, neg): 82 | expand_dims = ops.ExpandDims() # unsqueeze算子 83 | teacher_vgg, student_vgg, neg_vgg = self.vgg(teacher), self.vgg(student), self.vgg(neg) 84 | 85 | loss = 0 86 | for i in range(len(teacher_vgg)): 87 | neg_i = expand_dims(neg_vgg[i], 0) # [8, n_feats, w, h] 88 | # neg_i = neg_i.repeat(student_vgg[i].shape[0], axis=0) #TODO:1.3版本才会支持Tensor.repeat 89 | neg_i = np.repeat(neg_i, student_vgg[i].shape[0], axis=0) # [16, 8, n_feats, w, h] 90 | neg_i = neg_i.transpose((1, 0, 2, 3, 4)) # [8, 16, n_feats, w, h] 91 | 92 | d_ts = self.l1(stop_gradient(teacher_vgg[i]), student_vgg[i]) 93 | # d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs().sum(axis=0).mean() #TODO:1.3版本才支持Tensor.sum 94 | d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs() # [8, 16, n_feats, w, h] 95 | # print(d_sn.shape) 96 | reduceSum = ops.ReduceSum() 97 | d_sn = reduceSum(d_sn, 0).mean() 98 | # print(d_sn) 99 | 100 | contrastive = d_ts / (d_sn + 1e-7) 101 | loss += self.weights[i] * contrastive 102 | 103 | return self.get_loss(loss) 104 | -------------------------------------------------------------------------------- /MindSpore version/src/data/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """common""" 16 | import random 17 | import os 18 | import numpy as np 19 | 20 | 21 | def get_patch(*args, patch_size=96, scale=2, input_large=False): 22 | """get_patch""" 23 | ih, iw = args[0].shape[:2] 24 | 25 | tp = patch_size 26 | ip = tp // scale 27 | 28 | ix = random.randrange(0, iw - ip + 1) 29 | iy = random.randrange(0, ih - ip + 1) 30 | 31 | if not input_large: 32 | tx, ty = scale * ix, scale * iy 33 | else: 34 | tx, ty = ix, iy 35 | 36 | ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]] 37 | 38 | return ret 39 | 40 | 41 | def set_channel(*args, n_channels=3): 42 | """set_channel""" 43 | def _set_channel(img): 44 | if img.ndim == 2: 45 | img = np.expand_dims(img, axis=2) 46 | 47 | c = img.shape[2] 48 | if n_channels == 3 and c == 1: 49 | img = np.concatenate([img] * n_channels, 2) 50 | 51 | return img[:, :, :n_channels] 52 | 53 | return [_set_channel(a) for a in args] 54 | 55 | 56 | def np2Tensor(*args, rgb_range=255): 57 | """ np2Tensor""" 58 | def _np2Tensor(img): 59 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 60 | input_data = np_transpose.astype(np.float32) 61 | output = input_data * (rgb_range / 255) 62 | return output 63 | return [_np2Tensor(a) for a in args] 64 | 65 | 66 | def augment(*args, hflip=True, rot=True): 67 | """augment(""" 68 | hflip = hflip and random.random() < 0.5 69 | vflip = rot and random.random() < 0.5 70 | rot90 = rot and random.random() < 0.5 71 | 72 | def _augment(img): 73 | """augment""" 74 | if hflip: 75 | img = img[:, ::-1, :] 76 | if vflip: 77 | img = img[::-1, :, :] 78 | if rot90: 79 | img = img.transpose(1, 0, 2) 80 | return img 81 | 82 | return [_augment(a) for a in args] 83 | 84 | 85 | def search(root, target="JPEG"): 86 | """search""" 87 | item_list = [] 88 | items = os.listdir(root) 89 | for item in items: 90 | path = os.path.join(root, item) 91 | if os.path.isdir(path): 92 | item_list.extend(search(path, target)) 93 | elif path.split('/')[-1].startswith(target): 94 | item_list.append(path) 95 | elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): 96 | item_list.append(path) 97 | return item_list 98 | -------------------------------------------------------------------------------- /MindSpore version/src/data/div2k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """div2k""" 16 | import os 17 | from src.data.srdata import SRData 18 | 19 | 20 | class DIV2K(SRData): 21 | """DIV2K""" 22 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 23 | self.dir_hr = None 24 | self.dir_lr = None 25 | data_range = [r.split('-') for r in args.data_range.split('/')] 26 | if train: 27 | data_range = data_range[0] 28 | else: 29 | if args.test_only and len(data_range) == 1: 30 | data_range = data_range[0] 31 | else: 32 | data_range = data_range[1] 33 | 34 | self.begin, self.end = list(map(int, data_range)) 35 | super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark) 36 | 37 | def _scan(self): 38 | names_hr, names_lr = super(DIV2K, self)._scan() 39 | names_hr = names_hr[self.begin - 1:self.end] 40 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 41 | 42 | return names_hr, names_lr 43 | 44 | def _set_filesystem(self, dir_data): 45 | super(DIV2K, self)._set_filesystem(dir_data) 46 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 47 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 48 | -------------------------------------------------------------------------------- /MindSpore version/src/data/srdata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """"srdata""" 16 | import os 17 | import glob 18 | import random 19 | import pickle 20 | import imageio 21 | from src.data import common 22 | from PIL import ImageFile 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | 26 | 27 | class SRData: 28 | """srdata""" 29 | 30 | def __init__(self, args, name='', train=True, benchmark=False): 31 | self.derain_lr_test = None 32 | self.derain_hr_test = None 33 | self.deblur_lr_test = None 34 | self.deblur_hr_test = None 35 | self.args = args 36 | self.name = name 37 | self.train = train 38 | self.split = 'train' if train else 'test' 39 | self.do_eval = True 40 | self.benchmark = benchmark 41 | self.input_large = (args.model == 'VDSR') 42 | self.scale = args.scale 43 | self.idx_scale = 0 44 | if benchmark: 45 | self._set_filesystem(os.path.join(args.dir_data, 'benchmark')) 46 | else: 47 | self._set_filesystem(args.dir_data) 48 | self._set_img(args) 49 | if train: 50 | self._repeat(args) 51 | 52 | def _set_img(self, args): 53 | """set_img""" 54 | if args.ext.find('img') < 0: 55 | path_bin = os.path.join(self.apath, 'bin') 56 | os.makedirs(path_bin, exist_ok=True) 57 | list_hr, list_lr = self._scan() 58 | if args.ext.find('img') >= 0 or self.benchmark: 59 | self.images_hr, self.images_lr = list_hr, list_lr 60 | elif args.ext.find('sep') >= 0: 61 | os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True) 62 | for s in self.scale: 63 | if s == 1: 64 | os.makedirs(os.path.join(self.dir_hr), exist_ok=True) 65 | else: 66 | os.makedirs( 67 | os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True) 68 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 69 | for h in list_hr: 70 | b = h.replace(self.apath, path_bin) 71 | b = b.replace(self.ext[0], '.pt') 72 | self.images_hr.append(b) 73 | self._check_and_load(args.ext, h, b, verbose=True) 74 | for i, ll in enumerate(list_lr): 75 | for l in ll: 76 | b = l.replace(self.apath, path_bin) 77 | b = b.replace(self.ext[1], '.pt') 78 | self.images_lr[i].append(b) 79 | self._check_and_load(args.ext, l, b, verbose=True) 80 | 81 | def _repeat(self, args): 82 | """repeat""" 83 | n_patches = args.batch_size * args.test_every 84 | n_images = len(args.data_train) * len(self.images_hr) 85 | if n_images == 0: 86 | self.repeat = 0 87 | else: 88 | self.repeat = max(n_patches // n_images, 1) 89 | 90 | def _scan(self): 91 | """_scan""" 92 | names_hr = sorted( 93 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))) 94 | names_lr = [[] for _ in self.scale] 95 | for f in names_hr: 96 | filename, _ = os.path.splitext(os.path.basename(f)) 97 | for si, s in enumerate(self.scale): 98 | if s != 1: 99 | scale = s 100 | names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \ 101 | .format(s, filename, scale, self.ext[1]))) 102 | for si, s in enumerate(self.scale): 103 | if s == 1: 104 | names_lr[si] = names_hr 105 | return names_hr, names_lr 106 | 107 | def _set_filesystem(self, dir_data): 108 | """set_filesystem""" 109 | self.apath = os.path.join(dir_data, self.name[0]) 110 | self.dir_hr = os.path.join(self.apath, 'HR') 111 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 112 | self.ext = ('.png', '.png') 113 | 114 | def _check_and_load(self, ext, img, f, verbose=True): 115 | """check_and_load""" 116 | if not os.path.isfile(f) or ext.find('reset') >= 0: 117 | if verbose: 118 | print('Making a binary: {}'.format(f)) 119 | with open(f, 'wb') as _f: 120 | pickle.dump(imageio.imread(img), _f) 121 | 122 | def __getitem__(self, idx): 123 | """get item""" 124 | lr, hr, _ = self._load_file(idx) 125 | pair = self.get_patch(lr, hr) 126 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 127 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 128 | return pair_t[0], pair_t[1] 129 | 130 | def __len__(self): 131 | """length of hr""" 132 | if self.train: 133 | return len(self.images_hr) * self.repeat 134 | return len(self.images_hr) 135 | 136 | def _get_index(self, idx): 137 | """get_index""" 138 | if self.train: 139 | return idx % len(self.images_hr) 140 | return idx 141 | 142 | def _load_file_hr(self, idx): 143 | """load_file_hr""" 144 | idx = self._get_index(idx) 145 | f_hr = self.images_hr[idx] 146 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 147 | if self.args.ext == 'img' or self.benchmark: 148 | hr = imageio.imread(f_hr) 149 | elif self.args.ext.find('sep') >= 0: 150 | with open(f_hr, 'rb') as _f: 151 | hr = pickle.load(_f) 152 | return hr, filename 153 | 154 | def _load_file(self, idx): 155 | """load_file""" 156 | idx = self._get_index(idx) 157 | f_hr = self.images_hr[idx] 158 | f_lr = self.images_lr[self.idx_scale][idx] 159 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 160 | if self.args.ext == 'img' or self.benchmark: 161 | hr = imageio.imread(f_hr) 162 | lr = imageio.imread(f_lr) 163 | elif self.args.ext.find('sep') >= 0: 164 | with open(f_hr, 'rb') as _f: 165 | hr = pickle.load(_f) 166 | with open(f_lr, 'rb') as _f: 167 | lr = pickle.load(_f) 168 | return lr, hr, filename 169 | 170 | def get_patch_hr(self, hr): 171 | """get_patch_hr""" 172 | if self.train: 173 | hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1) 174 | return hr 175 | 176 | def get_patch_img_hr(self, img, patch_size=96, scale=2): 177 | """get_patch_img_hr""" 178 | ih, iw = img.shape[:2] 179 | tp = patch_size 180 | ip = tp // scale 181 | ix = random.randrange(0, iw - ip + 1) 182 | iy = random.randrange(0, ih - ip + 1) 183 | ret = img[iy:iy + ip, ix:ix + ip, :] 184 | return ret 185 | 186 | def get_patch(self, lr, hr): 187 | """get_patch""" 188 | scale = self.scale[self.idx_scale] 189 | if self.train: 190 | lr, hr = common.get_patch( 191 | lr, hr, 192 | patch_size=self.args.patch_size * scale, 193 | scale=scale) 194 | if not self.args.no_augment: 195 | lr, hr = common.augment(lr, hr) 196 | else: 197 | ih, iw = lr.shape[:2] 198 | hr = hr[0:ih * scale, 0:iw * scale] 199 | return lr, hr 200 | 201 | def set_scale(self, idx_scale): 202 | """set_scale""" 203 | if not self.input_large: 204 | self.idx_scale = idx_scale 205 | else: 206 | self.idx_scale = random.randint(0, len(self.scale) - 1) 207 | -------------------------------------------------------------------------------- /MindSpore version/src/edsr_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """edsr""" 16 | import math 17 | import mindspore.ops as ops 18 | import mindspore.nn as nn 19 | import mindspore.common.dtype as mstype 20 | from mindspore.ops import operations as P 21 | from mindspore.common import Tensor, Parameter 22 | 23 | 24 | def default_conv(in_channels, out_channels, kernel_size, has_bias=True): 25 | """edsr""" 26 | return nn.Conv2d( 27 | in_channels, out_channels, kernel_size, 28 | padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad') 29 | 30 | 31 | class MeanShift(nn.Conv2d): 32 | """edsr""" 33 | def __init__(self, 34 | rgb_range, 35 | rgb_mean=(0.4488, 0.4371, 0.4040), 36 | rgb_std=(1.0, 1.0, 1.0), 37 | sign=-1): 38 | """edsr""" 39 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 40 | self.reshape = P.Reshape() 41 | self.eye = P.Eye() 42 | std = Tensor(rgb_std, mstype.float32) 43 | self.weight.set_data( 44 | self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1))) 45 | self.weight.requires_grad = False 46 | self.bias = Parameter( 47 | sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False) 48 | self.has_bias = True 49 | 50 | 51 | def _pixelsf_(x, scale): 52 | """edsr""" 53 | n, c, ih, iw = x.shape 54 | oh = ih * scale 55 | ow = iw * scale 56 | oc = c // (scale ** 2) 57 | output = P.Transpose()(x, (0, 2, 1, 3)) 58 | output = P.Reshape()(output, (n, ih, oc * scale, scale, iw)) 59 | output = P.Transpose()(output, (0, 1, 2, 4, 3)) 60 | output = P.Reshape()(output, (n, ih, oc, scale, ow)) 61 | output = P.Transpose()(output, (0, 2, 1, 3, 4)) 62 | output = P.Reshape()(output, (n, oc, oh, ow)) 63 | return output 64 | 65 | 66 | class SmallUpSampler(nn.Cell): 67 | """edsr""" 68 | def __init__(self, conv, upsize, n_feats, has_bias=True): 69 | """edsr""" 70 | super(SmallUpSampler, self).__init__() 71 | self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias) 72 | self.reshape = P.Reshape() 73 | self.upsize = upsize 74 | self.pixelsf = _pixelsf_ 75 | 76 | def construct(self, x): 77 | """edsr""" 78 | x = self.conv(x) 79 | output = self.pixelsf(x, self.upsize) 80 | return output 81 | 82 | 83 | class Upsampler(nn.Cell): 84 | """edsr""" 85 | def __init__(self, conv, scale, n_feats, has_bias=True): 86 | """edsr""" 87 | super(Upsampler, self).__init__() 88 | m = [] 89 | if (scale & (scale - 1)) == 0: 90 | for _ in range(int(math.log(scale, 2))): 91 | m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias)) 92 | elif scale == 3: 93 | m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias)) 94 | self.net = nn.SequentialCell(m) 95 | 96 | def construct(self, x): 97 | """edsr""" 98 | return self.net(x) 99 | 100 | 101 | class AdaptiveAvgPool2d(nn.Cell): 102 | """edsr""" 103 | def __init__(self): 104 | """edsr""" 105 | super().__init__() 106 | self.ReduceMean = ops.ReduceMean(keep_dims=True) 107 | 108 | def construct(self, x): 109 | """edsr""" 110 | return self.ReduceMean(x, 0) 111 | 112 | 113 | class ResidualBlock(nn.Cell): 114 | """edsr""" 115 | def __init__(self, conv, n_feat, kernel_size, has_bias=True 116 | , bn=False, act=nn.ReLU(), res_scale=0.1): 117 | """edsr""" 118 | super(ResidualBlock, self).__init__() 119 | self.modules_body = [] 120 | for i in range(2): 121 | self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias)) 122 | if bn: self.modules_body.append(nn.BatchNorm2d(n_feat)) 123 | if i == 0: self.modules_body.append(act) 124 | self.body = nn.SequentialCell(*self.modules_body) 125 | self.res_scale = res_scale 126 | 127 | def construct(self, x): 128 | """edsr""" 129 | res = self.body(x) 130 | res = res * self.res_scale 131 | return x + res 132 | 133 | class EDSR(nn.Cell): 134 | def __init__(self, args, conv=default_conv): 135 | super(EDSR, self).__init__() 136 | n_resblocks = args.n_resblocks 137 | n_feats = args.n_feats 138 | kernel_size = 3 139 | idx = args.task_id 140 | scale = args.scale[idx] 141 | self.dytpe = mstype.float16 142 | 143 | # RGB mean for DIV2K 144 | rgb_mean = (0.4488, 0.4371, 0.4040) 145 | rgb_std = (1.0, 1.0, 1.0) 146 | 147 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe) 148 | 149 | # define head module 150 | modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe) 151 | 152 | m_body = [ 153 | ResidualBlock( 154 | conv, n_feats, kernel_size, res_scale=args.res_scale, 155 | ).to_float(self.dytpe) for _ in range(n_resblocks) 156 | ] 157 | m_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe)) 158 | 159 | m_tail = [ 160 | Upsampler(conv, scale, n_feats).to_float(self.dytpe), 161 | conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe) 162 | ] 163 | 164 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe) 165 | 166 | self.head = modules_head 167 | self.body = nn.SequentialCell(m_body) 168 | self.tail = nn.SequentialCell(m_tail) 169 | 170 | def construct(self, x): 171 | x = self.sub_mean(x) 172 | x = self.head(x) 173 | x = x + self.body(x) 174 | x = self.tail(x) 175 | x = self.add_mean(x) 176 | return x 177 | -------------------------------------------------------------------------------- /MindSpore version/src/edsr_slim.py: -------------------------------------------------------------------------------- 1 | import mindspore 2 | from src import common 3 | from src.edsr_model import Upsampler, default_conv 4 | 5 | import mindspore.nn as nn 6 | import mindspore.ops as ops 7 | import mindspore.ops.operations as P 8 | from mindspore import Tensor 9 | 10 | class EDSR(nn.Cell): 11 | def __init__(self, args): 12 | super(EDSR, self).__init__() 13 | 14 | self.n_colors = args.n_colors 15 | n_resblocks = args.n_resblocks 16 | self.n_feats = args.n_feats 17 | self.kernel_size = 3 18 | scale = args.scale[0] 19 | act = nn.ReLU() 20 | self.rgb_range = args.rgb_range 21 | 22 | # self.head = nn.Conv2d(in_channels=args.n_colors, out_channels=self.n_feats, kernel_size=self.kernel_size, pad_mode='pad', padding=self.kernel_size // 2, has_bias=True) 23 | self.head = common.conv(args.n_colors, self.n_feats, self.kernel_size, padding=self.kernel_size//2) 24 | 25 | m_body = [ 26 | common.ResidualBlock( 27 | self.n_feats, self.kernel_size, act=act, res_scale=args.res_scale 28 | ) for _ in range(n_resblocks) 29 | ] 30 | self.body = nn.CellList(m_body) 31 | # self.body = m_body ###如果用这行,body这部分参数不会被训练 32 | self.body_conv = common.conv(self.n_feats, self.n_feats, self.kernel_size, padding=self.kernel_size//2) 33 | 34 | self.upsampler = common.Upsampler(scale, self.n_feats) 35 | self.tail_conv = common.conv(self.n_feats, args.n_colors, self.kernel_size, padding=self.kernel_size//2) 36 | 37 | def construct(self, x, width_mult=Tensor(1.0)): 38 | # def construct(self, x, width_mult): 39 | width_mult = width_mult.asnumpy().item() 40 | feature_width = int(self.n_feats * width_mult) 41 | conv2d = ops.Conv2D(out_channel=feature_width, kernel_size=self.kernel_size, mode=1, pad_mode='pad', 42 | pad=self.kernel_size // 2) 43 | biasadd = ops.BiasAdd() 44 | 45 | x = common.MeanShift(x, self.rgb_range) 46 | #原来写的是weight.clone()[] 47 | weight = self.head.weight[:feature_width, :self.n_colors, :, :] 48 | bias = self.head.bias[:feature_width] 49 | x = conv2d(x, weight) 50 | x = biasadd(x, bias) 51 | 52 | residual = x 53 | for block in self.body: 54 | residual = block(residual, width_mult) 55 | weight = self.body_conv.weight[:feature_width, :feature_width, :, :] 56 | bias = self.body_conv.bias[:feature_width] 57 | residual = conv2d(residual, weight) 58 | residual = biasadd(residual, bias) 59 | residual += x 60 | 61 | x = self.upsampler(residual, width_mult) 62 | weight = self.tail_conv.weight[:self.n_colors, :feature_width, :, :] 63 | bias = self.tail_conv.bias[:self.n_colors] 64 | conv2d = ops.Conv2D(out_channel=self.n_colors, kernel_size=self.kernel_size, mode=1, pad_mode='pad', pad=self.kernel_size//2) 65 | x = conv2d(x, weight) 66 | x = biasadd(x, bias) 67 | x = common.MeanShift(x, self.rgb_range, sign=1) 68 | 69 | return x -------------------------------------------------------------------------------- /MindSpore version/src/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Metric for evaluation.""" 16 | import os 17 | import math 18 | 19 | from PIL import Image 20 | import numpy as np 21 | from mindspore import nn, Tensor, ops 22 | from mindspore import dtype as mstype 23 | from mindspore.ops.operations.comm_ops import ReduceOp 24 | 25 | try: 26 | from model_utils.device_adapter import get_rank_id, get_device_num 27 | except ImportError: 28 | get_rank_id = None 29 | get_device_num = None 30 | finally: 31 | pass 32 | 33 | 34 | class SelfEnsembleWrapperNumpy: 35 | """ 36 | SelfEnsembleWrapperNumpy using numpy 37 | """ 38 | 39 | def __init__(self, net): 40 | super(SelfEnsembleWrapperNumpy, self).__init__() 41 | self.net = net 42 | 43 | def hflip(self, x): 44 | return x[:, :, :, ::-1] 45 | 46 | def vflip(self, x): 47 | return x[:, :, ::-1, :] 48 | 49 | def trnsps(self, x): 50 | return x.transpose(0, 1, 3, 2) 51 | 52 | def aug_x8(self, x): 53 | """ 54 | do x8 augments for input image 55 | """ 56 | # hflip 57 | hx = self.hflip(x) 58 | # vflip 59 | vx = self.vflip(x) 60 | vhx = self.vflip(hx) 61 | # trnsps 62 | tx = self.trnsps(x) 63 | thx = self.trnsps(hx) 64 | tvx = self.trnsps(vx) 65 | tvhx = self.trnsps(vhx) 66 | return x, hx, vx, vhx, tx, thx, tvx, tvhx 67 | 68 | def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx): 69 | """ 70 | undo x8 augments for input images 71 | """ 72 | # trnsps 73 | tvhx = self.trnsps(tvhx) 74 | tvx = self.trnsps(tvx) 75 | thx = self.trnsps(thx) 76 | tx = self.trnsps(tx) 77 | # vflip 78 | tvhx = self.vflip(tvhx) 79 | tvx = self.vflip(tvx) 80 | vhx = self.vflip(vhx) 81 | vx = self.vflip(vx) 82 | # hflip 83 | tvhx = self.hflip(tvhx) 84 | thx = self.hflip(thx) 85 | vhx = self.hflip(vhx) 86 | hx = self.hflip(hx) 87 | return x, hx, vx, vhx, tx, thx, tvx, tvhx 88 | 89 | def to_numpy(self, *inputs): 90 | if inputs: 91 | return None 92 | if len(inputs) == 1: 93 | return inputs[0].asnumpy() 94 | return [x.asnumpy() for x in inputs] 95 | 96 | def to_tensor(self, *inputs): 97 | if inputs: 98 | return None 99 | if len(inputs) == 1: 100 | return Tensor(inputs[0]) 101 | return [Tensor(x) for x in inputs] 102 | 103 | def set_train(self, mode=True): 104 | self.net.set_train(mode) 105 | return self 106 | 107 | def __call__(self, x): 108 | x = self.to_numpy(x) 109 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x) 110 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7) 111 | x0 = self.net(x0) 112 | x1 = self.net(x1) 113 | x2 = self.net(x2) 114 | x3 = self.net(x3) 115 | x4 = self.net(x4) 116 | x5 = self.net(x5) 117 | x6 = self.net(x6) 118 | x7 = self.net(x7) 119 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_numpy(x0, x1, x2, x3, x4, x5, x6, x7) 120 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7) 121 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7) 122 | return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8 123 | 124 | 125 | class SelfEnsembleWrapper(nn.Cell): 126 | """ 127 | because of [::-1] operator error, use "SelfEnsembleWrapperNumpy" instead 128 | """ 129 | def __init__(self, net): 130 | super(SelfEnsembleWrapper, self).__init__() 131 | self.net = net 132 | 133 | def hflip(self, x): 134 | raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue") 135 | 136 | def vflip(self, x): 137 | raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue") 138 | 139 | def trnsps(self, x): 140 | return x.transpose(0, 1, 3, 2) 141 | 142 | def aug_x8(self, x): 143 | """ 144 | do x8 augments for input image 145 | """ 146 | # hflip 147 | hx = self.hflip(x) 148 | # vflip 149 | vx = self.vflip(x) 150 | vhx = self.vflip(hx) 151 | # trnsps 152 | tx = self.trnsps(x) 153 | thx = self.trnsps(hx) 154 | tvx = self.trnsps(vx) 155 | tvhx = self.trnsps(vhx) 156 | return x, hx, vx, vhx, tx, thx, tvx, tvhx 157 | 158 | def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx): 159 | """ 160 | undo x8 augments for input images 161 | """ 162 | # trnsps 163 | tvhx = self.trnsps(tvhx) 164 | tvx = self.trnsps(tvx) 165 | thx = self.trnsps(thx) 166 | tx = self.trnsps(tx) 167 | # vflip 168 | tvhx = self.vflip(tvhx) 169 | tvx = self.vflip(tvx) 170 | vhx = self.vflip(vhx) 171 | vx = self.vflip(vx) 172 | # hflip 173 | tvhx = self.hflip(tvhx) 174 | thx = self.hflip(thx) 175 | vhx = self.hflip(vhx) 176 | hx = self.hflip(hx) 177 | return x, hx, vx, vhx, tx, thx, tvx, tvhx 178 | 179 | def construct(self, x): 180 | """ 181 | do x8 aug, run network, undo x8 aug, calculate mean for 8 output 182 | """ 183 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x) 184 | x0 = self.net(x0) 185 | x1 = self.net(x1) 186 | x2 = self.net(x2) 187 | x3 = self.net(x3) 188 | x4 = self.net(x4) 189 | x5 = self.net(x5) 190 | x6 = self.net(x6) 191 | x7 = self.net(x7) 192 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7) 193 | return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8 194 | 195 | 196 | class Quantizer(nn.Cell): 197 | """ 198 | clip by [0.0, 255.0], rount to int 199 | """ 200 | def __init__(self, _min=0.0, _max=255.0): 201 | super(Quantizer, self).__init__() 202 | self.round = ops.Round() 203 | self._min = _min 204 | self._max = _max 205 | 206 | def construct(self, x): 207 | x = ops.clip_by_value(x, self._min, self._max) 208 | x = self.round(x) 209 | return x 210 | 211 | 212 | class TensorSyncer(nn.Cell): 213 | """ 214 | sync metric values from all mindspore-processes 215 | """ 216 | def __init__(self, _type="sum"): 217 | super(TensorSyncer, self).__init__() 218 | self._type = _type.lower() 219 | if self._type == "sum": 220 | self.ops = ops.AllReduce(ReduceOp.SUM) 221 | elif self._type == "gather": 222 | self.ops = ops.AllGather() 223 | else: 224 | raise ValueError(f"TensorSyncer._type == {self._type} is not support") 225 | 226 | def construct(self, x): 227 | return self.ops(x) 228 | 229 | 230 | class _DistMetric(nn.Metric): 231 | """ 232 | gather data from all rank while eval(True) 233 | _type(str): choice from ["avg", "sum"]. 234 | """ 235 | def __init__(self, _type): 236 | super(_DistMetric, self).__init__() 237 | self._type = _type.lower() 238 | self.all_reduce_sum = None 239 | if get_device_num is not None and get_device_num() > 1: 240 | self.all_reduce_sum = TensorSyncer(_type="sum") 241 | self.clear() 242 | 243 | def _accumulate(self, value): 244 | if isinstance(value, (list, tuple)): 245 | self._acc_value += sum(value) 246 | self._count += len(value) 247 | else: 248 | self._acc_value += value 249 | self._count += 1 250 | 251 | def clear(self): 252 | self._acc_value = 0.0 253 | self._count = 0 254 | 255 | def eval(self, sync=True): 256 | """ 257 | sync: True, return metric value merged from all mindspore-processes 258 | sync: False, return metric value in this single mindspore-processes 259 | """ 260 | if self._count == 0: 261 | raise RuntimeError('self._count == 0') 262 | if self.sum is not None and sync: 263 | data = Tensor([self._acc_value, self._count], mstype.float32) 264 | data = self.all_reduce_sum(data) 265 | acc_value, count = self._convert_data(data).tolist() 266 | else: 267 | acc_value, count = self._acc_value, self._count 268 | if self._type == "avg": 269 | return acc_value / count 270 | if self._type == "sum": 271 | return acc_value 272 | raise RuntimeError(f"_DistMetric._type={self._type} is not support") 273 | 274 | 275 | class PSNR(_DistMetric): 276 | """ 277 | Define PSNR metric for SR network. 278 | """ 279 | def __init__(self, rgb_range, shave): 280 | super(PSNR, self).__init__(_type="avg") 281 | self.shave = shave 282 | self.rgb_range = rgb_range 283 | self.quantize = Quantizer(0.0, 255.0) 284 | 285 | def update(self, *inputs): 286 | """ 287 | update psnr 288 | """ 289 | if len(inputs) != 2: 290 | raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs))) 291 | sr, hr = inputs 292 | sr = self.quantize(sr) 293 | diff = (sr - hr) / self.rgb_range 294 | valid = diff 295 | if self.shave is not None and self.shave != 0: 296 | valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)] 297 | mse_list = (valid ** 2).mean(axis=(1, 2, 3)) 298 | mse_list = self._convert_data(mse_list).tolist() 299 | psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list] 300 | self._accumulate(psnr_list) 301 | 302 | 303 | class SaveSrHr(_DistMetric): 304 | """ 305 | help to save sr and hr 306 | """ 307 | def __init__(self, save_dir): 308 | super(SaveSrHr, self).__init__(_type="sum") 309 | self.save_dir = save_dir 310 | self.quantize = Quantizer(0.0, 255.0) 311 | self.rank_id = 0 if get_rank_id is None else get_rank_id() 312 | self.device_num = 1 if get_device_num is None else get_device_num() 313 | 314 | def update(self, *inputs): 315 | """ 316 | update images to save 317 | """ 318 | if len(inputs) != 2: 319 | raise ValueError('SaveSrHr need 2 inputs (sr, hr), but got {}'.format(len(inputs))) 320 | sr, hr = inputs 321 | sr = self.quantize(sr) 322 | sr = self._convert_data(sr).astype(np.uint8) 323 | hr = self._convert_data(hr).astype(np.uint8) 324 | for s, h in zip(sr.transpose(0, 2, 3, 1), hr.transpose(0, 2, 3, 1)): 325 | idx = self._count * self.device_num + self.rank_id 326 | sr_path = os.path.join(self.save_dir, f"{idx:0>4}_sr.png") 327 | Image.fromarray(s).save(sr_path) 328 | hr_path = os.path.join(self.save_dir, f"{idx:0>4}_hr.png") 329 | Image.fromarray(h).save(hr_path) 330 | self._accumulate(1) 331 | -------------------------------------------------------------------------------- /MindSpore version/src/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """metrics""" 16 | import math 17 | import numpy as np 18 | import cv2 19 | 20 | 21 | def quantize(img, rgb_range): 22 | """quantize image range to 0-255""" 23 | pixel_range = 255 / rgb_range 24 | img = np.multiply(img, pixel_range) 25 | img = np.clip(img, 0, 255) 26 | img = np.round(img) / pixel_range 27 | return img 28 | 29 | 30 | def calc_psnr(sr, hr, scale, rgb_range): 31 | """calculate psnr""" 32 | hr = np.float32(hr) 33 | sr = np.float32(sr) 34 | diff = (sr - hr) / rgb_range 35 | gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256 36 | diff = np.multiply(diff, gray_coeffs).sum(1) 37 | if hr.size == 1: 38 | return 0 39 | 40 | shave = scale 41 | valid = diff[..., shave:-shave, shave:-shave] 42 | mse = np.mean(pow(valid, 2)) 43 | return -10 * math.log10(mse) 44 | 45 | 46 | def rgb2ycbcr(img, y_only=True): 47 | """from rgb space to ycbcr space""" 48 | img.astype(np.float32) 49 | if y_only: 50 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 51 | return rlt 52 | 53 | 54 | def calc_ssim(img1, img2, scale): 55 | """calculate ssim""" 56 | def ssim(img1, img2): 57 | """calculate ssim""" 58 | C1 = (0.01 * 255) ** 2 59 | C2 = (0.03 * 255) ** 2 60 | 61 | img1 = img1.astype(np.float64) 62 | img2 = img2.astype(np.float64) 63 | kernel = cv2.getGaussianKernel(11, 1.5) 64 | window = np.outer(kernel, kernel.transpose()) 65 | 66 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 67 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 68 | mu1_sq = mu1 ** 2 69 | mu2_sq = mu2 ** 2 70 | mu1_mu2 = mu1 * mu2 71 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 72 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 73 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 74 | 75 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 76 | (sigma1_sq + sigma2_sq + C2)) 77 | return ssim_map.mean() 78 | 79 | border = scale 80 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0 81 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0 82 | if not img1.shape == img2.shape: 83 | raise ValueError('Input images must have the same dimensions.') 84 | h, w = img1.shape[:2] 85 | img1_y = img1_y[border:h - border, border:w - border] 86 | img2_y = img2_y[border:h - border, border:w - border] 87 | 88 | if img1_y.ndim == 2: 89 | return ssim(img1_y, img2_y) 90 | if img1.ndim == 3: 91 | if img1.shape[2] == 3: 92 | ssims = [] 93 | for _ in range(3): 94 | ssims.append(ssim(img1, img2)) 95 | 96 | return np.array(ssims).mean() 97 | if img1.shape[2] == 1: 98 | return ssim(np.squeeze(img1), np.squeeze(img2)) 99 | 100 | raise ValueError('Wrong input image dimensions.') 101 | -------------------------------------------------------------------------------- /MindSpore version/src/rcan_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """rcan""" 16 | import math 17 | import mindspore.ops as ops 18 | import mindspore.nn as nn 19 | import mindspore.common.dtype as mstype 20 | from mindspore.ops import operations as P 21 | from mindspore.common import Tensor, Parameter 22 | 23 | 24 | def default_conv(in_channels, out_channels, kernel_size, has_bias=True): 25 | """rcan""" 26 | return nn.Conv2d( 27 | in_channels, out_channels, kernel_size, 28 | padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad') 29 | 30 | 31 | class MeanShift(nn.Conv2d): 32 | """rcan""" 33 | def __init__(self, 34 | rgb_range, 35 | rgb_mean=(0.4488, 0.4371, 0.4040), 36 | rgb_std=(1.0, 1.0, 1.0), 37 | sign=-1): 38 | """rcan""" 39 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 40 | self.reshape = P.Reshape() 41 | self.eye = P.Eye() 42 | std = Tensor(rgb_std, mstype.float32) 43 | self.weight.set_data( 44 | self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1))) 45 | self.weight.requires_grad = False 46 | self.bias = Parameter( 47 | sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False) 48 | self.has_bias = True 49 | 50 | 51 | def _pixelsf_(x, scale): 52 | """rcan""" 53 | n, c, ih, iw = x.shape 54 | oh = ih * scale 55 | ow = iw * scale 56 | oc = c // (scale ** 2) 57 | output = P.Transpose()(x, (0, 2, 1, 3)) 58 | output = P.Reshape()(output, (n, ih, oc * scale, scale, iw)) 59 | output = P.Transpose()(output, (0, 1, 2, 4, 3)) 60 | output = P.Reshape()(output, (n, ih, oc, scale, ow)) 61 | output = P.Transpose()(output, (0, 2, 1, 3, 4)) 62 | output = P.Reshape()(output, (n, oc, oh, ow)) 63 | return output 64 | 65 | 66 | class SmallUpSampler(nn.Cell): 67 | """rcan""" 68 | def __init__(self, conv, upsize, n_feats, has_bias=True): 69 | """rcan""" 70 | super(SmallUpSampler, self).__init__() 71 | self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias) 72 | self.reshape = P.Reshape() 73 | self.upsize = upsize 74 | self.pixelsf = _pixelsf_ 75 | 76 | def construct(self, x): 77 | """rcan""" 78 | x = self.conv(x) 79 | output = self.pixelsf(x, self.upsize) 80 | return output 81 | 82 | 83 | class Upsampler(nn.Cell): 84 | """rcan""" 85 | def __init__(self, conv, scale, n_feats, has_bias=True): 86 | """rcan""" 87 | super(Upsampler, self).__init__() 88 | m = [] 89 | if (scale & (scale - 1)) == 0: 90 | for _ in range(int(math.log(scale, 2))): 91 | m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias)) 92 | elif scale == 3: 93 | m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias)) 94 | self.net = nn.SequentialCell(m) 95 | 96 | def construct(self, x): 97 | """rcan""" 98 | return self.net(x) 99 | 100 | 101 | class AdaptiveAvgPool2d(nn.Cell): 102 | """rcan""" 103 | def __init__(self): 104 | """rcan""" 105 | super().__init__() 106 | self.ReduceMean = ops.ReduceMean(keep_dims=True) 107 | 108 | def construct(self, x): 109 | """rcan""" 110 | return self.ReduceMean(x, 0) 111 | 112 | 113 | class CALayer(nn.Cell): 114 | """rcan""" 115 | def __init__(self, channel, reduction=16): 116 | """rcan""" 117 | super(CALayer, self).__init__() 118 | # global average pooling: feature --> point 119 | self.avg_pool = AdaptiveAvgPool2d() 120 | # feature channel downscale and upscale --> channel weight 121 | self.conv_du = nn.SequentialCell([ 122 | nn.Conv2d(channel, channel // reduction, 1, padding=0, has_bias=True, pad_mode='pad'), 123 | nn.ReLU(), 124 | nn.Conv2d(channel // reduction, channel, 1, padding=0, has_bias=True, pad_mode='pad'), 125 | nn.Sigmoid() 126 | ]) 127 | 128 | def construct(self, x): 129 | """rcan""" 130 | y = self.avg_pool(x) 131 | y = self.conv_du(y) 132 | return x * y 133 | 134 | 135 | class RCAB(nn.Cell): 136 | """rcan""" 137 | def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True 138 | , bn=False, act=nn.ReLU(), res_scale=1): 139 | """rcan""" 140 | super(RCAB, self).__init__() 141 | self.modules_body = [] 142 | for i in range(2): 143 | self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias)) 144 | if bn: self.modules_body.append(nn.BatchNorm2d(n_feat)) 145 | if i == 0: self.modules_body.append(act) 146 | self.modules_body.append(CALayer(n_feat, reduction)) 147 | self.body = nn.SequentialCell(*self.modules_body) 148 | self.res_scale = res_scale 149 | 150 | def construct(self, x): 151 | """rcan""" 152 | res = self.body(x) 153 | res += x 154 | return res 155 | 156 | 157 | class ResidualGroup(nn.Cell): 158 | """rcan""" 159 | def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks): 160 | """rcan""" 161 | super(ResidualGroup, self).__init__() 162 | modules_body = [] 163 | modules_body = [ 164 | RCAB( 165 | conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \ 166 | for _ in range(n_resblocks)] 167 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 168 | self.body = nn.SequentialCell(*modules_body) 169 | 170 | def construct(self, x): 171 | """rcan""" 172 | res = self.body(x) 173 | res += x 174 | return res 175 | 176 | 177 | class RCAN(nn.Cell): 178 | """rcan""" 179 | def __init__(self, args, conv=default_conv): 180 | """rcan""" 181 | super(RCAN, self).__init__() 182 | 183 | n_resgroups = args.n_resgroups 184 | n_resblocks = args.n_resblocks 185 | n_feats = args.n_feats 186 | kernel_size = 3 187 | reduction = args.reduction 188 | idx = args.task_id 189 | scale = args.scale[idx] 190 | self.dytpe = mstype.float16 191 | 192 | # RGB mean for DIV2K 193 | rgb_mean = (0.4488, 0.4371, 0.4040) 194 | rgb_std = (1.0, 1.0, 1.0) 195 | 196 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe) 197 | 198 | # define head module 199 | modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe) 200 | 201 | # define body module 202 | modules_body = [ 203 | ResidualGroup( 204 | conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \ 205 | for _ in range(n_resgroups)] 206 | 207 | modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe)) 208 | 209 | # define tail module 210 | modules_tail = [ 211 | Upsampler(conv, scale, n_feats).to_float(self.dytpe), 212 | conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)] 213 | 214 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe) 215 | 216 | self.head = modules_head 217 | self.body = nn.SequentialCell(modules_body) 218 | self.tail = nn.SequentialCell(modules_tail) 219 | 220 | def construct(self, x): 221 | """rcan""" 222 | x = self.sub_mean(x) 223 | x = self.head(x) 224 | res = self.body(x) 225 | res += x 226 | x = self.tail(res) 227 | x = self.add_mean(x) 228 | return x 229 | -------------------------------------------------------------------------------- /MindSpore version/src/vgg_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """ 16 | Image classifiation. 17 | """ 18 | import math 19 | import mindspore.nn as nn 20 | import mindspore.common.dtype as mstype 21 | from mindspore.common import initializer as init 22 | from mindspore.common.initializer import initializer 23 | from utils.var_init import default_recurisive_init, KaimingNormal 24 | 25 | 26 | def _make_layer(base, args, batch_norm): 27 | """Make stage network of VGG.""" 28 | layers = [] 29 | in_channels = 3 30 | for v in base: 31 | if v == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | weight = 'ones' 35 | if args.initialize_mode == "XavierUniform": 36 | weight_shape = (v, in_channels, 3, 3) 37 | weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32) 38 | 39 | conv2d = nn.Conv2d(in_channels=in_channels, 40 | out_channels=v, 41 | kernel_size=3, 42 | padding=args.padding, 43 | pad_mode=args.pad_mode, 44 | has_bias=args.has_bias, 45 | weight_init=weight) 46 | if batch_norm: 47 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] 48 | else: 49 | layers += [conv2d, nn.ReLU()] 50 | in_channels = v 51 | return nn.SequentialCell(layers) 52 | 53 | 54 | class Vgg(nn.Cell): 55 | """ 56 | VGG network definition. 57 | 58 | Args: 59 | base (list): Configuration for different layers, mainly the channel number of Conv layer. 60 | num_classes (int): Class numbers. Default: 1000. 61 | batch_norm (bool): Whether to do the batchnorm. Default: False. 62 | batch_size (int): Batch size. Default: 1. 63 | include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True. 64 | 65 | Returns: 66 | Tensor, infer output tensor. 67 | 68 | Examples: 69 | >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 70 | >>> num_classes=1000, batch_norm=False, batch_size=1) 71 | """ 72 | 73 | def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train", 74 | include_top=True): 75 | super(Vgg, self).__init__() 76 | _ = batch_size 77 | self.layers = _make_layer(base, args, batch_norm=batch_norm) 78 | self.include_top = include_top 79 | self.flatten = nn.Flatten() 80 | dropout_ratio = 0.5 81 | # if not args.has_dropout or phase == "test": 82 | # dropout_ratio = 1.0 83 | self.classifier = nn.SequentialCell([ 84 | nn.Dense(512 * 7 * 7, 4096), 85 | nn.ReLU(), 86 | nn.Dropout(dropout_ratio), 87 | nn.Dense(4096, 4096), 88 | nn.ReLU(), 89 | nn.Dropout(dropout_ratio), 90 | nn.Dense(4096, num_classes)]) 91 | # if args.initialize_mode == "KaimingNormal": 92 | # default_recurisive_init(self) 93 | # self.custom_init_weight() 94 | 95 | def construct(self, x): 96 | x = self.layers(x) 97 | if self.include_top: 98 | x = self.flatten(x) 99 | x = self.classifier(x) 100 | return x 101 | 102 | def custom_init_weight(self): 103 | """ 104 | Init the weight of Conv2d and Dense in the net. 105 | """ 106 | for _, cell in self.cells_and_names(): 107 | if isinstance(cell, nn.Conv2d): 108 | cell.weight.set_data(init.initializer( 109 | KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), 110 | cell.weight.shape, cell.weight.dtype)) 111 | if cell.bias is not None: 112 | cell.bias.set_data(init.initializer( 113 | 'zeros', cell.bias.shape, cell.bias.dtype)) 114 | elif isinstance(cell, nn.Dense): 115 | cell.weight.set_data(init.initializer( 116 | init.Normal(0.01), cell.weight.shape, cell.weight.dtype)) 117 | if cell.bias is not None: 118 | cell.bias.set_data(init.initializer( 119 | 'zeros', cell.bias.shape, cell.bias.dtype)) 120 | 121 | 122 | cfg = { 123 | '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 124 | '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 125 | '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 126 | '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 127 | } 128 | 129 | 130 | def vgg16(num_classes=1000, args=None, phase="train", **kwargs): 131 | """ 132 | Get Vgg16 neural network with batch normalization. 133 | 134 | Args: 135 | num_classes (int): Class numbers. Default: 1000. 136 | args(namespace): param for net init. 137 | phase(str): train or test mode. 138 | 139 | Returns: 140 | Cell, cell instance of Vgg16 neural network with batch normalization. 141 | 142 | Examples: 143 | >>> vgg16(num_classes=1000, args=args, **kwargs) 144 | """ 145 | 146 | if args is None: 147 | from .config import cifar_cfg 148 | args = cifar_cfg 149 | net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase, **kwargs) 150 | return net 151 | -------------------------------------------------------------------------------- /MindSpore version/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """train""" 16 | import os 17 | import time 18 | from mindspore import context 19 | from mindspore.context import ParallelMode 20 | import mindspore.dataset as ds 21 | import mindspore.nn as nn 22 | from mindspore.train.serialization import load_checkpoint, load_param_into_net 23 | from mindspore.communication.management import init 24 | from mindspore.common import set_seed 25 | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback 26 | from mindspore.train.model import Model 27 | from mindspore.train.loss_scale_manager import DynamicLossScaleManager 28 | from src.args import args 29 | from src.data.div2k import DIV2K 30 | from src.data.srdata import SRData 31 | from src.rcan_model import RCAN 32 | # from src.edsr_model import EDSR 33 | from src.edsr_slim import EDSR 34 | from src.metric import PSNR 35 | from eval import do_eval 36 | 37 | # def do_eval(eval_network, ds_val, metrics, rank_id, cur_epoch=None): 38 | # """ 39 | # do eval for psnr and save hr, sr 40 | # """ 41 | # eval_network.set_train(False) 42 | # total_step = ds_val.get_dataset_size() 43 | # setw = len(str(total_step)) 44 | # begin = time.time() 45 | # step_begin = time.time() 46 | # # rank_id = get_rank_id() 47 | # ds_val = ds_val.create_dict_iterator(output_numpy=True) 48 | # for i, (lr, hr) in enumerate(ds_val): 49 | # sr = eval_network(lr) 50 | # _ = [m.update(sr, hr) for m in metrics.values()] 51 | # result = {k: m.eval(sync=False) for k, m in metrics.items()} 52 | # result["time"] = time.time() - step_begin 53 | # step_begin = time.time() 54 | # print(f"[{i+1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", flush=True) 55 | # result = {k: m.eval(sync=True) for k, m in metrics.items()} 56 | # result["time"] = time.time() - begin 57 | # if cur_epoch is not None: 58 | # result["epoch"] = cur_epoch 59 | # if rank_id == 0: 60 | # print(f"evaluation result = {result}", flush=True) 61 | # eval_network.set_train(True) 62 | # return result 63 | 64 | class EvalCallBack(Callback): 65 | """ 66 | eval callback 67 | """ 68 | def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, rank_id, result_evaluation=None): 69 | self.eval_network = eval_network 70 | self.ds_val = ds_val 71 | self.eval_epoch_frq = eval_epoch_frq 72 | self.epoch_size = epoch_size 73 | self.result_evaluation = result_evaluation 74 | self.metrics = metrics 75 | self.best_result = 0 76 | self.rank_id = rank_id 77 | self.eval_network.set_train(False) 78 | 79 | def epoch_end(self, run_context): 80 | """ 81 | do eval in epoch end 82 | """ 83 | cb_param = run_context.original_args() 84 | cur_epoch = cb_param.cur_epoch_num 85 | if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size: 86 | # result = do_eval(self.eval_network, self.ds_val, self.metrics, self.rank_id, cur_epoch=cur_epoch) 87 | result = do_eval(self.ds_val, self.eval_network) 88 | if self.best_result is None or self.best_result < result: 89 | self.best_result = result 90 | if self.rank_id == 0: 91 | print(f"best evaluation result = {self.best_result}", flush=True) 92 | if isinstance(self.result_evaluation, dict): 93 | for k, v in result.items(): 94 | r_list = self.result_evaluation.get(k) 95 | if r_list is None: 96 | r_list = [] 97 | self.result_evaluation[k] = r_list 98 | r_list.append(v) 99 | 100 | 101 | def train(): 102 | """train""" 103 | set_seed(1) 104 | device_id = int(os.getenv('DEVICE_ID', '0')) 105 | rank_id = int(os.getenv('RANK_ID', '0')) 106 | device_num = int(os.getenv('RANK_SIZE', '1')) 107 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False, device_id=device_id) 108 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False, device_id=device_id) 109 | 110 | if device_num > 1: 111 | init() 112 | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, 113 | device_num=device_num, global_rank=device_id, 114 | gradients_mean=True) 115 | if args.modelArts_mode: 116 | import moxing as mox 117 | local_data_url = '/cache/data' 118 | mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url) 119 | 120 | train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False) 121 | train_dataset.set_scale(args.task_id) 122 | print(len(train_dataset)) 123 | train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num, 124 | shard_id=rank_id, shuffle=True) 125 | train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) 126 | 127 | eval_dataset = SRData(args, name=args.data_test, train=False, benchmark=True) 128 | print(len(eval_dataset)) 129 | eval_ds = ds.GeneratorDataset(eval_dataset, ['LR', 'HR'], shuffle=False) 130 | eval_ds = eval_ds.batch(1, drop_remainder=True) 131 | 132 | # net_m = RCAN(args) 133 | net_m = EDSR(args) 134 | print("Init net weights successfully") 135 | 136 | if args.ckpt_path: 137 | param_dict = load_checkpoint(args.pth_path) 138 | load_param_into_net(net_m, param_dict) 139 | print("Load net weight successfully") 140 | step_size = train_de_dataset.get_dataset_size() 141 | lr = [] 142 | for i in range(0, args.epochs): 143 | cur_lr = args.lr / (2 ** ((i + 1) // 200)) 144 | lr.extend([cur_lr] * step_size) 145 | opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale) 146 | loss = nn.L1Loss() 147 | loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \ 148 | scale_factor=2, scale_window=1000) 149 | 150 | eval_net = net_m 151 | model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager) 152 | 153 | time_cb = TimeMonitor(data_size=step_size) 154 | loss_cb = LossMonitor() 155 | metrics = { 156 | "psnr": PSNR(rgb_range=args.rgb_range, shave=True), 157 | } 158 | eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size/args.batch_size, metrics=metrics, rank_id=rank_id) 159 | cb = [time_cb, loss_cb, eval_cb] 160 | config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size, 161 | keep_checkpoint_max=args.ckpt_save_max) 162 | ckpt_cb = ModelCheckpoint(prefix=args.filename, directory=args.ckpt_save_path, config=config_ck) 163 | if device_id == 0: 164 | cb += [ckpt_cb] 165 | model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True) 166 | 167 | 168 | if __name__ == "__main__": 169 | time_start = time.time() 170 | train() 171 | time_end = time.time() 172 | print('train_time: %f' % (time_end - time_start)) 173 | -------------------------------------------------------------------------------- /MindSpore version/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/MindSpore version/utils/__init__.py -------------------------------------------------------------------------------- /MindSpore version/utils/var_init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Huawei Technologies Co., Ltd 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """ 16 | Initialize. 17 | """ 18 | import math 19 | from functools import reduce 20 | import numpy as np 21 | import mindspore.nn as nn 22 | from mindspore.common import initializer as init 23 | 24 | def _calculate_gain(nonlinearity, param=None): 25 | r""" 26 | Return the recommended gain value for the given nonlinearity function. 27 | 28 | The values are as follows: 29 | ================= ==================================================== 30 | nonlinearity gain 31 | ================= ==================================================== 32 | Linear / Identity :math:`1` 33 | Conv{1,2,3}D :math:`1` 34 | Sigmoid :math:`1` 35 | Tanh :math:`\frac{5}{3}` 36 | ReLU :math:`\sqrt{2}` 37 | Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 38 | ================= ==================================================== 39 | 40 | Args: 41 | nonlinearity: the non-linear function 42 | param: optional parameter for the non-linear function 43 | 44 | Examples: 45 | >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 46 | """ 47 | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 48 | if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 49 | return 1 50 | if nonlinearity == 'tanh': 51 | return 5.0 / 3 52 | if nonlinearity == 'relu': 53 | return math.sqrt(2.0) 54 | if nonlinearity == 'leaky_relu': 55 | if param is None: 56 | negative_slope = 0.01 57 | elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 58 | negative_slope = param 59 | else: 60 | raise ValueError("negative_slope {} not a valid number".format(param)) 61 | return math.sqrt(2.0 / (1 + negative_slope ** 2)) 62 | 63 | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 64 | 65 | def _assignment(arr, num): 66 | """Assign the value of `num` to `arr`.""" 67 | if arr.shape == (): 68 | arr = arr.reshape((1)) 69 | arr[:] = num 70 | arr = arr.reshape(()) 71 | else: 72 | if isinstance(num, np.ndarray): 73 | arr[:] = num[:] 74 | else: 75 | arr[:] = num 76 | return arr 77 | 78 | def _calculate_in_and_out(arr): 79 | """ 80 | Calculate n_in and n_out. 81 | 82 | Args: 83 | arr (Array): Input array. 84 | 85 | Returns: 86 | Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. 87 | """ 88 | dim = len(arr.shape) 89 | if dim < 2: 90 | raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") 91 | 92 | n_in = arr.shape[1] 93 | n_out = arr.shape[0] 94 | 95 | if dim > 2: 96 | counter = reduce(lambda x, y: x * y, arr.shape[2:]) 97 | n_in *= counter 98 | n_out *= counter 99 | return n_in, n_out 100 | 101 | def _select_fan(array, mode): 102 | mode = mode.lower() 103 | valid_modes = ['fan_in', 'fan_out'] 104 | if mode not in valid_modes: 105 | raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) 106 | 107 | fan_in, fan_out = _calculate_in_and_out(array) 108 | return fan_in if mode == 'fan_in' else fan_out 109 | 110 | class KaimingInit(init.Initializer): 111 | r""" 112 | Base Class. Initialize the array with He kaiming algorithm. 113 | 114 | Args: 115 | a: the negative slope of the rectifier used after this layer (only 116 | used with ``'leaky_relu'``) 117 | mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 118 | preserves the magnitude of the variance of the weights in the 119 | forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 120 | backwards pass. 121 | nonlinearity: the non-linear function, recommended to use only with 122 | ``'relu'`` or ``'leaky_relu'`` (default). 123 | """ 124 | def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): 125 | super(KaimingInit, self).__init__() 126 | self.mode = mode 127 | self.gain = _calculate_gain(nonlinearity, a) 128 | def _initialize(self, arr): 129 | pass 130 | 131 | 132 | class KaimingUniform(KaimingInit): 133 | r""" 134 | Initialize the array with He kaiming uniform algorithm. The resulting tensor will 135 | have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 136 | 137 | .. math:: 138 | \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} 139 | 140 | Input: 141 | arr (Array): The array to be assigned. 142 | 143 | Returns: 144 | Array, assigned array. 145 | 146 | Examples: 147 | >>> w = np.empty(3, 5) 148 | >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') 149 | """ 150 | 151 | def _initialize(self, arr): 152 | fan = _select_fan(arr, self.mode) 153 | bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) 154 | data = np.random.uniform(-bound, bound, arr.shape) 155 | 156 | _assignment(arr, data) 157 | 158 | 159 | class KaimingNormal(KaimingInit): 160 | r""" 161 | Initialize the array with He kaiming normal algorithm. The resulting tensor will 162 | have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where 163 | 164 | .. math:: 165 | \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} 166 | 167 | Input: 168 | arr (Array): The array to be assigned. 169 | 170 | Returns: 171 | Array, assigned array. 172 | 173 | Examples: 174 | >>> w = np.empty(3, 5) 175 | >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') 176 | """ 177 | 178 | def _initialize(self, arr): 179 | fan = _select_fan(arr, self.mode) 180 | std = self.gain / math.sqrt(fan) 181 | data = np.random.normal(0, std, arr.shape) 182 | 183 | _assignment(arr, data) 184 | 185 | 186 | def default_recurisive_init(custom_cell): 187 | """default_recurisive_init""" 188 | for _, cell in custom_cell.cells_and_names(): 189 | if isinstance(cell, nn.Conv2d): 190 | cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), 191 | cell.weight.shape, 192 | cell.weight.dtype)) 193 | if cell.bias is not None: 194 | fan_in, _ = _calculate_in_and_out(cell.weight) 195 | bound = 1 / math.sqrt(fan_in) 196 | cell.bias.set_data(init.initializer(init.Uniform(bound), 197 | cell.bias.shape, 198 | cell.bias.dtype)) 199 | elif isinstance(cell, nn.Dense): 200 | cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), 201 | cell.weight.shape, 202 | cell.weight.dtype)) 203 | if cell.bias is not None: 204 | fan_in, _ = _calculate_in_and_out(cell.weight) 205 | bound = 1 / math.sqrt(fan_in) 206 | cell.bias.set_data(init.initializer(init.Uniform(bound), 207 | cell.bias.shape, 208 | cell.bias.dtype)) 209 | elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): 210 | pass 211 | -------------------------------------------------------------------------------- /PyTorch version/README.md: -------------------------------------------------------------------------------- 1 | ## Dependencies 2 | 3 | PyTorch 4 | 5 | matplotlib 6 | 7 | imageio 8 | 9 | tensorboardX 10 | 11 | opencv-python 12 | 13 | scipy 14 | 15 | scikit-image 16 | 17 | ## Train 18 | 19 | ### Prepare data 20 | 21 | We use DIV2K training set as our training data. 22 | 23 | About how to download data, you could refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) 24 | 25 | ### Begin to train 26 | 27 | ```python 28 | # use '--teacher_model' to specify the teacher model 29 | python main.py --model EDSR --scale 4 --reset --dir_data LOCATION_OF_DATA --model_filename edsr_x4_0.25student --pre_train output/model/edsr/ --epochs 400 --model_stat --neg_num 10 --contra_lambda 200 --t_lambda 1 --t_l_remove 400 --contrast_t_detach 30 | ``` 31 | 32 | ## Test 33 | 34 | You could download models of our paper from [BaiduYun](https://pan.baidu.com/s/1gYenkfLac1s19lfzczxtHA )(code: zser). 35 | 36 | 1. Test on benchmarks: 37 | 38 | ``` 39 | python main.py --scale 4 --pre_train FOLDER_OF_THE_PRETRAINED_MODEL --model_filename edsr_x4_0.25student --test_only --self_ensemble --dir_demo test --model EDSR --dir_data LOCATION_OF_DATA --n_GPUs 1 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --stu_width_mult 0.25 --model_stat --data_test Set5 40 | ``` 41 | 42 | 2. Test your own image: 43 | 44 | ```python 45 | python main.py --scale 4 --pre_train output/model --model_filename edsr_x4_0.25student --test_only --self_ensemble --model EDSR --n_GPUs 1 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --stu_width_mult 0.25 --model_stat --data_test Demo --save_results --dir_demo LOCATION_OF_YOUR_IMAGE 46 | ``` 47 | 48 | The output SR image will be save in './test/result/' 49 | 50 | -------------------------------------------------------------------------------- /PyTorch version/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | # MyConcatDataset(datasets), 28 | datasets[0], 29 | batch_size=args.batch_size, 30 | shuffle=True, 31 | pin_memory=not args.cpu, 32 | num_workers=args.n_threads, 33 | ) 34 | 35 | self.loader_test = [] 36 | for d in args.data_test: 37 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 38 | m = import_module('data.benchmark') 39 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 40 | else: 41 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 42 | m = import_module('data.' + module_name.lower()) 43 | testset = getattr(m, module_name)(args, train=False, name=d) 44 | 45 | self.loader_test.append( 46 | dataloader.DataLoader( 47 | testset, 48 | batch_size=1, 49 | shuffle=False, 50 | pin_memory=not args.cpu, 51 | num_workers=args.n_threads, 52 | ) 53 | ) 54 | 55 | # if args.data_aux == 'BSD500': 56 | # m = import_module('data.bsd500') 57 | # auxset = getattr(m, 'BSD500')(args, train=True, name='t') 58 | # self.loader_aux = dataloader.DataLoader( 59 | # auxset, 60 | # batch_size=args.neg_num, 61 | # shuffle=True, 62 | # pin_memory=not args.cpu, 63 | # num_workers=args.n_threads, 64 | # ) -------------------------------------------------------------------------------- /PyTorch version/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /PyTorch version/data/bsd500.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | import PIL 5 | import random 6 | 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | 11 | from data import common 12 | 13 | class BSD500(data.Dataset): 14 | def __init__(self, args, name='BSR', train=True): 15 | super(BSD500, self).__init__() 16 | self.args = args 17 | self.name = name 18 | self.train = train 19 | self.scale = args.scale[0] 20 | self.benchmark = False 21 | 22 | if train: 23 | self.dir_data = os.path.join(args.dir_data, 'BSR/BSDS500/data/images/train/') 24 | else: 25 | self.dir_data = os.path.join(args.dir_data, 'BSR/BSDS500/data/images/test/') 26 | self.dataset_from_folder = common.DatasetFromFolder(self.dir_data, '.jpg') 27 | self.cropper = transforms.RandomCrop(size=args.patch_size) 28 | self.resizer = transforms.Resize(size=args.patch_size // args.scale[0], interpolation=PIL.Image.BICUBIC) 29 | self.totensor = transforms.ToTensor() 30 | # @staticmethod 31 | # def to_tensor(patch): 32 | # return torch.Tensor(np.asarray(patch).swapaxes(0, 2)).float() / 255 - 0.5 33 | # 34 | # @staticmethod 35 | # def to_image(tensor): 36 | # if type(tensor) == torch.Tensor: 37 | # tensor = tensor.numpy() 38 | # 39 | # if len(tensor.shape) == 4: 40 | # tensor = tensor.swapaxes(1, 3) 41 | # elif len(tensor.shape) == 3: 42 | # tensor = tensor.swapaxes(0, 2) 43 | # else: 44 | # raise Exception("Predictions have shape not in set {3,4}") 45 | # 46 | # tensor = (tensor + 0.5) * 255 47 | # tensor[tensor > 255] = 255 48 | # tensor[tensor < 0] = 0 49 | # return tensor.round().astype(int) 50 | 51 | def __getitem__(self, index): 52 | img, filename = self.dataset_from_folder[index] 53 | 54 | if self.train: 55 | hr_patch = self.cropper(img) 56 | else: 57 | hr_patch = img 58 | 59 | lr_patch = self.resizer(hr_patch) 60 | hr_patch = self.totensor(hr_patch) 61 | lr_patch = self.totensor(lr_patch) 62 | if (not self.args.no_augment) and self.train: 63 | lr_patch, hr_patch = self.data_aug(lr_patch, hr_patch) 64 | 65 | hr_patch = hr_patch.mul_(self.args.rgb_range) 66 | lr_patch = lr_patch.mul_(self.args.rgb_range) 67 | 68 | return lr_patch, hr_patch, filename 69 | 70 | def __len__(self): 71 | return len(self.dataset_from_folder) 72 | 73 | def data_aug(self, lr, hr, hflip=True, rot=True): 74 | hflip = hflip and random.random() < 0.5 75 | vflip = rot and random.random() < 0.5 76 | rot90 = rot and random.random() < 0.5 77 | 78 | if hflip: 79 | lr = torch.flip(lr, [1]) 80 | hr = torch.flip(hr, [1]) 81 | # lr = lr[:, ::-1, :] 82 | # hr = hr[:, ::-1, :] 83 | if vflip: 84 | lr = torch.flip(lr, [2]) 85 | hr = torch.flip(hr, [2]) 86 | # lr = lr[::-1, :, :] 87 | # hr = hr[::-1, :, :] 88 | if rot90: 89 | lr = lr.permute(0, 2, 1) 90 | hr = hr.permute(0, 2, 1) 91 | 92 | return lr, hr -------------------------------------------------------------------------------- /PyTorch version/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | import numpy as np 5 | import skimage.color as sc 6 | 7 | import os 8 | import imageio 9 | import numpy as np 10 | import PIL 11 | 12 | import torch 13 | import torch.utils.data as data 14 | 15 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 16 | ih, iw = args[0].shape[:2] 17 | 18 | if not input_large: 19 | p = scale if multi else 1 20 | tp = p * patch_size 21 | ip = tp // scale 22 | else: 23 | tp = patch_size 24 | ip = patch_size 25 | 26 | ix = random.randrange(0, iw - ip + 1) 27 | iy = random.randrange(0, ih - ip + 1) 28 | 29 | if not input_large: 30 | tx, ty = scale * ix, scale * iy 31 | else: 32 | tx, ty = ix, iy 33 | 34 | ret = [ 35 | args[0][iy:iy + ip, ix:ix + ip, :], 36 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 37 | ] 38 | 39 | return ret 40 | 41 | def set_channel(*args, n_channels=3): 42 | def _set_channel(img): 43 | if img.ndim == 2: 44 | img = np.expand_dims(img, axis=2) 45 | 46 | c = img.shape[2] 47 | if n_channels == 1 and c == 3: 48 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 49 | elif n_channels == 3 and c == 1: 50 | img = np.concatenate([img] * n_channels, 2) 51 | 52 | return img 53 | 54 | return [_set_channel(a) for a in args] 55 | 56 | def np2Tensor(*args, rgb_range=255): 57 | def _np2Tensor(img): 58 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 59 | tensor = torch.from_numpy(np_transpose).float() 60 | tensor.mul_(rgb_range / 255) 61 | 62 | return tensor 63 | 64 | return [_np2Tensor(a) for a in args] 65 | 66 | def augment(*args, hflip=True, rot=True): 67 | hflip = hflip and random.random() < 0.5 68 | vflip = rot and random.random() < 0.5 69 | rot90 = rot and random.random() < 0.5 70 | 71 | def _augment(img): 72 | if hflip: img = img[:, ::-1, :] 73 | if vflip: img = img[::-1, :, :] 74 | if rot90: img = img.transpose(1, 0, 2) 75 | 76 | return img 77 | 78 | return [_augment(a) for a in args] 79 | 80 | def is_image_file(filename, extension): 81 | return filename.endswith(extension) 82 | 83 | class DatasetFromFolder(data.Dataset): 84 | def __init__(self, image_dir, extension = ".png"): 85 | super(DatasetFromFolder, self).__init__() 86 | self.image_dir = image_dir 87 | self.image_filenames = [x for x in sorted(os.listdir(image_dir)) if is_image_file(x, extension)] 88 | self.extension = extension 89 | 90 | def __getitem__(self, index): 91 | #return imageio.imread(os.path.join(self.image_dir, self.image_filenames[index])), self.image_filenames[index] 92 | filename = self.image_filenames[index] 93 | filename = filename.rstrip(self.extension) 94 | return PIL.Image.fromarray(imageio.imread(os.path.join(self.image_dir, self.image_filenames[index]))), filename 95 | 96 | def __len__(self): 97 | return len(self.image_filenames) -------------------------------------------------------------------------------- /PyTorch version/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /PyTorch version/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /PyTorch version/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /PyTorch version/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /PyTorch version/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # Below functions as used to prepare images 68 | def _scan(self): 69 | names_hr = sorted( 70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | ) 72 | names_lr = [[] for _ in self.scale] 73 | for f in names_hr: 74 | filename, _ = os.path.splitext(os.path.basename(f)) 75 | for si, s in enumerate(self.scale): 76 | names_lr[si].append(os.path.join( 77 | self.dir_lr, 'X{}/{}x{}{}'.format( 78 | s, filename, s, self.ext[1] 79 | ) 80 | )) 81 | 82 | return names_hr, names_lr 83 | 84 | def _set_filesystem(self, dir_data): 85 | self.apath = os.path.join(dir_data, self.name) 86 | self.dir_hr = os.path.join(self.apath, 'HR') 87 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 88 | if self.input_large: self.dir_lr += 'L' 89 | self.ext = ('.png', '.png') 90 | 91 | def _check_and_load(self, ext, img, f, verbose=True): 92 | if not os.path.isfile(f) or ext.find('reset') >= 0: 93 | if verbose: 94 | print('Making a binary: {}'.format(f)) 95 | with open(f, 'wb') as _f: 96 | pickle.dump(imageio.imread(img), _f) 97 | 98 | def __getitem__(self, idx): 99 | lr, hr, filename = self._load_file(idx) 100 | pair = self.get_patch(lr, hr) 101 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 102 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 103 | 104 | return pair_t[0], pair_t[1], filename 105 | 106 | def __len__(self): 107 | if self.train: 108 | return len(self.images_hr) * self.repeat 109 | else: 110 | return len(self.images_hr) 111 | 112 | def _get_index(self, idx): 113 | if self.train: 114 | return idx % len(self.images_hr) 115 | else: 116 | return idx 117 | 118 | def _load_file(self, idx): 119 | idx = self._get_index(idx) 120 | f_hr = self.images_hr[idx] 121 | f_lr = self.images_lr[self.idx_scale][idx] 122 | 123 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 124 | if self.args.ext == 'img' or self.benchmark: 125 | hr = imageio.imread(f_hr) 126 | lr = imageio.imread(f_lr) 127 | elif self.args.ext.find('sep') >= 0: 128 | with open(f_hr, 'rb') as _f: 129 | hr = pickle.load(_f) 130 | with open(f_lr, 'rb') as _f: 131 | lr = pickle.load(_f) 132 | 133 | return lr, hr, filename 134 | 135 | def get_patch(self, lr, hr): 136 | scale = self.scale[self.idx_scale] 137 | 138 | if self.train: 139 | lr, hr = common.get_patch( 140 | lr, hr, 141 | patch_size=self.args.patch_size, 142 | scale=scale, 143 | multi=(len(self.scale) > 1), 144 | input_large=self.input_large 145 | ) 146 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 147 | else: 148 | ih, iw = lr.shape[:2] 149 | hr = hr[0:ih * scale, 0:iw * scale] 150 | return lr, hr 151 | 152 | def set_scale(self, idx_scale): 153 | if not self.input_large: 154 | self.idx_scale = idx_scale 155 | else: 156 | self.idx_scale = random.randint(0, len(self.scale) - 1) 157 | 158 | -------------------------------------------------------------------------------- /PyTorch version/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /PyTorch version/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /PyTorch version/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/loss/__init__.py -------------------------------------------------------------------------------- /PyTorch version/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utils.utility as utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if torch.cuda.is_available(): 19 | self.dis.cuda() 20 | if gan_type == 'WGAN_GP': 21 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 22 | optim_dict = { 23 | 'optimizer': 'ADAM', 24 | 'betas': (0, 0.9), 25 | 'epsilon': 1e-8, 26 | 'lr': 1e-5, 27 | 'weight_decay': args.weight_decay, 28 | 'decay': args.decay, 29 | 'gamma': args.gamma 30 | } 31 | optim_args = SimpleNamespace(**optim_dict) 32 | else: 33 | optim_args = args 34 | 35 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 36 | 37 | def forward(self, fake, real): 38 | # updating discriminator... 39 | self.loss = 0 40 | fake_detach = fake.detach() # do not backpropagate through G 41 | for _ in range(self.gan_k): 42 | self.optimizer.zero_grad() 43 | # d: B x 1 tensor 44 | d_fake = self.dis(fake_detach) 45 | d_real = self.dis(real) 46 | retain_graph = False 47 | if self.gan_type == 'GAN': 48 | loss_d = self.bce(d_real, d_fake) 49 | elif self.gan_type.find('WGAN') >= 0: 50 | loss_d = (d_fake - d_real).mean() 51 | if self.gan_type.find('GP') >= 0: 52 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 53 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 54 | hat.requires_grad = True 55 | d_hat = self.dis(hat) 56 | gradients = torch.autograd.grad( 57 | outputs=d_hat.sum(), inputs=hat, 58 | retain_graph=True, create_graph=True, only_inputs=True 59 | )[0] 60 | gradients = gradients.view(gradients.size(0), -1) 61 | gradient_norm = gradients.norm(2, dim=1) 62 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 63 | loss_d += gradient_penalty 64 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 65 | elif self.gan_type == 'RGAN': 66 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 67 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 68 | loss_d = self.bce(better_real, better_fake) 69 | retain_graph = True 70 | 71 | # Discriminator update 72 | self.loss += loss_d.item() 73 | loss_d.backward(retain_graph=retain_graph) 74 | self.optimizer.step() 75 | 76 | if self.gan_type == 'WGAN': 77 | for p in self.dis.parameters(): 78 | p.data.clamp_(-1, 1) 79 | 80 | self.loss /= self.gan_k 81 | 82 | # updating generator... 83 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 84 | if self.gan_type == 'GAN': 85 | label_real = torch.ones_like(d_fake_bp) 86 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 87 | elif self.gan_type.find('WGAN') >= 0: 88 | loss_g = -d_fake_bp.mean() 89 | elif self.gan_type == 'RGAN': 90 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 91 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 92 | loss_g = self.bce(better_fake, better_real) 93 | 94 | # Generator loss 95 | return loss_g 96 | 97 | def state_dict(self, *args, **kwargs): 98 | state_discriminator = self.dis.state_dict(*args, **kwargs) 99 | state_optimizer = self.optimizer.state_dict() 100 | 101 | return dict(**state_discriminator, **state_optimizer) 102 | 103 | def bce(self, real, fake): 104 | label_real = torch.ones_like(real) 105 | label_fake = torch.zeros_like(fake) 106 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 107 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 108 | bce_loss = bce_real + bce_fake 109 | return bce_loss 110 | 111 | # Some references 112 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 113 | # OR 114 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 115 | -------------------------------------------------------------------------------- /PyTorch version/loss/contrast_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.nn.functional as fnn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from torchvision import models 8 | 9 | import math 10 | 11 | class Vgg19(torch.nn.Module): 12 | def __init__(self, requires_grad=False): 13 | super(Vgg19, self).__init__() 14 | vgg_pretrained_features = models.vgg19(pretrained=True).features 15 | self.slice1 = torch.nn.Sequential() 16 | self.slice2 = torch.nn.Sequential() 17 | self.slice3 = torch.nn.Sequential() 18 | self.slice4 = torch.nn.Sequential() 19 | self.slice5 = torch.nn.Sequential() 20 | for x in range(2): 21 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(2, 7): 23 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(7, 12): 25 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(12, 21): 27 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 28 | for x in range(21, 30): 29 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 30 | if not requires_grad: 31 | for param in self.parameters(): 32 | param.requires_grad = False 33 | 34 | def forward(self, X): 35 | h_relu1 = self.slice1(X) 36 | h_relu2 = self.slice2(h_relu1) 37 | h_relu3 = self.slice3(h_relu2) 38 | h_relu4 = self.slice4(h_relu3) 39 | h_relu5 = self.slice5(h_relu4) 40 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 41 | return out 42 | 43 | 44 | class ContrastLoss(nn.Module): 45 | def __init__(self, weights, d_func, t_detach = False, is_one=False): 46 | super(ContrastLoss, self).__init__() 47 | self.vgg = Vgg19().cuda() 48 | self.l1 = nn.L1Loss() 49 | self.weights = weights 50 | self.d_func = d_func 51 | self.is_one = is_one 52 | self.t_detach = t_detach 53 | 54 | def forward(self, teacher, student, neg, blur_neg=None): 55 | teacher_vgg, student_vgg, neg_vgg, = self.vgg(teacher), self.vgg(student), self.vgg(neg) 56 | blur_neg_vgg = None 57 | if blur_neg is not None: 58 | blur_neg_vgg = self.vgg(blur_neg) 59 | if self.d_func == "L1": 60 | self.forward_func = self.L1_forward 61 | elif self.d_func == 'cos': 62 | self.forward_func = self.cos_forward 63 | 64 | return self.forward_func(teacher_vgg, student_vgg, neg_vgg, blur_neg_vgg) 65 | 66 | def L1_forward(self, teacher, student, neg, blur_neg=None): 67 | """ 68 | :param teacher: 5*batchsize*color*patchsize*patchsize 69 | :param student: 5*batchsize*color*patchsize*patchsize 70 | :param neg: 5*negnum*color*patchsize*patchsize 71 | :return: 72 | """ 73 | loss = 0 74 | for i in range(len(teacher)): 75 | neg_i = neg[i].unsqueeze(0) 76 | neg_i = neg_i.repeat(student[i].shape[0], 1, 1, 1, 1) 77 | neg_i = neg_i.permute(1, 0, 2, 3, 4)### batchsize*negnum*color*patchsize*patchsize 78 | if blur_neg is not None: 79 | blur_neg_i = blur_neg[i].unsqueeze(0) 80 | neg_i = torch.cat((neg_i, blur_neg_i)) 81 | 82 | 83 | if self.t_detach: 84 | d_ts = self.l1(teacher[i].detach(), student[i]) 85 | else: 86 | d_ts = self.l1(teacher[i], student[i]) 87 | d_sn = torch.mean(torch.abs(neg_i.detach() - student[i]).sum(0)) 88 | 89 | contrastive = d_ts / (d_sn + 1e-7) 90 | loss += self.weights[i] * contrastive 91 | return loss 92 | 93 | 94 | def cos_forward(self, teacher, student, neg, blur_neg=None): 95 | loss = 0 96 | for i in range(len(teacher)): 97 | neg_i = neg[i].unsqueeze(0) 98 | neg_i = neg_i.repeat(student[i].shape[0], 1, 1, 1, 1) 99 | neg_i = neg_i.permute(1, 0, 2, 3, 4) 100 | if blur_neg is not None: 101 | blur_neg_i = blur_neg[i].unsqueeze(0) 102 | neg_i = torch.cat((neg_i, blur_neg_i)) 103 | 104 | if self.t_detach: 105 | d_ts = torch.cosine_similarity(teacher[i].detach(), student[i], dim=0).mean() 106 | else: 107 | d_ts = torch.cosine_similarity(teacher[i], student[i], dim=0).mean() 108 | d_sn = self.calc_cos_stu_neg(student[i], neg_i.detach()) 109 | 110 | contrastive = -torch.log(torch.exp(d_ts)/(torch.exp(d_sn)+1e-7)) 111 | loss += self.weights[i] * contrastive 112 | return loss 113 | 114 | def calc_cos_stu_neg(self, stu, neg): 115 | n = stu.shape[0] 116 | m = neg.shape[0] 117 | 118 | stu = stu.view(n, -1) 119 | neg = neg.view(m, n, -1) 120 | # normalize 121 | stu = F.normalize(stu, p=2, dim=1) 122 | neg = F.normalize(neg, p=2, dim=2) 123 | # multiply 124 | d_sn = torch.mean((stu * neg).sum(0)) 125 | return d_sn 126 | -------------------------------------------------------------------------------- /PyTorch version/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Discriminator(nn.Module): 7 | ''' 8 | output is not normalized 9 | ''' 10 | def __init__(self, args): 11 | super(Discriminator, self).__init__() 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | in_channels = args.n_colors 14 | out_channels = 64 15 | depth = 7 16 | 17 | def _block(_in_channels, _out_channels, stride=1): 18 | return nn.Sequential( 19 | nn.Conv2d( 20 | _in_channels, 21 | _out_channels, 22 | 3, 23 | padding=1, 24 | stride=stride, 25 | bias=False 26 | ), 27 | nn.BatchNorm2d(_out_channels), 28 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 29 | ) 30 | 31 | m_features = [_block(in_channels, out_channels)] 32 | for i in range(depth): 33 | in_channels = out_channels 34 | if i % 2 == 1: 35 | stride = 1 36 | out_channels *= 2 37 | else: 38 | stride = 2 39 | m_features.append(_block(in_channels, out_channels, stride=stride)) 40 | 41 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 42 | m_classifier = [ 43 | nn.Linear(out_channels * patch_size**2, 1024), 44 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 45 | nn.Linear(1024, 1) 46 | ] 47 | 48 | self.features = nn.Sequential(*m_features) 49 | self.classifier = nn.Sequential(*m_classifier) 50 | 51 | def forward(self, x): 52 | features = self.features(x) 53 | output = self.classifier(features.view(features.size(0), -1)) 54 | 55 | return output 56 | 57 | -------------------------------------------------------------------------------- /PyTorch version/loss/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.vgg import vgg19 3 | import torch 4 | 5 | 6 | class PerceptualLoss(nn.Module): 7 | def __init__(self): 8 | super(PerceptualLoss, self).__init__() 9 | 10 | vgg = vgg19(pretrained=True) 11 | loss_network = nn.Sequential(*list(vgg.features)[:35]).eval() 12 | for param in loss_network.parameters(): 13 | param.requires_grad = False 14 | self.loss_network = loss_network 15 | if torch.cuda.is_available(): 16 | self.loss_network.cuda() 17 | self.l1_loss = nn.L1Loss() 18 | 19 | def forward(self, high_resolution, fake_high_resolution): 20 | perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution)) 21 | return perception_loss 22 | -------------------------------------------------------------------------------- /PyTorch version/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /PyTorch version/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import utils.utility as utility 5 | import data 6 | import loss 7 | from option import args 8 | from trainer.slim_contrast_trainer import SlimContrastiveTrainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | from signal import signal, SIGPIPE, SIG_DFL, SIG_IGN 14 | signal(SIGPIPE, SIG_IGN) 15 | 16 | if __name__ == '__main__': 17 | loader = data.Data(args) 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | if args.seperate: 20 | t = SeperateContrastiveTrainer(args, loader, device) 21 | else: 22 | t = SlimContrastiveTrainer(args, loader, device) 23 | # t = SlimContrastiveTrainer(args, loader, device, neg_loader) 24 | 25 | if args.model_stat: 26 | total_param = 0 27 | if args.seperate: 28 | for name, param in t.s_model.named_parameters(): 29 | total_param += torch.numel(param) 30 | else: 31 | for name, param in t.model.named_parameters(): 32 | # print(name, ' ', torch.numel(param)) 33 | total_param += torch.numel(param) 34 | print(total_param) 35 | 36 | 37 | 38 | if not args.test_only: 39 | t.train() 40 | else: 41 | t.test(args.stu_width_mult) 42 | checkpoint.done() -------------------------------------------------------------------------------- /PyTorch version/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/model/__init__.py -------------------------------------------------------------------------------- /PyTorch version/model/carn.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class BasicBlock(nn.Module): 7 | def __init__(self, 8 | in_channels, out_channels, 9 | ksize=3, stride=1, pad=1): 10 | super(BasicBlock, self).__init__() 11 | 12 | self.conv = nn.Conv2d(in_channels, out_channels, ksize, stride, pad) 13 | self.act = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x, width_mult): 16 | out = common.SlimModule(x, self.conv, width_mult) 17 | out = self.act(out) 18 | return out 19 | 20 | class Block(nn.Module): 21 | def __init__(self, n_feats): 22 | super(Block, self).__init__() 23 | 24 | self.act = nn.ReLU(inplace=True) 25 | 26 | self.b1 = common.ResidualBlock(n_feats, 3, act=self.act, res_scale=1) 27 | self.b2 = common.ResidualBlock(n_feats, 3, self.act, 1) 28 | self.b3 = common.ResidualBlock(n_feats, 3, self.act, 1) 29 | self.c1 = BasicBlock(n_feats*2, n_feats, 1, 1, 0) 30 | self.c2 = BasicBlock(n_feats*3, n_feats, 1, 1, 0) 31 | self.c3 = BasicBlock(n_feats*4, n_feats, 1, 1, 0) 32 | 33 | def forward(self, x, width_mult=1): 34 | c0 = o0 = x 35 | 36 | b1 = self.b1(o0, width_mult) 37 | c1 = torch.cat([c0, b1], dim=1) 38 | o1 = self.c1(c1, width_mult) 39 | 40 | b2 = self.b2(o1, width_mult) 41 | c2 = torch.cat([c1, b2], dim=1) 42 | o2 = self.c2(c2, width_mult) 43 | 44 | b3 = self.b3(o2, width_mult) 45 | c3 = torch.cat([c2, b3], dim=1) 46 | o3 = self.c3(c3, width_mult) 47 | 48 | return o3 49 | 50 | class CARN(nn.Module): 51 | def __init__(self, args): 52 | super(CARN, self).__init__() 53 | 54 | scale = args.scale[0] 55 | self.act = nn.ReLU(inplace=True) 56 | self.n_feat = args.n_feats 57 | 58 | self.sub_mean = common.MeanShift(args.rgb_range) 59 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 60 | 61 | self.entry = nn.Conv2d(3, self.n_feat, 3, 1, 1) 62 | self.b1 = Block(self.n_feat) 63 | self.b2 = Block(self.n_feat) 64 | self.b3 = Block(self.n_feat) 65 | self.c1 = BasicBlock(self.n_feat*2, self.n_feat, 1, 1, 0) 66 | self.c2 = BasicBlock(self.n_feat*3, self.n_feat, 1, 1, 0) 67 | self.c3 = BasicBlock(self.n_feat*4, self.n_feat, 1, 1, 0) 68 | 69 | self.upsample = common.Upsampler(scale, self.n_feat) 70 | self.exit = nn.Conv2d(self.n_feat, 3, 3, 1, 1) 71 | 72 | def forward(self, x, width_mult=1): 73 | x = self.sub_mean(x) 74 | nf = int(self.n_feat * width_mult) 75 | weight = self.entry.weight[:nf, :3, :, :] 76 | bias = self.entry.bias[:nf] 77 | x = nn.functional.conv2d(x, weight, bias, stride=self.entry.stride, padding=self.entry.padding) 78 | c0 = o0 = x 79 | 80 | b1 = self.b1(o0, width_mult) 81 | c1 = torch.cat([c0, b1], dim=1) 82 | o1 = self.c1(c1, width_mult) 83 | 84 | b2 = self.b2(o1, width_mult) 85 | c2 = torch.cat([c1, b2], dim=1) 86 | o2 = self.c2(c2, width_mult) 87 | 88 | b3 = self.b3(o2, width_mult) 89 | c3 = torch.cat([c2, b3], dim=1) 90 | o3 =self.c3(c3, width_mult) 91 | 92 | out = self.upsample(o3, width_mult) 93 | weight = self.exit.weight[:3, :nf, :, :] 94 | bias = self.exit.bias[:3] 95 | out = nn.functional.conv2d(out, weight, bias, stride=self.exit.stride, padding=self.exit.padding) 96 | out = self.add_mean(out) 97 | 98 | return out 99 | 100 | -------------------------------------------------------------------------------- /PyTorch version/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MeanShift(nn.Conv2d): 9 | def __init__( 10 | self, rgb_range, 11 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 12 | 13 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 14 | std = torch.Tensor(rgb_std) 15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 16 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 17 | for p in self.parameters(): 18 | p.requires_grad = False 19 | 20 | class ResidualBlock(nn.Module): 21 | def __init__(self, n_feats, kernel_size, act, res_scale): 22 | super(ResidualBlock, self).__init__() 23 | self.n_feats = n_feats 24 | self.res_scale = res_scale 25 | self.kernel_size = kernel_size 26 | 27 | self.conv1 = nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=1) 28 | self.act = act 29 | self.conv2 = nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=1) 30 | 31 | def forward(self, x, width_mult=1): 32 | width = int(self.n_feats* width_mult) 33 | weight = self.conv1.weight[:width, :width, :, :] 34 | bias = self.conv1.bias[:width] 35 | residual = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2)) 36 | residual = self.act(residual) 37 | weight = self.conv2.weight[:width, :width, :, :] 38 | bias = self.conv2.bias[:width] 39 | residual = nn.functional.conv2d(residual, weight, bias, padding=(self.kernel_size//2)) 40 | 41 | return x + residual.mul(self.res_scale) 42 | 43 | class Upsampler(nn.Sequential): 44 | def __init__(self, scale_factor, nf): 45 | super(Upsampler, self).__init__() 46 | block = [] 47 | self.nf = nf 48 | self.scale = scale_factor 49 | 50 | if scale_factor == 3: 51 | block += [ 52 | nn.Conv2d(nf, nf * 9, 3, padding=1, bias=True) 53 | ] 54 | self.pixel_shuffle = nn.PixelShuffle(3) 55 | else: 56 | self.block_num = scale_factor // 2 57 | self.pixel_shuffle = nn.PixelShuffle(2) 58 | #self.act = nn.ReLU() 59 | 60 | for _ in range(self.block_num): 61 | block += [ 62 | nn.Conv2d(nf, nf * (2 ** 2), 3, padding=1, bias=True) 63 | ] 64 | self.blocks = nn.ModuleList(block) 65 | 66 | def forward(self, x, width_mult=1): 67 | res = x 68 | nf = self.nf 69 | if self.scale == 3: 70 | width = int(width_mult * nf) 71 | width9 = width * 9 72 | for block in self.blocks: 73 | weight = block.weight[:width9, :width, :, :] 74 | bias = block.bias[:width9] 75 | res = nn.functional.conv2d(res, weight, bias, padding=1) 76 | res = self.pixel_shuffle(res) 77 | else: 78 | for block in self.blocks: 79 | width = int(width_mult * nf) 80 | width4 = width * 4 81 | weight = block.weight[:width4, :width, :, :] 82 | bias = block.bias[:width4] 83 | res = nn.functional.conv2d(res, weight, bias, padding=1) 84 | res = self.pixel_shuffle(res) 85 | #res = self.act(res) 86 | 87 | return res 88 | 89 | def SlimModule(input, module, width_mult): 90 | weight = module.weight 91 | out_ch, in_ch = weight.shape[:2] 92 | out_ch = int(out_ch * width_mult) 93 | in_ch = int(in_ch * width_mult) 94 | weight = weight[:out_ch, :in_ch, :, :] 95 | bias = module.bias 96 | if bias is not None: 97 | bias = module.bias[:out_ch] 98 | return nn.functional.conv2d(input, weight, bias, stride=module.stride, padding=module.padding) -------------------------------------------------------------------------------- /PyTorch version/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class EDSR(nn.Module): 6 | def __init__(self, args): 7 | super(EDSR, self).__init__() 8 | 9 | self.n_colors = args.n_colors 10 | n_resblocks = args.n_resblocks 11 | self.n_feats = args.n_feats 12 | self.kernel_size = 3 13 | scale = args.scale[0] 14 | act = nn.ReLU(True) 15 | self.sub_mean = common.MeanShift(args.rgb_range) 16 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 17 | 18 | self.head = nn.Conv2d(args.n_colors, self.n_feats, self.kernel_size, padding=(self.kernel_size//2), bias=True) 19 | 20 | m_body = [ 21 | common.ResidualBlock( 22 | self.n_feats, self.kernel_size, act=act, res_scale=args.res_scale 23 | ) for _ in range(n_resblocks) 24 | ] 25 | self.body = nn.ModuleList(m_body) 26 | self.body_conv = nn.Conv2d(self.n_feats, self.n_feats, self.kernel_size, padding=(self.kernel_size//2), bias=True) 27 | 28 | self.upsampler = common.Upsampler(scale, self.n_feats) 29 | self.tail_conv = nn.Conv2d(self.n_feats, args.n_colors, self.kernel_size, padding=(self.kernel_size//2), bias=True) 30 | 31 | def forward(self, x, width_mult=1): 32 | feature_width = int(self.n_feats * width_mult) 33 | 34 | x = self.sub_mean(x) 35 | weight = self.head.weight[:feature_width, :self.n_colors, :, :] 36 | bias = self.head.bias[:feature_width] 37 | x = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2)) 38 | 39 | residual = x 40 | for block in self.body: 41 | residual = block(residual, width_mult) 42 | weight = self.body_conv.weight[:feature_width, :feature_width, :, :] 43 | bias = self.body_conv.bias[:feature_width] 44 | residual = nn.functional.conv2d(residual, weight, bias, padding=(self.kernel_size//2)) 45 | residual += x 46 | 47 | x = self.upsampler(residual, width_mult) 48 | weight = self.tail_conv.weight[:self.n_colors, :feature_width, :, :] 49 | bias = self.tail_conv.bias[:self.n_colors] 50 | x = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2)) 51 | x = self.add_mean(x) 52 | 53 | return x 54 | 55 | def load_state_dict(self, state_dict, strict=True): 56 | own_state = self.state_dict() 57 | for name, param in state_dict.items(): 58 | if name in own_state: 59 | if isinstance(param, nn.Parameter): 60 | param = param.data 61 | try: 62 | own_state[name].copy_(param) 63 | except Exception: 64 | if name.find('tail') == -1: 65 | raise RuntimeError('While copying the parameter named {}, ' 66 | 'whose dimensions in the model are {} and ' 67 | 'whose dimensions in the checkpoint are {}.' 68 | .format(name, own_state[name].size(), param.size())) 69 | elif strict: 70 | if name.find('tail') == -1: 71 | raise KeyError('unexpected key "{}" in state_dict' 72 | .format(name)) 73 | -------------------------------------------------------------------------------- /PyTorch version/model/rcan.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | # class SlimConv(nn.Module): 6 | # def __init__(self, in_channel, out_channel, kernel_size, padding=(kernel_size//2), bias=True): 7 | # self.in_channel = in_channel 8 | # self.out_channel = out_channel 9 | # self.padding = padding 10 | # self.bias = bias 11 | # 12 | # self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, bias=bias) 13 | # 14 | # def forward(self, x, width_mult=1): 15 | # in_channel_width = int(width_mult * self.in_channel) 16 | # out_channel_width = int(width_mult * self.out_channel) 17 | # weight = self.conv.weight[:out_channel_width, :in_channel_width, :, :] 18 | # bias = None 19 | # if self.bias: 20 | # bias = self.conv.bias[:out_channel_width] 21 | # 22 | # return nn.functional.conv2d(x, weight, bias, self.conv.stride) 23 | 24 | class CALayer(nn.Module): 25 | def __init__(self, channel, reduction=16): 26 | super(CALayer, self).__init__() 27 | 28 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 29 | self.conv_du = nn.ModuleList([ 30 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 31 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 32 | ]) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.sigmoid = nn.Sigmoid() 35 | 36 | def forward(self, x, width_mult=1): 37 | y = self.avg_pool(x) 38 | 39 | module = getattr(self.conv_du, '0') 40 | y = common.SlimModule(y, module, width_mult) 41 | y = self.relu(y) 42 | 43 | module = getattr(self.conv_du, '1') 44 | y = common.SlimModule(y, module, width_mult) 45 | y = self.sigmoid(y) 46 | 47 | return x * y 48 | 49 | class RCAB(nn.Module): 50 | def __init__(self, n_feats, kernel_size, reduction, bias=True, act=nn.ReLU(True), res_scale=1): 51 | super(RCAB, self).__init__() 52 | 53 | modules_body = [] 54 | for i in range(2): 55 | modules_body.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=bias)) 56 | self.caLayer = CALayer(n_feats, reduction) 57 | self.body = nn.ModuleList(modules_body) 58 | self.act = act 59 | self.res_scale = res_scale 60 | 61 | def forward(self, x, width_mult=1): 62 | module = self.body[0] 63 | res = common.SlimModule(x, module, width_mult) 64 | res = self.act(res) 65 | module = self.body[1] 66 | res = common.SlimModule(res, module, width_mult) 67 | res = self.caLayer(res, width_mult) 68 | res += x 69 | return res 70 | 71 | class ResidualGroup(nn.Module): 72 | def __init__(self, n_feats, kernel_size, reduction, act, res_scale, n_resblocks): 73 | super(ResidualGroup, self).__init__() 74 | 75 | modules_body = [ 76 | RCAB( 77 | n_feats, kernel_size, reduction, bias=True, act=nn.ReLU(True), res_scale=1 78 | ) for _ in range(n_resblocks) 79 | ] 80 | self.body = nn.ModuleList(modules_body) 81 | self.conv = nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=True) 82 | 83 | def forward(self, x, width_mult): 84 | res = x 85 | for module in self.body: 86 | res = module(res, width_mult) 87 | res = common.SlimModule(res, self.conv, width_mult) 88 | res += x 89 | return res 90 | 91 | 92 | class RCAN(nn.Module): 93 | def __init__(self, args): 94 | super(RCAN, self).__init__() 95 | 96 | self.n_colors = args.n_colors 97 | n_resgroups = args.n_resgroups 98 | n_resblocks = args.n_resblocks 99 | n_feats = args.n_feats 100 | kernel_size = 3 101 | reduction = args.reduction 102 | scale = args.scale[0] 103 | act = nn.ReLU(True) 104 | 105 | self.sub_mean = common.MeanShift(args.rgb_range) 106 | 107 | self.head_conv = nn.Conv2d(args.n_colors, n_feats, kernel_size, padding=(kernel_size//2), bias=True) 108 | 109 | modules_body = [ 110 | ResidualGroup( 111 | n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks 112 | ) for _ in range(n_resgroups) 113 | ] 114 | # modules_body = [ 115 | # common.ResidualBlock( 116 | # n_feats, kernel_size, act, args.res_scale 117 | # ) for _ in range(n_resgroups) 118 | # ] 119 | self.body = nn.ModuleList(modules_body) 120 | self.body_conv = nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=True) 121 | 122 | self.upsampler = common.Upsampler(scale, n_feats) 123 | self.tail_conv = nn.Conv2d(n_feats, args.n_colors, kernel_size, padding=(kernel_size // 2), 124 | bias=True) 125 | 126 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 127 | 128 | def forward(self, x, width_mult=1): 129 | x = self.sub_mean(x) 130 | weight = self.head_conv.weight 131 | n_feats = weight.shape[0] 132 | out_ch = int(n_feats * width_mult) 133 | weight = weight[:out_ch, :self.n_colors, :, :] 134 | bias = self.head_conv.bias[:out_ch] 135 | x = nn.functional.conv2d(x, weight, bias, stride=self.head_conv.stride, padding=self.head_conv.padding) 136 | 137 | res = x 138 | for module in self.body: 139 | res = module(res, width_mult) 140 | res = common.SlimModule(res, self.body_conv, width_mult) 141 | res += x 142 | 143 | x = self.upsampler(res, width_mult) 144 | weight = self.tail_conv.weight[:self.n_colors, :out_ch, :, :] 145 | bias = self.tail_conv.bias[:self.n_colors] 146 | x = nn.functional.conv2d(x, weight, bias, stride=self.tail_conv.stride, padding=self.tail_conv.padding) 147 | x = self.add_mean(x) 148 | 149 | return x 150 | 151 | def load_state_dict(self, state_dict, strict=False): 152 | own_state = self.state_dict() 153 | for name, param in state_dict.items(): 154 | if name in own_state: 155 | if isinstance(param, nn.Parameter): 156 | param = param.data 157 | try: 158 | own_state[name].copy_(param) 159 | except Exception: 160 | if name.find('tail') >= 0: 161 | print('Replace pre-trained upsampler to new one...') 162 | else: 163 | raise RuntimeError('While copying the parameter named {}, ' 164 | 'whose dimensions in the model are {} and ' 165 | 'whose dimensions in the checkpoint are {}.' 166 | .format(name, own_state[name].size(), param.size())) 167 | elif strict: 168 | if name.find('tail') == -1: 169 | raise KeyError('unexpected key "{}" in state_dict' 170 | .format(name)) 171 | 172 | if strict: 173 | missing = set(own_state.keys()) - set(state_dict.keys()) 174 | if len(missing) > 0: 175 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 176 | -------------------------------------------------------------------------------- /PyTorch version/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 4 | 5 | parser.add_argument('--debug', action='store_true', 6 | help='Enables debug mode') 7 | parser.add_argument('--template', default='.', 8 | help='You can set various templates in option.py') 9 | 10 | # Hardware specifications 11 | parser.add_argument('--n_threads', type=int, default=6, 12 | help='number of threads for data loading') 13 | parser.add_argument('--cpu', action='store_true', 14 | help='use cpu only') 15 | parser.add_argument('--n_GPUs', type=int, default=1, 16 | help='number of GPUs') 17 | parser.add_argument('--seed', type=int, default=1, 18 | help='random seed') 19 | 20 | # Data specifications 21 | parser.add_argument('--dir_data', type=str, default='../../../dataset', 22 | help='dataset directory') 23 | parser.add_argument('--dir_demo', type=str, default='../test', 24 | help='demo image directory') 25 | parser.add_argument('--data_train', type=str, default='DIV2K', 26 | help='train dataset name') 27 | parser.add_argument('--data_test', type=str, default='DIV2K', 28 | help='test dataset name') 29 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 30 | help='train/test data range') 31 | parser.add_argument('--ext', type=str, default='sep', 32 | help='dataset file extension') 33 | parser.add_argument('--scale', type=str, default='4', 34 | help='super resolution scale') 35 | parser.add_argument('--patch_size', type=int, default=192, 36 | help='output patch size') 37 | parser.add_argument('--rgb_range', type=int, default=255, 38 | help='maximum value of RGB') 39 | parser.add_argument('--n_colors', type=int, default=3, 40 | help='number of color channels to use') 41 | parser.add_argument('--chop', action='store_true', 42 | help='enable memory-efficient forward') 43 | parser.add_argument('--no_augment', action='store_true', 44 | help='do not use data augmentation') 45 | 46 | # Model specifications 47 | parser.add_argument('--model', type=str, default='EDSR', 48 | help='model name') 49 | 50 | parser.add_argument('--act', type=str, default='relu', 51 | help='activation function') 52 | parser.add_argument('--pre_train', type=str, default='', 53 | help='pre-trained model directory') 54 | parser.add_argument('--extend', type=str, default='.', 55 | help='pre-trained model directory') 56 | parser.add_argument('--n_resblocks', type=int, default=32, 57 | help='number of residual blocks') 58 | parser.add_argument('--n_feats', type=int, default=256, 59 | help='number of feature maps') 60 | parser.add_argument('--res_scale', type=float, default=0.1, 61 | help='residual scaling') 62 | parser.add_argument('--shift_mean', default=True, 63 | help='subtract pixel mean from the input') 64 | parser.add_argument('--dilation', action='store_true', 65 | help='use dilated convolution') 66 | parser.add_argument('--precision', type=str, default='single', 67 | choices=('single', 'half'), 68 | help='FP precision for test (single | half)') 69 | 70 | # Option for Residual dense network (RDN) 71 | parser.add_argument('--G0', type=int, default=64, 72 | help='default number of filters. (Use in RDN)') 73 | parser.add_argument('--RDNkSize', type=int, default=3, 74 | help='default kernel size. (Use in RDN)') 75 | parser.add_argument('--RDNconfig', type=str, default='B', 76 | help='parameters config of RDN. (Use in RDN)') 77 | 78 | # Option for Residual channel attention network (RCAN) 79 | parser.add_argument('--n_resgroups', type=int, default=10, 80 | help='number of residual groups') 81 | parser.add_argument('--reduction', type=int, default=16, 82 | help='number of feature maps reduction') 83 | 84 | # Training specifications 85 | parser.add_argument('--reset', action='store_true', 86 | help='reset the training') 87 | parser.add_argument('--test_every', type=int, default=1000, 88 | help='do test per every N batches') 89 | parser.add_argument('--epochs', type=int, default=300, 90 | help='number of epochs to train') 91 | parser.add_argument('--batch_size', type=int, default=16, 92 | help='input batch size for training') 93 | parser.add_argument('--split_batch', type=int, default=1, 94 | help='split the batch into smaller chunks') 95 | parser.add_argument('--self_ensemble', action='store_true', 96 | help='use self-ensemble method for test') 97 | parser.add_argument('--test_only', action='store_true', 98 | help='set this option to test the model') 99 | parser.add_argument('--gan_k', type=int, default=1, 100 | help='k value for adversarial loss') 101 | 102 | # Optimization specifications 103 | parser.add_argument('--lr', type=float, default=1e-4, 104 | help='learning rate') 105 | parser.add_argument('--decay', type=str, default='200', 106 | help='learning rate decay type') 107 | parser.add_argument('--gamma', type=float, default=0.5, 108 | help='learning rate decay factor for step decay') 109 | parser.add_argument('--optimizer', default='ADAM', 110 | choices=('SGD', 'ADAM', 'RMSprop'), 111 | help='optimizer to use (SGD | ADAM | RMSprop)') 112 | parser.add_argument('--momentum', type=float, default=0.9, 113 | help='SGD momentum') 114 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 115 | help='ADAM beta') 116 | parser.add_argument('--epsilon', type=float, default=1e-8, 117 | help='ADAM epsilon for numerical stability') 118 | parser.add_argument('--weight_decay', type=float, default=0, 119 | help='weight decay') 120 | parser.add_argument('--gclip', type=float, default=0, 121 | help='gradient clipping threshold (0 = no clipping)') 122 | 123 | # Loss specifications 124 | parser.add_argument('--loss', type=str, default='1*L1', 125 | help='loss function configuration') 126 | parser.add_argument('--skip_threshold', type=float, default='1e8', 127 | help='skipping batch that has large error') 128 | 129 | # Log specifications 130 | parser.add_argument('--save', type=str, default='test', 131 | help='file name to save') 132 | parser.add_argument('--load', type=str, default='', 133 | help='file name to load') 134 | parser.add_argument('--resume', type=int, default=0, 135 | help='resume from specific checkpoint') 136 | parser.add_argument('--save_models', action='store_true', 137 | help='save all intermediate models') 138 | parser.add_argument('--print_every', type=int, default=100, 139 | help='how many batches to wait before logging training status') 140 | parser.add_argument('--save_results', action='store_true', 141 | help='save output results') 142 | parser.add_argument('--save_gt', action='store_true', 143 | help='save low-resolution and high-resolution images together') 144 | 145 | parser.add_argument('--stu_width_mult', type=float, default=0.25, 146 | help='width_mult of student model') 147 | parser.add_argument('--bn', action='store_true', help='use bn in residual block') 148 | parser.add_argument('--slim_num', type=int, default=1, help='divide the network into {slim_part} parts') 149 | parser.add_argument('--full_width_mult', type=float, default=1, help='set full width mult') 150 | parser.add_argument('--batch_test', action='store_true', help='test a batch') 151 | parser.add_argument('--print_each', action='store_true', help='print result of every sample') 152 | parser.add_argument('--model_stat', action='store_true', help='count model params and flops') 153 | 154 | parser.add_argument('--content_loss_factor', type=float, default=1e-1, help='content loss factor') 155 | parser.add_argument('--feature_loss_factor', type=float, default=1, help='feature loss factor') 156 | parser.add_argument('--adversarial_loss_factor', type=float, default=5e-3, help='adversarial loss factor') 157 | 158 | parser.add_argument('--model_filename', type=str, default='', help='pre-train model filename') 159 | parser.add_argument('--model_str', type=str, default='', help='save the model as filename{model_str}') 160 | parser.add_argument('--teacher_model', type=str, default='output/model/edsr/baseline/edsr_x4_baseline.pth', 161 | help='load teacher model from {teacher_model}') 162 | 163 | parser.add_argument('--neg_num', type=int, default=8, 164 | help='negative samples number') 165 | 166 | parser.add_argument('--t_lambda', type=float, default=0, help='weight of l1(hr, teacher_sr)') 167 | parser.add_argument('--contra_lambda', type=float, default=1, help='weight of contra_loss') 168 | parser.add_argument('--ad_lambda', type=float, default=0, help='weight of adversarial loss') 169 | parser.add_argument('--percep_lambda', type=float, default=0, help='weight of perceptual loss') 170 | parser.add_argument('--kd_lambda', type=float, default=0, help='weight of kd loss') 171 | parser.add_argument('--vgg_weight', nargs='+', type=float, default=[1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0], 172 | help='weight of vgg features in contrastive loss') 173 | parser.add_argument('--d_func', type=str, default="L1", help='the distance function in contrastive loss') 174 | parser.add_argument('--mean_outside', action='store_true', help='calc mean for negative samples outside the contrast_loss') 175 | parser.add_argument('--t_l_remove', type=int, default=0, help='remove teacher loss @ epoch {t_l_remove}') 176 | parser.add_argument('--contrast_t_detach', action='store_true', help='detach teacher in contrast_loss') 177 | parser.add_argument('--gt_as_pos', action='store_true', help='use gt as positive sample') 178 | parser.add_argument('--blur_sigma', type=float, default=0, help='blur sigma of neg sample') 179 | parser.add_argument('--noise_sigma', type=float, default=0, help='noise sigma of neg sample') 180 | 181 | parser.add_argument('--seperate', action='store_true', help='seperate teacher and student') 182 | 183 | args = parser.parse_args() 184 | 185 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 186 | args.data_train = args.data_train.split('+') 187 | args.data_test = args.data_test.split('+') 188 | 189 | if args.epochs == 0: 190 | args.epochs = 1e8 191 | 192 | for arg in vars(args): 193 | if vars(args)[arg] == 'True': 194 | vars(args)[arg] = True 195 | elif vars(args)[arg] == 'False': 196 | vars(args)[arg] = False 197 | 198 | -------------------------------------------------------------------------------- /PyTorch version/trainer/slim_contrast_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | from glob import glob 5 | import datetime, time 6 | from importlib import import_module 7 | import numpy as np 8 | 9 | # import lpips 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torchvision.transforms as transforms 14 | from tensorboardX import SummaryWriter 15 | 16 | import utils.utility as utility 17 | from loss.contrast_loss import ContrastLoss 18 | from loss.adversarial import Adversarial 19 | from loss.perceptual import PerceptualLoss 20 | from model.edsr import EDSR 21 | from model.rcan import RCAN 22 | from utils.niqe import niqe 23 | from utils.ssim import calc_ssim 24 | 25 | 26 | class SlimContrastiveTrainer: 27 | def __init__(self, args, loader, device, neg_loader=None): 28 | self.model_str = args.model.lower() 29 | self.pic_path = f'./output/{self.model_str}/{args.model_filename}/' 30 | if not os.path.exists(self.pic_path): 31 | self.makedirs = os.makedirs(self.pic_path) 32 | self.teacher_model = args.teacher_model 33 | self.checkpoint_dir = args.pre_train 34 | self.model_filename = args.model_filename 35 | self.model_filepath = f'{self.model_filename}.pth' 36 | self.writer = SummaryWriter(f'log/{self.model_filename}') 37 | 38 | self.start_epoch = -1 39 | self.device = device 40 | self.epochs = args.epochs 41 | self.init_lr = args.lr 42 | self.rgb_range = args.rgb_range 43 | self.scale = args.scale[0] 44 | self.stu_width_mult = args.stu_width_mult 45 | self.batch_size = args.batch_size 46 | self.neg_num = args.neg_num 47 | self.save_results = args.save_results 48 | self.self_ensemble = args.self_ensemble 49 | self.print_every = args.print_every 50 | self.best_psnr = 0 51 | self.best_psnr_epoch = -1 52 | 53 | self.loader = loader 54 | self.mean = [0.404, 0.436, 0.446] 55 | self.std = [0.288, 0.263, 0.275] 56 | 57 | self.build_model(args) 58 | self.upsampler = nn.Upsample(scale_factor=self.scale, mode='bicubic') 59 | self.optimizer = utility.make_optimizer(args, self.model) 60 | 61 | self.t_lambda = args.t_lambda 62 | self.contra_lambda = args.contra_lambda 63 | self.ad_lambda = args.ad_lambda 64 | self.percep_lambda = args.percep_lambda 65 | self.t_detach = args.contrast_t_detach 66 | self.contra_loss = ContrastLoss(args.vgg_weight, args.d_func, self.t_detach) 67 | self.l1_loss = nn.L1Loss() 68 | self.ad_loss = Adversarial(args, 'GAN') 69 | self.percep_loss = PerceptualLoss() 70 | self.t_l_remove = args.t_l_remove 71 | 72 | def train(self): 73 | self.model.train() 74 | 75 | total_iter = (self.start_epoch+1)*len(self.loader.loader_train) 76 | for epoch in range(self.start_epoch + 1, self.epochs): 77 | if epoch >= self.t_l_remove: 78 | self.t_lambda = 0 79 | 80 | starttime = datetime.datetime.now() 81 | 82 | lrate = utility.adjust_learning_rate(self.optimizer, epoch, self.epochs, self.init_lr) 83 | print("[Epoch {}]\tlr:{}\t".format(epoch, lrate)) 84 | psnr, t_psnr = 0.0, 0.0 85 | step = 0 86 | for batch, (lr, hr, _,) in enumerate(self.loader.loader_train): 87 | torch.cuda.empty_cache() 88 | step += 1 89 | total_iter += 1 90 | lr = lr.to(self.device) 91 | hr = hr.to(self.device) 92 | 93 | self.optimizer.zero_grad() 94 | teacher_sr = self.model(lr) 95 | 96 | student_sr = self.model(lr, self.stu_width_mult) 97 | l1_loss = self.l1_loss(hr, student_sr) 98 | teacher_l1_loss = self.l1_loss(hr, teacher_sr) 99 | 100 | bic_sample = lr[torch.randperm(self.neg_num), :, :, :] 101 | bic_sample = self.upsampler(bic_sample) 102 | contras_loss = 0.0 103 | 104 | if self.neg_num > 0: 105 | contras_loss = self.contra_loss(teacher_sr, student_sr, bic_sample) 106 | 107 | loss = l1_loss + self.contra_lambda * contras_loss + self.t_lambda * teacher_l1_loss 108 | if self.ad_lambda > 0: 109 | ad_loss = self.ad_loss(student_sr, hr) 110 | loss += self.ad_lambda * ad_loss 111 | self.writer.add_scalar('Train/Ad_loss', ad_loss, total_iter) 112 | if self.percep_lambda > 0: 113 | percep_loss = self.percep_loss(hr, student_sr) 114 | loss += self.percep_lambda * percep_loss 115 | self.writer.add_scalar('Train/Percep_loss', percep_loss, total_iter) 116 | 117 | loss.backward() 118 | self.optimizer.step() 119 | 120 | self.writer.add_scalar('Train/L1_loss', l1_loss, total_iter) 121 | self.writer.add_scalar('Train/Contras_loss', contras_loss, total_iter) 122 | self.writer.add_scalar('Train/Teacher_l1_loss', teacher_l1_loss, total_iter) 123 | self.writer.add_scalar('Train/Total_loss', loss, total_iter) 124 | 125 | student_sr = utility.quantize(student_sr, self.rgb_range) 126 | psnr += utility.calc_psnr(student_sr, hr, self.scale, self.rgb_range) 127 | teacher_sr = utility.quantize(teacher_sr, self.rgb_range) 128 | t_psnr += utility.calc_psnr(teacher_sr, hr, self.scale, self.rgb_range) 129 | if (batch + 1) % self.print_every == 0: 130 | print( 131 | f"[Epoch {epoch}/{self.epochs}] [Batch {batch * self.batch_size}/{len(self.loader.loader_train.dataset)}] " 132 | f"[psnr {psnr / step}]" 133 | f"[t_psnr {t_psnr / step}]" 134 | ) 135 | utility.save_results(f'result_{batch}', hr, self.scale, width=1, rgb_range=self.rgb_range, 136 | postfix='hr', dir=self.pic_path) 137 | utility.save_results(f'result_{batch}', teacher_sr, self.scale, width=1, rgb_range=self.rgb_range, 138 | postfix='t_sr', dir=self.pic_path) 139 | utility.save_results(f'result_{batch}', student_sr, self.scale, width=1, rgb_range=self.rgb_range, 140 | postfix='s_sr', dir=self.pic_path) 141 | 142 | print(f"training PSNR @epoch {epoch}: {psnr / step}") 143 | 144 | test_psnr = self.test(self.stu_width_mult) 145 | if test_psnr > self.best_psnr: 146 | print(f"saving models @epoch {epoch} with psnr: {test_psnr}") 147 | self.best_psnr = test_psnr 148 | self.best_psnr_epoch = epoch 149 | torch.save({ 150 | 'epoch': epoch, 151 | 'model_state_dict': self.model.state_dict(), 152 | 'optimizer_state_dict': self.optimizer.state_dict(), 153 | 'best_psnr': self.best_psnr, 154 | 'best_psnr_epoch': self.best_psnr_epoch, 155 | }, f'{self.checkpoint_dir}{self.model_filepath}') 156 | 157 | endtime = datetime.datetime.now() 158 | cost = (endtime - starttime).seconds 159 | print(f"time of epoch{epoch}: {cost}") 160 | 161 | def test(self, width_mult=1): 162 | self.model.eval() 163 | with torch.no_grad(): 164 | psnr = 0 165 | niqe_score = 0 166 | ssim = 0 167 | t0 = time.time() 168 | 169 | starttime = datetime.datetime.now() 170 | for d in self.loader.loader_test: 171 | for lr, hr, filename in d: 172 | lr = lr.to(self.device) 173 | hr = hr.to(self.device) 174 | 175 | x = [lr] 176 | for tf in 'v', 'h', 't': 177 | x.extend([utility.transform(_x, tf, self.device) for _x in x]) 178 | op = ['', 'v', 'h', 'hv', 't', 'tv', 'th', 'thv'] 179 | 180 | if self.self_ensemble: 181 | res = self.model(lr, width_mult) 182 | for i in range(1, len(x)): 183 | _x = x[i] 184 | _sr = self.model(_x, width_mult) 185 | for _op in op[i]: 186 | _sr = utility.transform(_sr, _op, self.device) 187 | res = torch.cat((res, _sr), 0) 188 | sr = torch.mean(res, 0).unsqueeze(0) 189 | else: 190 | sr = self.model(lr, width_mult) 191 | 192 | sr = utility.quantize(sr, self.rgb_range) 193 | if self.save_results: 194 | # if not os.path.exists(f'./output/test/{self.model_str}/{self.model_filename}'): 195 | # self.makedirs = os.makedirs(f'./output/test/{self.model_str}/{self.model_filename}') 196 | utility.save_results(str(filename), sr, self.scale, width_mult, 197 | self.rgb_range, 'SR') 198 | 199 | psnr += utility.calc_psnr(sr, hr, self.scale, self.rgb_range, dataset=d) 200 | niqe_score += niqe(sr.squeeze(0).permute(1, 2, 0).cpu().numpy()) 201 | ssim += calc_ssim(sr, hr, self.scale, dataset=d) 202 | 203 | psnr /= len(d) 204 | niqe_score /= len(d) 205 | ssim /= len(d) 206 | print(width_mult, d.dataset.name, psnr, niqe_score, ssim) 207 | 208 | endtime = datetime.datetime.now() 209 | cost = (endtime - starttime).seconds 210 | t1 = time.time() 211 | total_time = (t1 - t0) 212 | print(f"time of test: {total_time}") 213 | return psnr 214 | 215 | def build_model(self, args): 216 | m = import_module('model.' + self.model_str) 217 | self.model = getattr(m, self.model_str.upper())(args).to(self.device) 218 | self.model = nn.DataParallel(self.model, device_ids=range(args.n_GPUs)) 219 | self.load_model() 220 | 221 | # test teacher 222 | # self.test() 223 | 224 | def load_model(self): 225 | checkpoint_dir = self.checkpoint_dir 226 | print(f"[*] Load model from {checkpoint_dir}") 227 | if not os.path.exists(checkpoint_dir): 228 | self.makedirs = os.makedirs(checkpoint_dir) 229 | 230 | if not os.listdir(checkpoint_dir): 231 | print(f"[!] No checkpoint in {checkpoint_dir}") 232 | return 233 | 234 | model = glob(os.path.join(checkpoint_dir, self.model_filepath)) 235 | 236 | no_student = False 237 | if not model: 238 | no_student = True 239 | print(f"[!] No checkpoint ") 240 | print("Loading pre-trained teacher model") 241 | model = glob(self.teacher_model) 242 | if not model: 243 | print(f"[!] No teacher model ") 244 | return 245 | 246 | model_state_dict = torch.load(model[0]) 247 | if not no_student: 248 | self.start_epoch = model_state_dict['epoch'] 249 | self.best_psnr = model_state_dict['best_psnr'] 250 | self.best_psnr_epoch = model_state_dict['best_psnr_epoch'] 251 | 252 | self.model.load_state_dict(model_state_dict['model_state_dict'], False) -------------------------------------------------------------------------------- /PyTorch version/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/utils/__init__.py -------------------------------------------------------------------------------- /PyTorch version/utils/niqe.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os.path import dirname, join 3 | 4 | import cv2 5 | import numpy as np 6 | import scipy 7 | import scipy.io 8 | import scipy.misc 9 | import scipy.ndimage 10 | import scipy.special 11 | from PIL import Image 12 | 13 | gamma_range = np.arange(0.2, 10, 0.001) 14 | a = scipy.special.gamma(2.0/gamma_range) 15 | a *= a 16 | b = scipy.special.gamma(1.0/gamma_range) 17 | c = scipy.special.gamma(3.0/gamma_range) 18 | prec_gammas = a/(b*c) 19 | 20 | 21 | def aggd_features(imdata): 22 | # flatten imdata 23 | imdata.shape = (len(imdata.flat),) 24 | imdata2 = imdata*imdata 25 | left_data = imdata2[imdata < 0] 26 | right_data = imdata2[imdata >= 0] 27 | left_mean_sqrt = 0 28 | right_mean_sqrt = 0 29 | if len(left_data) > 0: 30 | left_mean_sqrt = np.sqrt(np.average(left_data)) 31 | if len(right_data) > 0: 32 | right_mean_sqrt = np.sqrt(np.average(right_data)) 33 | 34 | if right_mean_sqrt != 0: 35 | gamma_hat = left_mean_sqrt/right_mean_sqrt 36 | else: 37 | gamma_hat = np.inf 38 | # solve r-hat norm 39 | 40 | imdata2_mean = np.mean(imdata2) 41 | if imdata2_mean != 0: 42 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) 43 | else: 44 | r_hat = np.inf 45 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * 46 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 47 | 48 | # solve alpha by guessing values that minimize ro 49 | pos = np.argmin((prec_gammas - rhat_norm)**2) 50 | alpha = gamma_range[pos] 51 | 52 | gam1 = scipy.special.gamma(1.0/alpha) 53 | gam2 = scipy.special.gamma(2.0/alpha) 54 | gam3 = scipy.special.gamma(3.0/alpha) 55 | 56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 57 | bl = aggdratio * left_mean_sqrt 58 | br = aggdratio * right_mean_sqrt 59 | 60 | # mean parameter 61 | N = (br - bl)*(gam2 / gam1) # *aggdratio 62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 63 | 64 | 65 | def ggd_features(imdata): 66 | nr_gam = 1/prec_gammas 67 | sigma_sq = np.var(imdata) 68 | E = np.mean(np.abs(imdata)) 69 | rho = sigma_sq/E**2 70 | pos = np.argmin(np.abs(nr_gam - rho)) 71 | return gamma_range[pos], sigma_sq 72 | 73 | 74 | def paired_product(new_im): 75 | shift1 = np.roll(new_im.copy(), 1, axis=1) 76 | shift2 = np.roll(new_im.copy(), 1, axis=0) 77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 79 | 80 | H_img = shift1 * new_im 81 | V_img = shift2 * new_im 82 | D1_img = shift3 * new_im 83 | D2_img = shift4 * new_im 84 | 85 | return (H_img, V_img, D1_img, D2_img) 86 | 87 | 88 | def gen_gauss_window(lw, sigma): 89 | sd = np.float32(sigma) 90 | lw = int(lw) 91 | weights = [0.0] * (2 * lw + 1) 92 | weights[lw] = 1.0 93 | sum = 1.0 94 | sd *= sd 95 | for ii in range(1, lw + 1): 96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) 97 | weights[lw + ii] = tmp 98 | weights[lw - ii] = tmp 99 | sum += 2.0 * tmp 100 | for ii in range(2 * lw + 1): 101 | weights[ii] /= sum 102 | return weights 103 | 104 | 105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): 106 | if avg_window is None: 107 | avg_window = gen_gauss_window(3, 7.0/6.0) 108 | assert len(np.shape(image)) == 2 109 | h, w = np.shape(image) 110 | mu_image = np.zeros((h, w), dtype=np.float32) 111 | var_image = np.zeros((h, w), dtype=np.float32) 112 | image = np.array(image).astype('float32') 113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) 114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, 115 | mu_image, mode=extend_mode) 116 | scipy.ndimage.correlate1d(image**2, avg_window, 0, 117 | var_image, mode=extend_mode) 118 | scipy.ndimage.correlate1d(var_image, avg_window, 119 | 1, var_image, mode=extend_mode) 120 | var_image = np.sqrt(np.abs(var_image - mu_image**2)) 121 | return (image - mu_image)/(var_image + C), var_image, mu_image 122 | 123 | 124 | def _niqe_extract_subband_feats(mscncoefs): 125 | # alpha_m, = extract_ggd_features(mscncoefs) 126 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) 127 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 128 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 129 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 130 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 131 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 132 | return np.array([alpha_m, (bl+br)/2.0, 133 | alpha1, N1, bl1, br1, # (V) 134 | alpha2, N2, bl2, br2, # (H) 135 | alpha3, N3, bl3, bl3, # (D1) 136 | alpha4, N4, bl4, bl4, # (D2) 137 | ]) 138 | 139 | 140 | def get_patches_train_features(img, patch_size, stride=8): 141 | return _get_patches_generic(img, patch_size, 1, stride) 142 | 143 | 144 | def get_patches_test_features(img, patch_size, stride=8): 145 | return _get_patches_generic(img, patch_size, 0, stride) 146 | 147 | 148 | def extract_on_patches(img, patch_size): 149 | h, w = img.shape 150 | patch_size = np.int(patch_size) 151 | patches = [] 152 | for j in range(0, h-patch_size+1, patch_size): 153 | for i in range(0, w-patch_size+1, patch_size): 154 | patch = img[j:j+patch_size, i:i+patch_size] 155 | patches.append(patch) 156 | 157 | patches = np.array(patches) 158 | 159 | patch_features = [] 160 | for p in patches: 161 | patch_features.append(_niqe_extract_subband_feats(p)) 162 | patch_features = np.array(patch_features) 163 | 164 | return patch_features 165 | 166 | 167 | def _get_patches_generic(img, patch_size, is_train, stride): 168 | h, w = np.shape(img) 169 | if h < patch_size or w < patch_size: 170 | print("Input image is too small") 171 | exit(0) 172 | 173 | # ensure that the patch divides evenly into img 174 | hoffset = (h % patch_size) 175 | woffset = (w % patch_size) 176 | 177 | if hoffset > 0: 178 | img = img[:-hoffset, :] 179 | if woffset > 0: 180 | img = img[:, :-woffset] 181 | 182 | img = img.astype(np.float32) 183 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') 184 | img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5) 185 | 186 | mscn1, var, mu = compute_image_mscn_transform(img) 187 | mscn1 = mscn1.astype(np.float32) 188 | 189 | mscn2, _, _ = compute_image_mscn_transform(img2) 190 | mscn2 = mscn2.astype(np.float32) 191 | 192 | feats_lvl1 = extract_on_patches(mscn1, patch_size) 193 | feats_lvl2 = extract_on_patches(mscn2, patch_size/2) 194 | 195 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) 196 | 197 | return feats 198 | 199 | 200 | def niqe(inputImgData): 201 | 202 | patch_size = 96 203 | module_path = dirname(__file__) 204 | 205 | # TODO: memoize 206 | params = scipy.io.loadmat( 207 | join(module_path, 'niqe_image_params.mat')) 208 | pop_mu = np.ravel(params["pop_mu"]) 209 | pop_cov = params["pop_cov"] 210 | 211 | if inputImgData.ndim == 3: 212 | inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY) 213 | M, N = inputImgData.shape 214 | 215 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) 216 | assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 217 | assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 218 | 219 | feats = get_patches_test_features(inputImgData, patch_size) 220 | sample_mu = np.mean(feats, axis=0) 221 | sample_cov = np.cov(feats.T) 222 | 223 | X = sample_mu - pop_mu 224 | covmat = ((pop_cov+sample_cov)/2.0) 225 | pinvmat = scipy.linalg.pinv(covmat) 226 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) 227 | 228 | return niqe_score 229 | -------------------------------------------------------------------------------- /PyTorch version/utils/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/utils/niqe_image_params.mat -------------------------------------------------------------------------------- /PyTorch version/utils/spatial_trans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def spatial_similarity(fm): 5 | fm = fm.view(fm.size(0), fm.size(1), -1) 6 | norm_fm = fm / (torch.sqrt(torch.sum(torch.pow(fm,2), 1)).unsqueeze(1).expand(fm.shape) + 0.0000001) 7 | s = norm_fm.transpose(1,2).bmm(norm_fm) 8 | s = s.unsqueeze(1) 9 | return s 10 | 11 | def channel_similarity(fm): 12 | fm = fm.view(fm.size(0), fm.size(1), -1) 13 | norm_fm = fm / (torch.sqrt(torch.sum(torch.pow(fm,2), 2)).unsqueeze(2).expand(fm.shape) + 0.0000001) 14 | s = norm_fm.bmm(norm_fm.transpose(1,2)) 15 | s = s.unsqueeze(1) 16 | return s 17 | 18 | def batch_similarity(fm): 19 | fm = fm.view(fm.size(0), -1) 20 | Q = torch.mm(fm, fm.transpose(0,1)) 21 | normalized_Q = Q / torch.norm(Q,2,dim=1).unsqueeze(1).expand(Q.shape) 22 | return normalized_Q 23 | 24 | def FSP(fm1, fm2): 25 | if fm1.size(2) > fm2.size(2): 26 | fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3))) 27 | 28 | fm1 = fm1.view(fm1.size(0), fm1.size(1), -1) 29 | fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2) 30 | 31 | fsp = torch.bmm(fm1, fm2) / fm1.size(2) 32 | 33 | return fsp 34 | 35 | def AT(fm): 36 | eps=1e-6 37 | am = torch.pow(torch.abs(fm), 2) 38 | am = torch.sum(am, dim=1, keepdim=True) 39 | norm = torch.norm(am, dim=(2,3), keepdim=True) 40 | am = torch.div(am, norm+eps) 41 | return am -------------------------------------------------------------------------------- /PyTorch version/utils/ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import signal 3 | 4 | def trans2Y(img): 5 | img_r = img[:, 0, :, :] 6 | img_g = img[:, 1, :, :] 7 | img_b = img[:, 2, :, :] 8 | img_y = 0.256789 * img_r + 0.504129 * img_g + 0.097906 * img_b + 16 9 | return img_y 10 | 11 | def matlab_style_gauss2D(shape=(3,3),sigma=0.5): 12 | """ 13 | 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma]) 14 | Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m) 15 | """ 16 | m,n = [(ss-1.)/2. for ss in shape] 17 | y,x = np.ogrid[-m:m+1,-n:n+1] 18 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 19 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 20 | sumh = h.sum() 21 | if sumh != 0: 22 | h /= sumh 23 | return h 24 | 25 | def calc_ssim(X, Y, scale, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255): 26 | ''' 27 | X : y channel (i.e., luminance) of transformed YCbCr space of X 28 | Y : y channel (i.e., luminance) of transformed YCbCr space of Y 29 | Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017). 30 | Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba 31 | The authors of EDSR use MATLAB's ssim as the evaluation tool, 32 | thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2. 33 | ''' 34 | gaussian_filter = matlab_style_gauss2D((11, 11), sigma) 35 | 36 | X = trans2Y(X).squeeze() 37 | Y = trans2Y(Y).squeeze() 38 | X = X.cpu().numpy().astype(np.float64) 39 | Y = Y.cpu().numpy().astype(np.float64) 40 | 41 | shave = scale 42 | if dataset and not dataset.dataset.benchmark: 43 | shave = scale + 6 44 | X = X[shave:-shave, shave:-shave] 45 | Y = Y[shave:-shave, shave:-shave] 46 | 47 | window = gaussian_filter / np.sum(np.sum(gaussian_filter)) 48 | 49 | window = np.fliplr(window) 50 | window = np.flipud(window) 51 | 52 | ux = signal.convolve2d(X, window, mode='valid', boundary='fill', fillvalue=0) 53 | uy = signal.convolve2d(Y, window, mode='valid', boundary='fill', fillvalue=0) 54 | 55 | uxx = signal.convolve2d(X * X, window, mode='valid', boundary='fill', fillvalue=0) 56 | uyy = signal.convolve2d(Y * Y, window, mode='valid', boundary='fill', fillvalue=0) 57 | uxy = signal.convolve2d(X * Y, window, mode='valid', boundary='fill', fillvalue=0) 58 | 59 | vx = uxx - ux * ux 60 | vy = uyy - uy * uy 61 | vxy = uxy - ux * uy 62 | 63 | C1 = (K1 * R) ** 2 64 | C2 = (K2 * R) ** 2 65 | 66 | A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2)) 67 | D = B1 * B2 68 | S = (A1 * A2) / D 69 | mssim = S.mean() 70 | 71 | # window = gaussian_filter 72 | # 73 | # ux = signal.convolve2d(X, window, mode='same', boundary='symm') 74 | # uy = signal.convolve2d(Y, window, mode='same', boundary='symm') 75 | # 76 | # uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm') 77 | # uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm') 78 | # uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm') 79 | # 80 | # vx = uxx - ux * ux 81 | # vy = uyy - uy * uy 82 | # vxy = uxy - ux * uy 83 | # 84 | # C1 = (K1 * R) ** 2 85 | # C2 = (K2 * R) ** 2 86 | # 87 | # A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2)) 88 | # D = B1 * B2 89 | # S = (A1 * A2) / D 90 | # mssim = S.mean() 91 | 92 | return mssim -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSD 2 | This is the official implementation (including [PyTorch version](https://github.com/Booooooooooo/CSD/tree/main/PyTorch%20version) and [MindSpore version](https://github.com/Booooooooooo/CSD/tree/main/MindSpore%20version)) of [Towards Compact Single Image Super-Resolution via Contrastive Self-distillation, IJCAI2021](https://arxiv.org/abs/2105.11683) 3 | 4 | ## Abstract 5 | 6 | Convolutional neural networks (CNNs) are highly successful for super-resolution (SR) but often require sophisticated architectures with heavy memory cost and computational overhead, significantly restricts their practical deployments on resource-limited devices. In this paper, we proposed a novel contrastive self-distillation (CSD) framework to simultaneously compress and accelerate various off-the-shelf SR models. In particular, a channel-splitting super-resolution network can first be constructed from a target teacher network as a compact student network. Then, we propose a novel contrastive loss to improve the quality of SR images and PSNR/SSIM via explicit knowledge transfer. Extensive experiments demonstrate that the proposed CSD scheme effectively compresses and accelerates several standard SR models such as EDSR, RCAN and CARN. 7 | 8 | ![model](https://github.com/Booooooooooo/CSD/blob/main/images/model.png) 9 | 10 | ## Results 11 | 12 | ![tradeoff](https://github.com/Booooooooooo/CSD/blob/main/images/tradeoff.png) 13 | 14 | ![table](https://github.com/Booooooooooo/CSD/blob/main/images/table.png) 15 | 16 | ![visual](https://github.com/Booooooooooo/CSD/blob/main/images/visual.png) 17 | 18 | ## Citation 19 | 20 | If you find the code helpful in you research or work, please cite as: 21 | 22 | ```@inproceedings{wu2021contrastive, 23 | @misc{wang2021compact, 24 | title={Towards Compact Single Image Super-Resolution via Contrastive Self-distillation}, 25 | author={Yanbo Wang and Shaohui Lin and Yanyun Qu and Haiyan Wu and Zhizhong Zhang and Yuan Xie and Angela Yao}, 26 | year={2021}, 27 | eprint={2105.11683}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CV} 30 | } 31 | ``` 32 | 33 | ## Acknowledgements 34 | 35 | This code is built on [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch). For the training part of the MindSpore version we referred to [DBPN-MindSpore](https://gitee.com/amythist/DBPN-MindSpore/tree/master), [ModelZoo-RCAN](https://gitee.com/mindspore/models/tree/master/research/cv/RCAN) and the official [tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html). We thank the authors for sharing their codes. 36 | 37 | -------------------------------------------------------------------------------- /images/debug.log: -------------------------------------------------------------------------------- 1 | [0526/113715:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 2 | [0526/113720:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 3 | [0527/124008:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 4 | [0527/124014:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 5 | [0527/124019:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 6 | [0527/124024:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 7 | [0527/124029:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 8 | [0527/124034:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 9 | [0527/124039:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 10 | [0527/124044:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 11 | [0527/124050:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 12 | [0527/124055:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 13 | [0527/135848:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 14 | [0527/135853:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 15 | [0527/135858:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 16 | [0527/135903:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 17 | [0527/135908:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 18 | [0527/135913:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 19 | [0527/135918:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 20 | [0527/135923:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 21 | [0527/135928:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 22 | [0527/135933:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 23 | [0527/135938:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 24 | [0527/135943:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 25 | [0527/135948:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 26 | [0527/135953:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 27 | [0527/135958:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 28 | [0527/140003:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 29 | [0527/140008:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 30 | [0527/140013:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 31 | [0527/140018:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 32 | [0527/140023:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 33 | [0527/140028:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 34 | [0527/140033:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 35 | [0527/140038:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 36 | [0527/140043:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 37 | [0527/140048:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 38 | [0527/140053:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 39 | [0527/140058:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 40 | [0527/140103:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 41 | [0527/140108:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 42 | [0527/140113:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 43 | [0527/140118:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 44 | [0527/140123:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 45 | [0527/140128:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 46 | [0527/140133:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 47 | [0527/140138:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 48 | [0527/140143:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 49 | [0527/140148:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 50 | [0527/140153:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 51 | [0527/140158:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 52 | [0527/140203:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 53 | [0527/140208:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 54 | [0527/140213:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 55 | [0527/140218:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 56 | [0527/140223:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 57 | [0527/140228:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 58 | [0527/140233:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 59 | [0527/140238:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 60 | [0527/140243:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 61 | [0527/140248:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 62 | [0527/140253:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 63 | [0527/140258:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 64 | [0527/140303:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 65 | [0527/140308:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 66 | [0527/140313:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 67 | [0527/140318:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 68 | [0527/140323:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 69 | [0527/140328:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 70 | [0527/140333:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 71 | [0527/140338:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 72 | [0527/140343:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 73 | [0527/140348:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 74 | [0527/140353:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 75 | [0527/140358:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 76 | [0527/140403:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 77 | [0527/140408:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 78 | [0527/140413:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 79 | [0527/140418:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 80 | [0527/140423:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 81 | [0527/140428:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 82 | [0527/140433:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 83 | [0527/140438:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 84 | [0527/140443:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 85 | [0527/140448:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 86 | [0527/140453:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 87 | [0527/140458:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 88 | [0527/140503:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 89 | [0527/140508:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 90 | [0527/140513:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 91 | [0527/140518:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 92 | [0527/140523:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 93 | [0527/140528:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 94 | [0527/140533:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 95 | [0527/140538:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 96 | [0527/140543:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 97 | [0527/140548:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 98 | [0527/140553:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 99 | [0527/140558:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 100 | [0527/140603:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 101 | [0527/140608:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 102 | [0527/140613:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 103 | [0527/140618:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 104 | [0527/140623:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 105 | [0527/140628:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 106 | [0527/140633:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 107 | [0527/140638:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 108 | [0527/140643:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 109 | [0527/140648:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 110 | [0527/140653:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 111 | [0527/140658:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 112 | [0527/140703:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 113 | [0527/140708:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 114 | [0527/140713:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 115 | [0527/140718:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 116 | [0527/140723:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 117 | [0527/140728:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 118 | [0527/140734:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 119 | [0527/140739:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 120 | [0527/140744:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 121 | [0527/140749:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 122 | [0529/121726:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 123 | [0529/121745:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 124 | [0529/121749:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 125 | [0529/121754:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 126 | [0529/121819:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 127 | [0529/121820:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig. 128 | -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/model.png -------------------------------------------------------------------------------- /images/psnr-speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/psnr-speed.png -------------------------------------------------------------------------------- /images/psnr-tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/psnr-tradeoff.png -------------------------------------------------------------------------------- /images/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/table.png -------------------------------------------------------------------------------- /images/tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/tradeoff.png -------------------------------------------------------------------------------- /images/visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/visual.png --------------------------------------------------------------------------------