├── .idea
├── CSD.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── modules.xml
└── workspace.xml
├── MindSpore version
├── README.md
├── csd_train.py
├── eval.py
├── export.py
├── src
│ ├── args.py
│ ├── common.py
│ ├── config.py
│ ├── contras_loss.py
│ ├── data
│ │ ├── common.py
│ │ ├── div2k.py
│ │ └── srdata.py
│ ├── edsr_model.py
│ ├── edsr_slim.py
│ ├── metric.py
│ ├── metrics.py
│ ├── rcan_model.py
│ └── vgg_model.py
├── train.py
└── utils
│ ├── __init__.py
│ └── var_init.py
├── PyTorch version
├── README.md
├── data
│ ├── __init__.py
│ ├── benchmark.py
│ ├── bsd500.py
│ ├── common.py
│ ├── demo.py
│ ├── div2k.py
│ ├── div2kjpeg.py
│ ├── sr291.py
│ ├── srdata.py
│ └── video.py
├── dataloader.py
├── loss
│ ├── __init__.py
│ ├── adversarial.py
│ ├── contrast_loss.py
│ ├── discriminator.py
│ ├── perceptual.py
│ └── vgg.py
├── main.py
├── model
│ ├── __init__.py
│ ├── carn.py
│ ├── common.py
│ ├── edsr.py
│ └── rcan.py
├── option.py
├── trainer
│ └── slim_contrast_trainer.py
└── utils
│ ├── __init__.py
│ ├── niqe.py
│ ├── niqe_image_params.mat
│ ├── spatial_trans.py
│ ├── ssim.py
│ └── utility.py
├── README.md
└── images
├── debug.log
├── model.png
├── psnr-speed.png
├── psnr-tradeoff.png
├── table.png
├── tradeoff.png
└── visual.png
/.idea/CSD.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 | 1620185824731
49 |
50 |
51 | 1620185824731
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | file://$PROJECT_DIR$/main.py
64 | 8
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/MindSpore version/README.md:
--------------------------------------------------------------------------------
1 | ## Dependencies
2 |
3 | - Python == 3.7.5
4 |
5 | - MindSpore: https://www.mindspore.cn/install
6 |
7 | - matplotlib
8 |
9 | - imageio
10 |
11 | - tensorboardX
12 |
13 | - opencv-python
14 |
15 | - scipy
16 |
17 | - scikit-image
18 |
19 | ## Train
20 |
21 | ### Prepare data
22 |
23 | We use DIV2K training set as our training data.
24 |
25 | About how to download data, you could refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch)
26 |
27 | ### Train baseline model
28 |
29 | ```bash
30 | # train teacher model
31 | python -u train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_baseline --lr 0.0001 --epochs 5000
32 | ```
33 |
34 | ```bash
35 | # train student model (width=0.25)
36 | python -u train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_baseline025 --lr 0.0001 --epochs 5000 --n_feats 64
37 | ```
38 |
39 | ### Train CSD
40 |
41 | VGG pre-trained on ImageNet is used in our contrastive loss. Due to copyright reasons, the pre-trained VGG cannot be shared publicly.
42 |
43 | ```bash
44 | python -u csd_train.py --dir_data LOCATION_OF_DATA --data_test Set5 --test_every 1 --filename edsr_csd --lr 0.0001 --epochs 5000 --ckpt_path ckpt/TEACHER_MODEL_NAME.ckpt --contra_lambda 200
45 | ```
46 |
47 | ## Test
48 |
49 | ```bash
50 | python eval.py --dir_data LOCATION_OF_DATA --test_only --ext "img" --data_test B100 --ckpt_path ckpt/MODEL_NAME.ckpt --task_id 0 --scale 4
51 | ```
--------------------------------------------------------------------------------
/MindSpore version/csd_train.py:
--------------------------------------------------------------------------------
1 | from mindspore.train.callback import ModelCheckpoint, Callback, LossMonitor, TimeMonitor, CheckpointConfig, _InternalCallbackParam, RunContext
2 | import mindspore.nn as nn
3 | from mindspore import ParameterTuple, Tensor
4 | import mindspore.ops as ops
5 | import mindspore.numpy as numpy
6 | from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
7 | from mindspore.common import set_seed
8 | from mindspore import context
9 | from mindspore.context import ParallelMode
10 | import mindspore.dataset as ds
11 |
12 | import os
13 | import time
14 |
15 | from src.metric import PSNR
16 | from src.args import args
17 | from src.data.div2k import DIV2K
18 | from src.data.srdata import SRData
19 | from src.edsr_slim import EDSR
20 | from src.contras_loss import ContrastLoss
21 | # from eval import do_eval
22 |
23 | class NetWithLossCell(nn.Cell):
24 | def __init__(self, net):
25 | super(NetWithLossCell, self).__init__()
26 | self.net = net
27 | self.l1_loss = nn.L1Loss()
28 |
29 | def construct(self, lr, hr, stu_width_mult, tea_width_mult):
30 | sr = self.net(lr, stu_width_mult)
31 | tea_sr = self.net(lr, tea_width_mult)
32 | loss = self.l1_loss(sr, hr) + self.l1_loss(tea_sr, hr)
33 | return loss
34 |
35 | class NetWithCSDLossCell(nn.Cell):
36 | def __init__(self, net, contrast_w=0, neg_num=0):
37 | super(NetWithCSDLossCell, self).__init__()
38 | self.net = net
39 | self.neg_num = neg_num
40 | self.l1_loss = nn.L1Loss()
41 | self.contrast_loss = ContrastLoss()
42 | self.contrast_w = contrast_w
43 |
44 | def construct(self, lr, hr, stu_width_mult, tea_width_mult):
45 | sr = self.net(lr, stu_width_mult)
46 | tea_sr = self.net(lr, tea_width_mult)
47 | loss = self.l1_loss(sr, hr) + self.l1_loss(tea_sr, hr)
48 |
49 | resize = nn.ResizeBilinear()
50 | bic = resize(lr, size=(lr.shape[-2] * 4, lr.shape[-1] * 4))
51 | neg = numpy.flip(bic, 0)
52 | neg = neg[:self.neg_num, :, :, :]
53 | loss += self.contrast_w * self.contrast_loss(tea_sr, sr, neg)
54 | return loss
55 |
56 | class TrainOneStepCell(nn.Cell):
57 | def __init__(self, network, optimizer, sens=1.0):
58 | super(TrainOneStepCell, self).__init__(auto_prefix=False)
59 | self.network = network
60 | self.weights = optimizer.parameters
61 | self.optimizer = optimizer
62 | self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
63 | self.sens = sens
64 |
65 | def set_sens(self, value):
66 | self.sens = value
67 |
68 | def construct(self, lr, hr, width_mult, tea_width_mult):
69 | weights = self.weights
70 | loss = self.network(lr, hr, width_mult, tea_width_mult)
71 |
72 | sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
73 | grads = self.grad(self.network, weights)(lr, hr, width_mult, tea_width_mult, sens)
74 | self.optimizer(grads)
75 | return loss
76 |
77 | def csd_train(train_loader, net, opt):
78 | set_seed(1)
79 | device_id = int(os.getenv('DEVICE_ID', '0'))
80 | print("[CSD] Start Training...")
81 |
82 | step_size = train_loader.get_dataset_size()
83 | lr = []
84 | for i in range(0, opt.epochs):
85 | cur_lr = opt.lr / (2 ** ((i + 1) // 200))
86 | lr.extend([cur_lr] * step_size)
87 | optim = nn.Adam(net.trainable_params(), learning_rate=lr, loss_scale=opt.loss_scale)
88 |
89 | # net_with_loss = NetWithLossCell(net)
90 | net_with_loss = NetWithCSDLossCell(net, args.contra_lambda, args.neg_num)
91 | train_cell = TrainOneStepCell(net_with_loss, optim)
92 | net.set_train()
93 | eval_net = net
94 |
95 | # time_cb = TimeMonitor(data_size=step_size)
96 | # loss_cb = LossMonitor()
97 | # metrics = {
98 | # "psnr": PSNR(rgb_range=opt.rgb_range, shave=True),
99 | # }
100 | # eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size / opt.batch_size, metrics=metrics,
101 | # rank_id=rank_id)
102 | # cb = [time_cb, loss_cb]
103 | # config_ck = CheckpointConfig(save_checkpoint_steps=opt.ckpt_save_interval * step_size,
104 | # keep_checkpoint_max=opt.ckpt_save_max)
105 | # ckpt_cb = ModelCheckpoint(prefix=opt.filename, directory=opt.ckpt_save_path, config=config_ck)
106 | # if device_id == 0:
107 | # cb += [ckpt_cb]
108 |
109 | for epoch in range(0, opt.epochs):
110 | epoch_loss = 0
111 | for iteration, batch in enumerate(train_loader.create_dict_iterator(), 1):
112 | lr = batch["LR"]
113 | hr = batch["HR"]
114 |
115 | loss = train_cell(lr, hr, Tensor(opt.stu_width_mult), Tensor(1.0))
116 | epoch_loss += loss
117 |
118 | print(f"Epoch[{epoch}] loss: {epoch_loss.asnumpy()}")
119 | # with eval_net.set_train(False):
120 | # do_eval(eval_ds, eval_net)
121 |
122 | if (epoch) % 10 == 0:
123 | print('===> Saving model...')
124 | save_checkpoint(net, f'./ckpt/{opt.filename}.ckpt')
125 | # cb_params.cur_epoch_num = epoch + 1
126 | # ckpt_cb.step_end(run_context)
127 |
128 |
129 | if __name__ == '__main__':
130 | time_start = time.time()
131 | device_id = int(os.getenv('DEVICE_ID', '0'))
132 | rank_id = int(os.getenv('RANK_ID', '0'))
133 | device_num = int(os.getenv('RANK_SIZE', '1'))
134 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False, device_id=device_id)
135 |
136 | train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
137 | train_dataset.set_scale(args.task_id)
138 | print(len(train_dataset))
139 | train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
140 | shard_id=rank_id, shuffle=True)
141 | train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
142 |
143 | eval_dataset = SRData(args, name=args.data_test, train=False, benchmark=True)
144 | print(len(eval_dataset))
145 | eval_ds = ds.GeneratorDataset(eval_dataset, ['LR', 'HR'], shuffle=False)
146 | eval_ds = eval_ds.batch(1, drop_remainder=True)
147 |
148 | # net_m = RCAN(args)
149 | net_m = EDSR(args)
150 | print("Init net weights successfully")
151 |
152 | if args.ckpt_path:
153 | param_dict = load_checkpoint(args.ckpt_path)
154 | load_param_into_net(net_m, param_dict)
155 | print("Load net weight successfully")
156 |
157 | csd_train(train_de_dataset, net_m, args)
158 | time_end = time.time()
159 | print('train_time: %f' % (time_end - time_start))
--------------------------------------------------------------------------------
/MindSpore version/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """eval script"""
16 | import os
17 | import time
18 | import numpy as np
19 | import mindspore.dataset as ds
20 | from mindspore import Tensor, context
21 | from mindspore.common import dtype as mstype
22 | from mindspore.train.serialization import load_checkpoint, load_param_into_net
23 | import mindspore.nn as nn
24 | from mindspore.train.model import Model
25 |
26 | from src.args import args
27 | import src.rcan_model as rcan
28 | # from src.edsr_model import EDSR
29 | from src.edsr_slim import EDSR
30 | from src.data.srdata import SRData
31 | from src.metrics import calc_psnr, quantize, calc_ssim
32 | from src.data.div2k import DIV2K
33 |
34 | # device_id = int(os.getenv('DEVICE_ID', '0'))
35 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id, save_graphs=False)
36 | # context.set_context(max_call_depth=10000)
37 | def eval_net(width_mult=1.0):
38 | """eval"""
39 | if args.epochs == 0:
40 | args.epochs = 1e8
41 | for arg in vars(args):
42 | if vars(args)[arg] == 'True':
43 | vars(args)[arg] = True
44 | elif vars(args)[arg] == 'False':
45 | vars(args)[arg] = False
46 | if args.data_test[0] == 'DIV2K':
47 | train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
48 | else:
49 | train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
50 | train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
51 | train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
52 | train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
53 | #net_m = rcan.RCAN(args)
54 | net_m = EDSR(args)
55 | net_m.set_train(False)
56 | if args.ckpt_path:
57 | print(f"Load from {args.ckpt_path}")
58 | param_dict = load_checkpoint(args.ckpt_path)
59 | load_param_into_net(net_m, param_dict)
60 |
61 | # opt = nn.Adam(net_m.trainable_params(), learning_rate=0.0001, loss_scale=args.loss_scale)
62 | # model = Model(net_m, nn.L1Loss(), opt)
63 |
64 | print('load mindspore net successfully.')
65 | num_imgs = train_de_dataset.get_dataset_size()
66 | psnrs = np.zeros((num_imgs, 1))
67 | ssims = np.zeros((num_imgs, 1))
68 | for batch_idx, imgs in enumerate(train_loader):
69 | lr = imgs['LR']
70 | hr = imgs['HR']
71 | lr = Tensor(lr, mstype.float32)
72 | pred = net_m(lr, Tensor(width_mult))
73 | pred_np = pred.asnumpy()
74 | pred_np = quantize(pred_np, 255)
75 | psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
76 | pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
77 | hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
78 | ssim = calc_ssim(pred_np, hr, args.scale[0])
79 | print("current psnr: ", psnr)
80 | print("current ssim: ", ssim)
81 | psnrs[batch_idx, 0] = psnr
82 | ssims[batch_idx, 0] = ssim
83 | print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
84 | print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
85 |
86 | def do_eval(eval_ds, eval_net, width_mult=1.0):
87 | train_loader = eval_ds.create_dict_iterator(output_numpy=True)
88 | num_imgs = eval_ds.get_dataset_size()
89 | psnrs = np.zeros((num_imgs, 1))
90 | ssims = np.zeros((num_imgs, 1))
91 | for batch_idx, imgs in enumerate(train_loader):
92 | lr = imgs['LR']
93 | hr = imgs['HR']
94 | lr = Tensor(lr, mstype.float32)
95 | pred = eval_net(lr, Tensor(width_mult))
96 | pred_np = pred.asnumpy()
97 | pred_np = quantize(pred_np, 255)
98 | psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
99 | pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
100 | hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
101 | ssim = calc_ssim(pred_np, hr, args.scale[0])
102 | # print("current psnr: ", psnr)
103 | # print("current ssim: ", ssim)
104 | psnrs[batch_idx, 0] = psnr
105 | ssims[batch_idx, 0] = ssim
106 | print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
107 | print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
108 | return np.mean(psnrs)
109 |
110 | if __name__ == '__main__':
111 | device_id = int(os.getenv('DEVICE_ID', '0'))
112 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id, save_graphs=False)
113 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=device_id, save_graphs=False)
114 | context.set_context(max_call_depth=10000)
115 | time_start = time.time()
116 | print("Start eval function!")
117 | print("Eval 1.0 Teacher")
118 | eval_net()
119 | time_end = time.time()
120 | print('eval_time: %f' % (time_end - time_start))
121 |
122 | print("Eval Student")
123 | time_start = time.time()
124 | eval_net(args.stu_width_mult)
125 | time_end = time.time()
126 | print('eval_time: %f' % (time_end - time_start))
--------------------------------------------------------------------------------
/MindSpore version/export.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """export net together with checkpoint into air/mindir/onnx models"""
16 | import os
17 | import argparse
18 | import numpy as np
19 | from src.args import args as arg
20 | from src.rcan_model import RCAN
21 | from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
22 |
23 |
24 | parser = argparse.ArgumentParser(description='rcan export')
25 | parser.add_argument("--batch_size", type=int, default=1, help="batch size")
26 | parser.add_argument("--ckpt_path", type=str, required=True, help="path of checkpoint file")
27 | parser.add_argument("--file_name", type=str, default="rcan", help="output file name.")
28 | parser.add_argument("--file_format", type=str, default="MINDIR", choices=['MINDIR', 'AIR', 'ONNX'], help="file format")
29 | args_1 = parser.parse_args()
30 |
31 |
32 | def run_export(args):
33 | """ export """
34 | device_id = int(os.getenv('DEVICE_ID', '0'))
35 | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
36 | net = RCAN(arg)
37 | param_dict = load_checkpoint(args.ckpt_path)
38 | load_param_into_net(net, param_dict)
39 | net.set_train(False)
40 | print('load mindspore net and checkpoint successfully.')
41 | inputs = Tensor(np.zeros([args.batch_size, 3, 678, 1020], np.float32))
42 | export(net, inputs, file_name=args.file_name, file_format=args.file_format)
43 | print('export successfully!')
44 |
45 |
46 | if __name__ == "__main__":
47 | run_export(args_1)
48 |
--------------------------------------------------------------------------------
/MindSpore version/src/args.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """args"""
16 | import argparse
17 | import ast
18 |
19 | parser = argparse.ArgumentParser(description='RCAN')
20 |
21 | # Hardware specifications
22 | parser.add_argument('--seed', type=int, default=1,
23 | help='random seed')
24 |
25 | # Data specifications
26 | parser.add_argument('--dir_data', type=str, default='/cache/data/',
27 | help='dataset directory')
28 | parser.add_argument('--data_train', type=str, default='DIV2K',
29 | help='train dataset name')
30 | parser.add_argument('--data_test', type=str, default='DIV2K',
31 | help='test dataset name')
32 | parser.add_argument('--data_range', type=str, default='1-800/801-810',
33 | help='train/test data range')
34 | parser.add_argument('--ext', type=str, default='sep',
35 | help='dataset file extension')
36 | parser.add_argument('--scale', type=str, default='4',
37 | help='super resolution scale')
38 | parser.add_argument('--patch_size', type=int, default=48,
39 | help='output patch size')
40 | parser.add_argument('--rgb_range', type=int, default=255,
41 | help='maximum value of RGB')
42 | parser.add_argument('--n_colors', type=int, default=3,
43 | help='number of color channels to use')
44 | parser.add_argument('--no_augment', action='store_true',
45 | help='do not use data augmentation')
46 |
47 | # Model specifications
48 | parser.add_argument('--model', default='RCAN',
49 | help='model name')
50 | parser.add_argument('--act', type=str, default='relu',
51 | help='activation function')
52 | parser.add_argument('--n_resblocks', type=int, default=20,
53 | help='number of residual blocks')
54 | parser.add_argument('--n_feats', type=int, default=64,
55 | help='number of feature maps')
56 | parser.add_argument('--res_scale', type=float, default=1,
57 | help='residual scaling')
58 |
59 |
60 | # Option for Residual channel attention network (RCAN)
61 | parser.add_argument('--n_resgroups', type=int, default=10,
62 | help='number of residual groups')
63 | parser.add_argument('--reduction', type=int, default=16,
64 | help='number of feature maps reduction')
65 |
66 | # Training specifications
67 | parser.add_argument('--test_every', type=int, default=4000,
68 | help='do test per every N batches')
69 | parser.add_argument('--epochs', type=int, default=1000,
70 | help='number of epochs to train')
71 | parser.add_argument('--batch_size', type=int, default=16,
72 | help='input batch size for training')
73 | parser.add_argument('--test_only', action='store_true',
74 | help='set this option to test the model')
75 |
76 |
77 | # Optimization specifications
78 | parser.add_argument('--lr', type=float, default=1e-5,
79 | help='learning rate')
80 | parser.add_argument('--loss_scale', type=float, default=1024.0,
81 | help='scaling factor for optim')
82 | parser.add_argument('--init_loss_scale', type=float, default=65536.,
83 | help='scaling factor')
84 | parser.add_argument('--decay', type=str, default='200',
85 | help='learning rate decay type')
86 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
87 | help='ADAM beta')
88 | parser.add_argument('--epsilon', type=float, default=1e-8,
89 | help='ADAM epsilon for numerical stability')
90 | parser.add_argument('--weight_decay', type=float, default=0,
91 | help='weight decay')
92 | parser.add_argument('--gclip', type=float, default=0,
93 | help='gradient clipping threshold (0 = no clipping)')
94 |
95 | # ckpt specifications
96 | parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
97 | help='path to save ckpt')
98 | parser.add_argument('--ckpt_save_interval', type=int, default=10,
99 | help='save ckpt frequency, unit is epoch')
100 | parser.add_argument('--ckpt_save_max', type=int, default=100,
101 | help='max number of saved ckpt')
102 | parser.add_argument('--ckpt_path', type=str, default='',
103 | help='path of saved ckpt')
104 | parser.add_argument('--filename', type=str, default='')
105 |
106 | # Task
107 | parser.add_argument('--task_id', type=int, default=0)
108 |
109 | # CSD
110 | parser.add_argument('--stu_width_mult', type=float, default=0.25)
111 | parser.add_argument('--neg_num', type=float, default=10)
112 | parser.add_argument('--contra_lambda', type=float, default=1, help='weight of contra_loss')
113 |
114 | # ModelArts
115 | parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
116 | help='train on modelarts or not, default is False')
117 | parser.add_argument('--data_url', type=str, default='', help='the directory path of saved file')
118 |
119 |
120 | args, unparsed = parser.parse_known_args()
121 |
122 | args.scale = [int(x) for x in args.scale.split("+")]
123 | args.data_train = args.data_train.split('+')
124 | args.data_test = args.data_test.split('+')
125 |
126 | if args.epochs == 0:
127 | args.epochs = 1e8
128 |
129 | for arg in vars(args):
130 | if vars(args)[arg] == 'True':
131 | vars(args)[arg] = True
132 | elif vars(args)[arg] == 'False':
133 | vars(args)[arg] = False
134 |
--------------------------------------------------------------------------------
/MindSpore version/src/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import mindspore
4 | import mindspore.nn as nn
5 | import mindspore.ops as ops
6 | from mindspore import Tensor
7 | import mindspore.numpy as numpy
8 | from mindspore.common.initializer import TruncatedNormal
9 |
10 |
11 | def weight_variable():
12 | """weight initial"""
13 | return TruncatedNormal(0.02)
14 |
15 | def conv(in_channels, out_channels, kernel_size, stride=1, padding=0,
16 | pad_mode='pad', has_bias=True):
17 | """weight initial for conv layer"""
18 | weight = weight_variable()
19 | return nn.Conv2d(in_channels, out_channels,
20 | kernel_size=kernel_size, stride=stride, padding=padding,
21 | weight_init=weight, has_bias=has_bias, pad_mode=pad_mode)
22 |
23 | def pixel_shuffle(tensor, scale_factor):
24 | """
25 | Implementation of pixel shuffle using numpy
26 |
27 | Parameters:
28 | -----------
29 | tensor: input tensor, shape is [N, C, H, W]
30 | scale_factor: scale factor to up-sample tensor
31 |
32 | Returns:
33 | --------
34 | tensor: tensor after pixel shuffle, shape is [N, C/(r*r), r*H, r*W],
35 | where r refers to scale factor
36 | """
37 | num, ch, height, width = tensor.shape
38 | # assert ch % (scale_factor * scale_factor) == 0
39 |
40 | new_ch = ch // (scale_factor * scale_factor)
41 | new_height = height * scale_factor
42 | new_width = width * scale_factor
43 |
44 | reshape = ops.Reshape()
45 | tensor = reshape(tensor, (num, new_ch, scale_factor, scale_factor, height, width))
46 | # new axis: [num, new_ch, height, scale_factor, width, scale_factor]
47 | transpose = ops.Transpose()
48 | tensor = transpose(tensor, (0, 1, 4, 2, 5, 3))
49 | tensor = reshape(tensor, (num, new_ch, new_height, new_width))
50 | return tensor
51 |
52 | def MeanShift(x, rgb_range,
53 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
54 |
55 | # super(MeanShift, self).__init__(3, 3, kernel_size=1)
56 |
57 | std = Tensor(rgb_std)
58 | conv2d = ops.Conv2D(out_channel=3, kernel_size=1)
59 | biasadd = ops.BiasAdd()
60 | weight = numpy.eye(3, 3).view((3, 3, 1, 1)) / std.view(3, 1, 1, 1)
61 | bias = sign * rgb_range * Tensor(rgb_mean) / std
62 | weight = weight.astype(numpy.float32)
63 | bias = bias.astype(numpy.float32)
64 |
65 | x = conv2d(x, weight)
66 | x = biasadd(x, bias)
67 | return x
68 |
69 | class ResidualBlock(nn.Cell):
70 | def __init__(self, n_feats, kernel_size, act, res_scale):
71 | super(ResidualBlock, self).__init__()
72 | self.n_feats = n_feats
73 | self.res_scale = res_scale
74 | self.kernel_size = kernel_size
75 |
76 | # self.conv1 = nn.Conv2d(in_channels=n_feats, out_channels=n_feats, kernel_size=kernel_size, pad_mode='pad', padding=1, has_bias=True)
77 | self.conv1 = conv(n_feats, n_feats, kernel_size, padding=1)
78 | self.act = act
79 | # self.conv2 = nn.Conv2d(in_channels=n_feats, out_channels=n_feats, kernel_size=kernel_size, pad_mode='pad', padding=1, has_bias=True)
80 | self.conv2 = conv(n_feats, n_feats, kernel_size, padding=1)
81 |
82 | def construct(self, x, width_mult=1):
83 | # round = ops.Round()
84 | width = int(self.n_feats * width_mult)
85 | # width = round(self.n_feats * width_mult)
86 | conv2d = ops.Conv2D(out_channel=width, kernel_size=self.kernel_size, mode=1, pad_mode='pad', pad=1)
87 | biasadd = ops.BiasAdd()
88 | weight = self.conv1.weight[:width, :width, :, :]
89 | bias = self.conv1.bias[:width]
90 | residual = conv2d(x, weight)
91 | if bias is not None:
92 | residual = biasadd(residual, bias)
93 | residual = self.act(residual)
94 | weight = self.conv2.weight[:width, :width, :, :]
95 | bias = self.conv2.bias[:width]
96 | residual = conv2d(residual, weight)
97 | if bias is not None:
98 | residual = biasadd(residual, bias)
99 |
100 | return x + residual * self.res_scale
101 |
102 | class Upsampler(nn.SequentialCell):
103 | def __init__(self, scale_factor, nf):
104 | super(Upsampler, self).__init__()
105 | block = []
106 | self.nf = nf
107 | self.scale = scale_factor
108 |
109 | if scale_factor == 3:
110 | block += [
111 | # nn.Conv2d(in_channels=nf, out_channels=nf * 9, kernel_size=3, pad_mode='pad', padding=1, has_bias=True)
112 | conv(nf, nf*9, 3, padding=1)
113 | ]
114 | # self.pixel_shuffle = nn.PixelShuffle(3)
115 | # pixel_shuffle function
116 | else:
117 | self.block_num = scale_factor // 2
118 | # self.pixel_shuffle = nn.PixelShuffle(2)
119 | #self.act = nn.ReLU()
120 |
121 | for _ in range(self.block_num):
122 | block += [
123 | # nn.Conv2d(in_channels=nf, out_channels=nf * 2 ** 2, kernel_size=3, pad_mode='pad', padding=1, has_bias=True)
124 | conv(nf, nf*2**2, 3, padding=1)
125 | ]
126 | self.blocks = nn.SequentialCell(block)
127 |
128 | def construct(self, x, width_mult=1):
129 | res = x
130 | nf = self.nf
131 | if self.scale == 3:
132 | width = int(width_mult * nf)
133 | width9 = width * 9
134 | conv2d = ops.Conv2D(out_channel=width9, kernel_size=3, mode=1, pad_mode='pad', pad=1)
135 | biasadd = ops.BiasAdd()
136 | for block in self.blocks:
137 | weight = block.weight[:width9, :width, :, :]
138 | bias = block.bias[:width9]
139 | res = conv2d(res, weight)
140 | if bias:
141 | res = biasadd(res, bias)
142 | res = pixel_shuffle(res, self.scale)
143 | else:
144 | width = int(width_mult * nf)
145 | width4 = width * 4
146 | conv2d = ops.Conv2D(out_channel=width4, kernel_size=3, mode=1, pad_mode='pad', pad=1)
147 | biasadd = ops.BiasAdd()
148 | for block in self.blocks:
149 | weight = block.weight[:width4, :width, :, :]
150 | bias = block.bias[:width4]
151 | # print(res.shape)
152 | res = conv2d(res, weight)
153 | # print(res.shape)
154 | if bias is not None:
155 | res = biasadd(res, bias)
156 | res = pixel_shuffle(res, 2)
157 | #res = self.act(res)
158 |
159 | return res
160 |
161 | def SlimModule(input, module, width_mult):
162 | weight = module.weight
163 | out_ch, in_ch = weight.shape[:2]
164 | out_ch = int(out_ch * width_mult)
165 | in_ch = int(in_ch * width_mult)
166 | weight = weight[:out_ch, :in_ch, :, :]
167 | bias = module.bias
168 |
169 | conv2d = ops.Conv2D(out_channel=out_ch, kernel_size=module.kernel_size, mode=1, pad_mode=module.pad_mode, pad=module.padding)
170 | biasadd =ops.BiasAdd()
171 | out = conv2d(input, weight)
172 | if bias is not None:
173 | bias = module.bias[:out_ch]
174 | out = biasadd(out, bias)
175 | return out
176 |
--------------------------------------------------------------------------------
/MindSpore version/src/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """
16 | network config setting, will be used in train.py and eval.py
17 | """
18 | from easydict import EasyDict as edict
19 |
20 | # config for vgg16, cifar10
21 | cifar_cfg = edict({
22 | "num_classes": 10,
23 | "lr": 0.01,
24 | "lr_init": 0.01,
25 | "lr_max": 0.1,
26 | "lr_epochs": '30,60,90,120',
27 | "lr_scheduler": "step",
28 | "warmup_epochs": 5,
29 | "batch_size": 64,
30 | "max_epoch": 70,
31 | "momentum": 0.9,
32 | "weight_decay": 5e-4,
33 | "loss_scale": 1.0,
34 | "label_smooth": 0,
35 | "label_smooth_factor": 0,
36 | "buffer_size": 10,
37 | "image_size": '224,224',
38 | "pad_mode": 'same',
39 | "padding": 0,
40 | "has_bias": False,
41 | "batch_norm": True,
42 | "keep_checkpoint_max": 10,
43 | "initialize_mode": "XavierUniform",
44 | "has_dropout": False
45 | })
46 |
47 | # config for vgg16, imagenet2012
48 | imagenet_cfg = edict({
49 | "num_classes": 1000,
50 | "lr": 0.01,
51 | "lr_init": 0.01,
52 | "lr_max": 0.1,
53 | "lr_epochs": '30,60,90,120',
54 | "lr_scheduler": 'cosine_annealing',
55 | "warmup_epochs": 0,
56 | "batch_size": 32,
57 | "max_epoch": 150,
58 | "momentum": 0.9,
59 | "weight_decay": 1e-4,
60 | "loss_scale": 1024,
61 | "label_smooth": 1,
62 | "label_smooth_factor": 0.1,
63 | "buffer_size": 10,
64 | "image_size": '224,224',
65 | "pad_mode": 'pad',
66 | "padding": 1,
67 | # "has_bias": True,
68 | "has_bias": False,
69 | "batch_norm": False,
70 | "keep_checkpoint_max": 10,
71 | "initialize_mode": "KaimingNormal",
72 | "has_dropout": True
73 | })
--------------------------------------------------------------------------------
/MindSpore version/src/contras_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import mindspore.ops as ops
4 | from mindspore.nn.loss.loss import _Loss
5 | # import mindspore_hub as mshub
6 | import mindspore
7 | from mindspore import context, Tensor, nn
8 | from mindspore.train.model import Model
9 | from mindspore.common import dtype as mstype
10 | from mindspore.dataset.transforms import py_transforms
11 | from mindspore import load_checkpoint, load_param_into_net
12 | from mindspore.ops.functional import stop_gradient
13 | import mindspore.numpy as np
14 |
15 | from src.config import imagenet_cfg
16 | from src.vgg_model import Vgg
17 |
18 | # context.set_context(mode=context.GRAPH_MODE,
19 | # device_target="Ascend",
20 | # device_id=0)
21 | # context.set_context(device_id=0)
22 |
23 | cfg = {
24 | '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
25 | '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
26 | '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
27 | '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
28 | }
29 |
30 | class Vgg19(nn.Cell):
31 | def __init__(self, requires_grad=False):
32 | super(Vgg19, self).__init__()
33 |
34 | ##load vgg16
35 | vgg = Vgg(cfg['19'], phase="test", args=imagenet_cfg)
36 | # model = os.path.join(opt.data_url, 'vgg19_ImageNet.ckpt')
37 | model = os.path.join('./', 'vgg19_ImageNet.ckpt')
38 | print(model)
39 | param_dict = load_checkpoint(model)
40 | # print(param_dict)
41 | load_param_into_net(vgg, param_dict)
42 | vgg.set_train(False)
43 |
44 | vgg_pretrained_features = vgg.layers
45 | self.slice1 = nn.SequentialCell()
46 | self.slice2 = nn.SequentialCell()
47 | self.slice3 = nn.SequentialCell()
48 | self.slice4 = nn.SequentialCell()
49 | self.slice5 = nn.SequentialCell()
50 | for x in range(2):
51 | self.slice1.append(vgg_pretrained_features[x])
52 | for x in range(2, 7):
53 | self.slice2.append(vgg_pretrained_features[x])
54 | for x in range(7, 12):
55 | self.slice3.append(vgg_pretrained_features[x])
56 | for x in range(12, 21):
57 | self.slice4.append(vgg_pretrained_features[x])
58 | for x in range(21, 30):
59 | self.slice5.append(vgg_pretrained_features[x])
60 | if not requires_grad:
61 | for param in self.get_parameters():
62 | param.requires_grad = False
63 |
64 | def construct(self, x):
65 | h_relu1 = self.slice1(x)
66 | h_relu2 = self.slice2(h_relu1)
67 | h_relu3 = self.slice3(h_relu2)
68 | h_relu4 = self.slice4(h_relu3)
69 | h_relu5 = self.slice5(h_relu4)
70 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
71 | return out
72 |
73 |
74 | class ContrastLoss(_Loss):
75 | def __init__(self):
76 | super(ContrastLoss, self).__init__()
77 | self.vgg = Vgg19()
78 | self.l1 = nn.L1Loss()
79 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
80 |
81 | def construct(self, teacher, student, neg):
82 | expand_dims = ops.ExpandDims() # unsqueeze算子
83 | teacher_vgg, student_vgg, neg_vgg = self.vgg(teacher), self.vgg(student), self.vgg(neg)
84 |
85 | loss = 0
86 | for i in range(len(teacher_vgg)):
87 | neg_i = expand_dims(neg_vgg[i], 0) # [8, n_feats, w, h]
88 | # neg_i = neg_i.repeat(student_vgg[i].shape[0], axis=0) #TODO:1.3版本才会支持Tensor.repeat
89 | neg_i = np.repeat(neg_i, student_vgg[i].shape[0], axis=0) # [16, 8, n_feats, w, h]
90 | neg_i = neg_i.transpose((1, 0, 2, 3, 4)) # [8, 16, n_feats, w, h]
91 |
92 | d_ts = self.l1(stop_gradient(teacher_vgg[i]), student_vgg[i])
93 | # d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs().sum(axis=0).mean() #TODO:1.3版本才支持Tensor.sum
94 | d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs() # [8, 16, n_feats, w, h]
95 | # print(d_sn.shape)
96 | reduceSum = ops.ReduceSum()
97 | d_sn = reduceSum(d_sn, 0).mean()
98 | # print(d_sn)
99 |
100 | contrastive = d_ts / (d_sn + 1e-7)
101 | loss += self.weights[i] * contrastive
102 |
103 | return self.get_loss(loss)
104 |
--------------------------------------------------------------------------------
/MindSpore version/src/data/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """common"""
16 | import random
17 | import os
18 | import numpy as np
19 |
20 |
21 | def get_patch(*args, patch_size=96, scale=2, input_large=False):
22 | """get_patch"""
23 | ih, iw = args[0].shape[:2]
24 |
25 | tp = patch_size
26 | ip = tp // scale
27 |
28 | ix = random.randrange(0, iw - ip + 1)
29 | iy = random.randrange(0, ih - ip + 1)
30 |
31 | if not input_large:
32 | tx, ty = scale * ix, scale * iy
33 | else:
34 | tx, ty = ix, iy
35 |
36 | ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]
37 |
38 | return ret
39 |
40 |
41 | def set_channel(*args, n_channels=3):
42 | """set_channel"""
43 | def _set_channel(img):
44 | if img.ndim == 2:
45 | img = np.expand_dims(img, axis=2)
46 |
47 | c = img.shape[2]
48 | if n_channels == 3 and c == 1:
49 | img = np.concatenate([img] * n_channels, 2)
50 |
51 | return img[:, :, :n_channels]
52 |
53 | return [_set_channel(a) for a in args]
54 |
55 |
56 | def np2Tensor(*args, rgb_range=255):
57 | """ np2Tensor"""
58 | def _np2Tensor(img):
59 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
60 | input_data = np_transpose.astype(np.float32)
61 | output = input_data * (rgb_range / 255)
62 | return output
63 | return [_np2Tensor(a) for a in args]
64 |
65 |
66 | def augment(*args, hflip=True, rot=True):
67 | """augment("""
68 | hflip = hflip and random.random() < 0.5
69 | vflip = rot and random.random() < 0.5
70 | rot90 = rot and random.random() < 0.5
71 |
72 | def _augment(img):
73 | """augment"""
74 | if hflip:
75 | img = img[:, ::-1, :]
76 | if vflip:
77 | img = img[::-1, :, :]
78 | if rot90:
79 | img = img.transpose(1, 0, 2)
80 | return img
81 |
82 | return [_augment(a) for a in args]
83 |
84 |
85 | def search(root, target="JPEG"):
86 | """search"""
87 | item_list = []
88 | items = os.listdir(root)
89 | for item in items:
90 | path = os.path.join(root, item)
91 | if os.path.isdir(path):
92 | item_list.extend(search(path, target))
93 | elif path.split('/')[-1].startswith(target):
94 | item_list.append(path)
95 | elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
96 | item_list.append(path)
97 | return item_list
98 |
--------------------------------------------------------------------------------
/MindSpore version/src/data/div2k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """div2k"""
16 | import os
17 | from src.data.srdata import SRData
18 |
19 |
20 | class DIV2K(SRData):
21 | """DIV2K"""
22 | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
23 | self.dir_hr = None
24 | self.dir_lr = None
25 | data_range = [r.split('-') for r in args.data_range.split('/')]
26 | if train:
27 | data_range = data_range[0]
28 | else:
29 | if args.test_only and len(data_range) == 1:
30 | data_range = data_range[0]
31 | else:
32 | data_range = data_range[1]
33 |
34 | self.begin, self.end = list(map(int, data_range))
35 | super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
36 |
37 | def _scan(self):
38 | names_hr, names_lr = super(DIV2K, self)._scan()
39 | names_hr = names_hr[self.begin - 1:self.end]
40 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
41 |
42 | return names_hr, names_lr
43 |
44 | def _set_filesystem(self, dir_data):
45 | super(DIV2K, self)._set_filesystem(dir_data)
46 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
47 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
48 |
--------------------------------------------------------------------------------
/MindSpore version/src/data/srdata.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """"srdata"""
16 | import os
17 | import glob
18 | import random
19 | import pickle
20 | import imageio
21 | from src.data import common
22 | from PIL import ImageFile
23 |
24 | ImageFile.LOAD_TRUNCATED_IMAGES = True
25 |
26 |
27 | class SRData:
28 | """srdata"""
29 |
30 | def __init__(self, args, name='', train=True, benchmark=False):
31 | self.derain_lr_test = None
32 | self.derain_hr_test = None
33 | self.deblur_lr_test = None
34 | self.deblur_hr_test = None
35 | self.args = args
36 | self.name = name
37 | self.train = train
38 | self.split = 'train' if train else 'test'
39 | self.do_eval = True
40 | self.benchmark = benchmark
41 | self.input_large = (args.model == 'VDSR')
42 | self.scale = args.scale
43 | self.idx_scale = 0
44 | if benchmark:
45 | self._set_filesystem(os.path.join(args.dir_data, 'benchmark'))
46 | else:
47 | self._set_filesystem(args.dir_data)
48 | self._set_img(args)
49 | if train:
50 | self._repeat(args)
51 |
52 | def _set_img(self, args):
53 | """set_img"""
54 | if args.ext.find('img') < 0:
55 | path_bin = os.path.join(self.apath, 'bin')
56 | os.makedirs(path_bin, exist_ok=True)
57 | list_hr, list_lr = self._scan()
58 | if args.ext.find('img') >= 0 or self.benchmark:
59 | self.images_hr, self.images_lr = list_hr, list_lr
60 | elif args.ext.find('sep') >= 0:
61 | os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
62 | for s in self.scale:
63 | if s == 1:
64 | os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
65 | else:
66 | os.makedirs(
67 | os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
68 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
69 | for h in list_hr:
70 | b = h.replace(self.apath, path_bin)
71 | b = b.replace(self.ext[0], '.pt')
72 | self.images_hr.append(b)
73 | self._check_and_load(args.ext, h, b, verbose=True)
74 | for i, ll in enumerate(list_lr):
75 | for l in ll:
76 | b = l.replace(self.apath, path_bin)
77 | b = b.replace(self.ext[1], '.pt')
78 | self.images_lr[i].append(b)
79 | self._check_and_load(args.ext, l, b, verbose=True)
80 |
81 | def _repeat(self, args):
82 | """repeat"""
83 | n_patches = args.batch_size * args.test_every
84 | n_images = len(args.data_train) * len(self.images_hr)
85 | if n_images == 0:
86 | self.repeat = 0
87 | else:
88 | self.repeat = max(n_patches // n_images, 1)
89 |
90 | def _scan(self):
91 | """_scan"""
92 | names_hr = sorted(
93 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
94 | names_lr = [[] for _ in self.scale]
95 | for f in names_hr:
96 | filename, _ = os.path.splitext(os.path.basename(f))
97 | for si, s in enumerate(self.scale):
98 | if s != 1:
99 | scale = s
100 | names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
101 | .format(s, filename, scale, self.ext[1])))
102 | for si, s in enumerate(self.scale):
103 | if s == 1:
104 | names_lr[si] = names_hr
105 | return names_hr, names_lr
106 |
107 | def _set_filesystem(self, dir_data):
108 | """set_filesystem"""
109 | self.apath = os.path.join(dir_data, self.name[0])
110 | self.dir_hr = os.path.join(self.apath, 'HR')
111 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
112 | self.ext = ('.png', '.png')
113 |
114 | def _check_and_load(self, ext, img, f, verbose=True):
115 | """check_and_load"""
116 | if not os.path.isfile(f) or ext.find('reset') >= 0:
117 | if verbose:
118 | print('Making a binary: {}'.format(f))
119 | with open(f, 'wb') as _f:
120 | pickle.dump(imageio.imread(img), _f)
121 |
122 | def __getitem__(self, idx):
123 | """get item"""
124 | lr, hr, _ = self._load_file(idx)
125 | pair = self.get_patch(lr, hr)
126 | pair = common.set_channel(*pair, n_channels=self.args.n_colors)
127 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
128 | return pair_t[0], pair_t[1]
129 |
130 | def __len__(self):
131 | """length of hr"""
132 | if self.train:
133 | return len(self.images_hr) * self.repeat
134 | return len(self.images_hr)
135 |
136 | def _get_index(self, idx):
137 | """get_index"""
138 | if self.train:
139 | return idx % len(self.images_hr)
140 | return idx
141 |
142 | def _load_file_hr(self, idx):
143 | """load_file_hr"""
144 | idx = self._get_index(idx)
145 | f_hr = self.images_hr[idx]
146 | filename, _ = os.path.splitext(os.path.basename(f_hr))
147 | if self.args.ext == 'img' or self.benchmark:
148 | hr = imageio.imread(f_hr)
149 | elif self.args.ext.find('sep') >= 0:
150 | with open(f_hr, 'rb') as _f:
151 | hr = pickle.load(_f)
152 | return hr, filename
153 |
154 | def _load_file(self, idx):
155 | """load_file"""
156 | idx = self._get_index(idx)
157 | f_hr = self.images_hr[idx]
158 | f_lr = self.images_lr[self.idx_scale][idx]
159 | filename, _ = os.path.splitext(os.path.basename(f_hr))
160 | if self.args.ext == 'img' or self.benchmark:
161 | hr = imageio.imread(f_hr)
162 | lr = imageio.imread(f_lr)
163 | elif self.args.ext.find('sep') >= 0:
164 | with open(f_hr, 'rb') as _f:
165 | hr = pickle.load(_f)
166 | with open(f_lr, 'rb') as _f:
167 | lr = pickle.load(_f)
168 | return lr, hr, filename
169 |
170 | def get_patch_hr(self, hr):
171 | """get_patch_hr"""
172 | if self.train:
173 | hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
174 | return hr
175 |
176 | def get_patch_img_hr(self, img, patch_size=96, scale=2):
177 | """get_patch_img_hr"""
178 | ih, iw = img.shape[:2]
179 | tp = patch_size
180 | ip = tp // scale
181 | ix = random.randrange(0, iw - ip + 1)
182 | iy = random.randrange(0, ih - ip + 1)
183 | ret = img[iy:iy + ip, ix:ix + ip, :]
184 | return ret
185 |
186 | def get_patch(self, lr, hr):
187 | """get_patch"""
188 | scale = self.scale[self.idx_scale]
189 | if self.train:
190 | lr, hr = common.get_patch(
191 | lr, hr,
192 | patch_size=self.args.patch_size * scale,
193 | scale=scale)
194 | if not self.args.no_augment:
195 | lr, hr = common.augment(lr, hr)
196 | else:
197 | ih, iw = lr.shape[:2]
198 | hr = hr[0:ih * scale, 0:iw * scale]
199 | return lr, hr
200 |
201 | def set_scale(self, idx_scale):
202 | """set_scale"""
203 | if not self.input_large:
204 | self.idx_scale = idx_scale
205 | else:
206 | self.idx_scale = random.randint(0, len(self.scale) - 1)
207 |
--------------------------------------------------------------------------------
/MindSpore version/src/edsr_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """edsr"""
16 | import math
17 | import mindspore.ops as ops
18 | import mindspore.nn as nn
19 | import mindspore.common.dtype as mstype
20 | from mindspore.ops import operations as P
21 | from mindspore.common import Tensor, Parameter
22 |
23 |
24 | def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
25 | """edsr"""
26 | return nn.Conv2d(
27 | in_channels, out_channels, kernel_size,
28 | padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad')
29 |
30 |
31 | class MeanShift(nn.Conv2d):
32 | """edsr"""
33 | def __init__(self,
34 | rgb_range,
35 | rgb_mean=(0.4488, 0.4371, 0.4040),
36 | rgb_std=(1.0, 1.0, 1.0),
37 | sign=-1):
38 | """edsr"""
39 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
40 | self.reshape = P.Reshape()
41 | self.eye = P.Eye()
42 | std = Tensor(rgb_std, mstype.float32)
43 | self.weight.set_data(
44 | self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1)))
45 | self.weight.requires_grad = False
46 | self.bias = Parameter(
47 | sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False)
48 | self.has_bias = True
49 |
50 |
51 | def _pixelsf_(x, scale):
52 | """edsr"""
53 | n, c, ih, iw = x.shape
54 | oh = ih * scale
55 | ow = iw * scale
56 | oc = c // (scale ** 2)
57 | output = P.Transpose()(x, (0, 2, 1, 3))
58 | output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
59 | output = P.Transpose()(output, (0, 1, 2, 4, 3))
60 | output = P.Reshape()(output, (n, ih, oc, scale, ow))
61 | output = P.Transpose()(output, (0, 2, 1, 3, 4))
62 | output = P.Reshape()(output, (n, oc, oh, ow))
63 | return output
64 |
65 |
66 | class SmallUpSampler(nn.Cell):
67 | """edsr"""
68 | def __init__(self, conv, upsize, n_feats, has_bias=True):
69 | """edsr"""
70 | super(SmallUpSampler, self).__init__()
71 | self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
72 | self.reshape = P.Reshape()
73 | self.upsize = upsize
74 | self.pixelsf = _pixelsf_
75 |
76 | def construct(self, x):
77 | """edsr"""
78 | x = self.conv(x)
79 | output = self.pixelsf(x, self.upsize)
80 | return output
81 |
82 |
83 | class Upsampler(nn.Cell):
84 | """edsr"""
85 | def __init__(self, conv, scale, n_feats, has_bias=True):
86 | """edsr"""
87 | super(Upsampler, self).__init__()
88 | m = []
89 | if (scale & (scale - 1)) == 0:
90 | for _ in range(int(math.log(scale, 2))):
91 | m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
92 | elif scale == 3:
93 | m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
94 | self.net = nn.SequentialCell(m)
95 |
96 | def construct(self, x):
97 | """edsr"""
98 | return self.net(x)
99 |
100 |
101 | class AdaptiveAvgPool2d(nn.Cell):
102 | """edsr"""
103 | def __init__(self):
104 | """edsr"""
105 | super().__init__()
106 | self.ReduceMean = ops.ReduceMean(keep_dims=True)
107 |
108 | def construct(self, x):
109 | """edsr"""
110 | return self.ReduceMean(x, 0)
111 |
112 |
113 | class ResidualBlock(nn.Cell):
114 | """edsr"""
115 | def __init__(self, conv, n_feat, kernel_size, has_bias=True
116 | , bn=False, act=nn.ReLU(), res_scale=0.1):
117 | """edsr"""
118 | super(ResidualBlock, self).__init__()
119 | self.modules_body = []
120 | for i in range(2):
121 | self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias))
122 | if bn: self.modules_body.append(nn.BatchNorm2d(n_feat))
123 | if i == 0: self.modules_body.append(act)
124 | self.body = nn.SequentialCell(*self.modules_body)
125 | self.res_scale = res_scale
126 |
127 | def construct(self, x):
128 | """edsr"""
129 | res = self.body(x)
130 | res = res * self.res_scale
131 | return x + res
132 |
133 | class EDSR(nn.Cell):
134 | def __init__(self, args, conv=default_conv):
135 | super(EDSR, self).__init__()
136 | n_resblocks = args.n_resblocks
137 | n_feats = args.n_feats
138 | kernel_size = 3
139 | idx = args.task_id
140 | scale = args.scale[idx]
141 | self.dytpe = mstype.float16
142 |
143 | # RGB mean for DIV2K
144 | rgb_mean = (0.4488, 0.4371, 0.4040)
145 | rgb_std = (1.0, 1.0, 1.0)
146 |
147 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
148 |
149 | # define head module
150 | modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
151 |
152 | m_body = [
153 | ResidualBlock(
154 | conv, n_feats, kernel_size, res_scale=args.res_scale,
155 | ).to_float(self.dytpe) for _ in range(n_resblocks)
156 | ]
157 | m_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
158 |
159 | m_tail = [
160 | Upsampler(conv, scale, n_feats).to_float(self.dytpe),
161 | conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)
162 | ]
163 |
164 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
165 |
166 | self.head = modules_head
167 | self.body = nn.SequentialCell(m_body)
168 | self.tail = nn.SequentialCell(m_tail)
169 |
170 | def construct(self, x):
171 | x = self.sub_mean(x)
172 | x = self.head(x)
173 | x = x + self.body(x)
174 | x = self.tail(x)
175 | x = self.add_mean(x)
176 | return x
177 |
--------------------------------------------------------------------------------
/MindSpore version/src/edsr_slim.py:
--------------------------------------------------------------------------------
1 | import mindspore
2 | from src import common
3 | from src.edsr_model import Upsampler, default_conv
4 |
5 | import mindspore.nn as nn
6 | import mindspore.ops as ops
7 | import mindspore.ops.operations as P
8 | from mindspore import Tensor
9 |
10 | class EDSR(nn.Cell):
11 | def __init__(self, args):
12 | super(EDSR, self).__init__()
13 |
14 | self.n_colors = args.n_colors
15 | n_resblocks = args.n_resblocks
16 | self.n_feats = args.n_feats
17 | self.kernel_size = 3
18 | scale = args.scale[0]
19 | act = nn.ReLU()
20 | self.rgb_range = args.rgb_range
21 |
22 | # self.head = nn.Conv2d(in_channels=args.n_colors, out_channels=self.n_feats, kernel_size=self.kernel_size, pad_mode='pad', padding=self.kernel_size // 2, has_bias=True)
23 | self.head = common.conv(args.n_colors, self.n_feats, self.kernel_size, padding=self.kernel_size//2)
24 |
25 | m_body = [
26 | common.ResidualBlock(
27 | self.n_feats, self.kernel_size, act=act, res_scale=args.res_scale
28 | ) for _ in range(n_resblocks)
29 | ]
30 | self.body = nn.CellList(m_body)
31 | # self.body = m_body ###如果用这行,body这部分参数不会被训练
32 | self.body_conv = common.conv(self.n_feats, self.n_feats, self.kernel_size, padding=self.kernel_size//2)
33 |
34 | self.upsampler = common.Upsampler(scale, self.n_feats)
35 | self.tail_conv = common.conv(self.n_feats, args.n_colors, self.kernel_size, padding=self.kernel_size//2)
36 |
37 | def construct(self, x, width_mult=Tensor(1.0)):
38 | # def construct(self, x, width_mult):
39 | width_mult = width_mult.asnumpy().item()
40 | feature_width = int(self.n_feats * width_mult)
41 | conv2d = ops.Conv2D(out_channel=feature_width, kernel_size=self.kernel_size, mode=1, pad_mode='pad',
42 | pad=self.kernel_size // 2)
43 | biasadd = ops.BiasAdd()
44 |
45 | x = common.MeanShift(x, self.rgb_range)
46 | #原来写的是weight.clone()[]
47 | weight = self.head.weight[:feature_width, :self.n_colors, :, :]
48 | bias = self.head.bias[:feature_width]
49 | x = conv2d(x, weight)
50 | x = biasadd(x, bias)
51 |
52 | residual = x
53 | for block in self.body:
54 | residual = block(residual, width_mult)
55 | weight = self.body_conv.weight[:feature_width, :feature_width, :, :]
56 | bias = self.body_conv.bias[:feature_width]
57 | residual = conv2d(residual, weight)
58 | residual = biasadd(residual, bias)
59 | residual += x
60 |
61 | x = self.upsampler(residual, width_mult)
62 | weight = self.tail_conv.weight[:self.n_colors, :feature_width, :, :]
63 | bias = self.tail_conv.bias[:self.n_colors]
64 | conv2d = ops.Conv2D(out_channel=self.n_colors, kernel_size=self.kernel_size, mode=1, pad_mode='pad', pad=self.kernel_size//2)
65 | x = conv2d(x, weight)
66 | x = biasadd(x, bias)
67 | x = common.MeanShift(x, self.rgb_range, sign=1)
68 |
69 | return x
--------------------------------------------------------------------------------
/MindSpore version/src/metric.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Metric for evaluation."""
16 | import os
17 | import math
18 |
19 | from PIL import Image
20 | import numpy as np
21 | from mindspore import nn, Tensor, ops
22 | from mindspore import dtype as mstype
23 | from mindspore.ops.operations.comm_ops import ReduceOp
24 |
25 | try:
26 | from model_utils.device_adapter import get_rank_id, get_device_num
27 | except ImportError:
28 | get_rank_id = None
29 | get_device_num = None
30 | finally:
31 | pass
32 |
33 |
34 | class SelfEnsembleWrapperNumpy:
35 | """
36 | SelfEnsembleWrapperNumpy using numpy
37 | """
38 |
39 | def __init__(self, net):
40 | super(SelfEnsembleWrapperNumpy, self).__init__()
41 | self.net = net
42 |
43 | def hflip(self, x):
44 | return x[:, :, :, ::-1]
45 |
46 | def vflip(self, x):
47 | return x[:, :, ::-1, :]
48 |
49 | def trnsps(self, x):
50 | return x.transpose(0, 1, 3, 2)
51 |
52 | def aug_x8(self, x):
53 | """
54 | do x8 augments for input image
55 | """
56 | # hflip
57 | hx = self.hflip(x)
58 | # vflip
59 | vx = self.vflip(x)
60 | vhx = self.vflip(hx)
61 | # trnsps
62 | tx = self.trnsps(x)
63 | thx = self.trnsps(hx)
64 | tvx = self.trnsps(vx)
65 | tvhx = self.trnsps(vhx)
66 | return x, hx, vx, vhx, tx, thx, tvx, tvhx
67 |
68 | def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
69 | """
70 | undo x8 augments for input images
71 | """
72 | # trnsps
73 | tvhx = self.trnsps(tvhx)
74 | tvx = self.trnsps(tvx)
75 | thx = self.trnsps(thx)
76 | tx = self.trnsps(tx)
77 | # vflip
78 | tvhx = self.vflip(tvhx)
79 | tvx = self.vflip(tvx)
80 | vhx = self.vflip(vhx)
81 | vx = self.vflip(vx)
82 | # hflip
83 | tvhx = self.hflip(tvhx)
84 | thx = self.hflip(thx)
85 | vhx = self.hflip(vhx)
86 | hx = self.hflip(hx)
87 | return x, hx, vx, vhx, tx, thx, tvx, tvhx
88 |
89 | def to_numpy(self, *inputs):
90 | if inputs:
91 | return None
92 | if len(inputs) == 1:
93 | return inputs[0].asnumpy()
94 | return [x.asnumpy() for x in inputs]
95 |
96 | def to_tensor(self, *inputs):
97 | if inputs:
98 | return None
99 | if len(inputs) == 1:
100 | return Tensor(inputs[0])
101 | return [Tensor(x) for x in inputs]
102 |
103 | def set_train(self, mode=True):
104 | self.net.set_train(mode)
105 | return self
106 |
107 | def __call__(self, x):
108 | x = self.to_numpy(x)
109 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
110 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
111 | x0 = self.net(x0)
112 | x1 = self.net(x1)
113 | x2 = self.net(x2)
114 | x3 = self.net(x3)
115 | x4 = self.net(x4)
116 | x5 = self.net(x5)
117 | x6 = self.net(x6)
118 | x7 = self.net(x7)
119 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_numpy(x0, x1, x2, x3, x4, x5, x6, x7)
120 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
121 | x0, x1, x2, x3, x4, x5, x6, x7 = self.to_tensor(x0, x1, x2, x3, x4, x5, x6, x7)
122 | return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
123 |
124 |
125 | class SelfEnsembleWrapper(nn.Cell):
126 | """
127 | because of [::-1] operator error, use "SelfEnsembleWrapperNumpy" instead
128 | """
129 | def __init__(self, net):
130 | super(SelfEnsembleWrapper, self).__init__()
131 | self.net = net
132 |
133 | def hflip(self, x):
134 | raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
135 |
136 | def vflip(self, x):
137 | raise NotImplementedError("https://gitee.com/mindspore/mindspore/issues/I41ONQ?from=project-issue")
138 |
139 | def trnsps(self, x):
140 | return x.transpose(0, 1, 3, 2)
141 |
142 | def aug_x8(self, x):
143 | """
144 | do x8 augments for input image
145 | """
146 | # hflip
147 | hx = self.hflip(x)
148 | # vflip
149 | vx = self.vflip(x)
150 | vhx = self.vflip(hx)
151 | # trnsps
152 | tx = self.trnsps(x)
153 | thx = self.trnsps(hx)
154 | tvx = self.trnsps(vx)
155 | tvhx = self.trnsps(vhx)
156 | return x, hx, vx, vhx, tx, thx, tvx, tvhx
157 |
158 | def aug_x8_reverse(self, x, hx, vx, vhx, tx, thx, tvx, tvhx):
159 | """
160 | undo x8 augments for input images
161 | """
162 | # trnsps
163 | tvhx = self.trnsps(tvhx)
164 | tvx = self.trnsps(tvx)
165 | thx = self.trnsps(thx)
166 | tx = self.trnsps(tx)
167 | # vflip
168 | tvhx = self.vflip(tvhx)
169 | tvx = self.vflip(tvx)
170 | vhx = self.vflip(vhx)
171 | vx = self.vflip(vx)
172 | # hflip
173 | tvhx = self.hflip(tvhx)
174 | thx = self.hflip(thx)
175 | vhx = self.hflip(vhx)
176 | hx = self.hflip(hx)
177 | return x, hx, vx, vhx, tx, thx, tvx, tvhx
178 |
179 | def construct(self, x):
180 | """
181 | do x8 aug, run network, undo x8 aug, calculate mean for 8 output
182 | """
183 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8(x)
184 | x0 = self.net(x0)
185 | x1 = self.net(x1)
186 | x2 = self.net(x2)
187 | x3 = self.net(x3)
188 | x4 = self.net(x4)
189 | x5 = self.net(x5)
190 | x6 = self.net(x6)
191 | x7 = self.net(x7)
192 | x0, x1, x2, x3, x4, x5, x6, x7 = self.aug_x8_reverse(x0, x1, x2, x3, x4, x5, x6, x7)
193 | return (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
194 |
195 |
196 | class Quantizer(nn.Cell):
197 | """
198 | clip by [0.0, 255.0], rount to int
199 | """
200 | def __init__(self, _min=0.0, _max=255.0):
201 | super(Quantizer, self).__init__()
202 | self.round = ops.Round()
203 | self._min = _min
204 | self._max = _max
205 |
206 | def construct(self, x):
207 | x = ops.clip_by_value(x, self._min, self._max)
208 | x = self.round(x)
209 | return x
210 |
211 |
212 | class TensorSyncer(nn.Cell):
213 | """
214 | sync metric values from all mindspore-processes
215 | """
216 | def __init__(self, _type="sum"):
217 | super(TensorSyncer, self).__init__()
218 | self._type = _type.lower()
219 | if self._type == "sum":
220 | self.ops = ops.AllReduce(ReduceOp.SUM)
221 | elif self._type == "gather":
222 | self.ops = ops.AllGather()
223 | else:
224 | raise ValueError(f"TensorSyncer._type == {self._type} is not support")
225 |
226 | def construct(self, x):
227 | return self.ops(x)
228 |
229 |
230 | class _DistMetric(nn.Metric):
231 | """
232 | gather data from all rank while eval(True)
233 | _type(str): choice from ["avg", "sum"].
234 | """
235 | def __init__(self, _type):
236 | super(_DistMetric, self).__init__()
237 | self._type = _type.lower()
238 | self.all_reduce_sum = None
239 | if get_device_num is not None and get_device_num() > 1:
240 | self.all_reduce_sum = TensorSyncer(_type="sum")
241 | self.clear()
242 |
243 | def _accumulate(self, value):
244 | if isinstance(value, (list, tuple)):
245 | self._acc_value += sum(value)
246 | self._count += len(value)
247 | else:
248 | self._acc_value += value
249 | self._count += 1
250 |
251 | def clear(self):
252 | self._acc_value = 0.0
253 | self._count = 0
254 |
255 | def eval(self, sync=True):
256 | """
257 | sync: True, return metric value merged from all mindspore-processes
258 | sync: False, return metric value in this single mindspore-processes
259 | """
260 | if self._count == 0:
261 | raise RuntimeError('self._count == 0')
262 | if self.sum is not None and sync:
263 | data = Tensor([self._acc_value, self._count], mstype.float32)
264 | data = self.all_reduce_sum(data)
265 | acc_value, count = self._convert_data(data).tolist()
266 | else:
267 | acc_value, count = self._acc_value, self._count
268 | if self._type == "avg":
269 | return acc_value / count
270 | if self._type == "sum":
271 | return acc_value
272 | raise RuntimeError(f"_DistMetric._type={self._type} is not support")
273 |
274 |
275 | class PSNR(_DistMetric):
276 | """
277 | Define PSNR metric for SR network.
278 | """
279 | def __init__(self, rgb_range, shave):
280 | super(PSNR, self).__init__(_type="avg")
281 | self.shave = shave
282 | self.rgb_range = rgb_range
283 | self.quantize = Quantizer(0.0, 255.0)
284 |
285 | def update(self, *inputs):
286 | """
287 | update psnr
288 | """
289 | if len(inputs) != 2:
290 | raise ValueError('PSNR need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
291 | sr, hr = inputs
292 | sr = self.quantize(sr)
293 | diff = (sr - hr) / self.rgb_range
294 | valid = diff
295 | if self.shave is not None and self.shave != 0:
296 | valid = valid[..., self.shave:(-self.shave), self.shave:(-self.shave)]
297 | mse_list = (valid ** 2).mean(axis=(1, 2, 3))
298 | mse_list = self._convert_data(mse_list).tolist()
299 | psnr_list = [float(1e32) if mse == 0 else(- 10.0 * math.log10(mse)) for mse in mse_list]
300 | self._accumulate(psnr_list)
301 |
302 |
303 | class SaveSrHr(_DistMetric):
304 | """
305 | help to save sr and hr
306 | """
307 | def __init__(self, save_dir):
308 | super(SaveSrHr, self).__init__(_type="sum")
309 | self.save_dir = save_dir
310 | self.quantize = Quantizer(0.0, 255.0)
311 | self.rank_id = 0 if get_rank_id is None else get_rank_id()
312 | self.device_num = 1 if get_device_num is None else get_device_num()
313 |
314 | def update(self, *inputs):
315 | """
316 | update images to save
317 | """
318 | if len(inputs) != 2:
319 | raise ValueError('SaveSrHr need 2 inputs (sr, hr), but got {}'.format(len(inputs)))
320 | sr, hr = inputs
321 | sr = self.quantize(sr)
322 | sr = self._convert_data(sr).astype(np.uint8)
323 | hr = self._convert_data(hr).astype(np.uint8)
324 | for s, h in zip(sr.transpose(0, 2, 3, 1), hr.transpose(0, 2, 3, 1)):
325 | idx = self._count * self.device_num + self.rank_id
326 | sr_path = os.path.join(self.save_dir, f"{idx:0>4}_sr.png")
327 | Image.fromarray(s).save(sr_path)
328 | hr_path = os.path.join(self.save_dir, f"{idx:0>4}_hr.png")
329 | Image.fromarray(h).save(hr_path)
330 | self._accumulate(1)
331 |
--------------------------------------------------------------------------------
/MindSpore version/src/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """metrics"""
16 | import math
17 | import numpy as np
18 | import cv2
19 |
20 |
21 | def quantize(img, rgb_range):
22 | """quantize image range to 0-255"""
23 | pixel_range = 255 / rgb_range
24 | img = np.multiply(img, pixel_range)
25 | img = np.clip(img, 0, 255)
26 | img = np.round(img) / pixel_range
27 | return img
28 |
29 |
30 | def calc_psnr(sr, hr, scale, rgb_range):
31 | """calculate psnr"""
32 | hr = np.float32(hr)
33 | sr = np.float32(sr)
34 | diff = (sr - hr) / rgb_range
35 | gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
36 | diff = np.multiply(diff, gray_coeffs).sum(1)
37 | if hr.size == 1:
38 | return 0
39 |
40 | shave = scale
41 | valid = diff[..., shave:-shave, shave:-shave]
42 | mse = np.mean(pow(valid, 2))
43 | return -10 * math.log10(mse)
44 |
45 |
46 | def rgb2ycbcr(img, y_only=True):
47 | """from rgb space to ycbcr space"""
48 | img.astype(np.float32)
49 | if y_only:
50 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
51 | return rlt
52 |
53 |
54 | def calc_ssim(img1, img2, scale):
55 | """calculate ssim"""
56 | def ssim(img1, img2):
57 | """calculate ssim"""
58 | C1 = (0.01 * 255) ** 2
59 | C2 = (0.03 * 255) ** 2
60 |
61 | img1 = img1.astype(np.float64)
62 | img2 = img2.astype(np.float64)
63 | kernel = cv2.getGaussianKernel(11, 1.5)
64 | window = np.outer(kernel, kernel.transpose())
65 |
66 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
67 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
68 | mu1_sq = mu1 ** 2
69 | mu2_sq = mu2 ** 2
70 | mu1_mu2 = mu1 * mu2
71 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
72 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
73 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
74 |
75 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
76 | (sigma1_sq + sigma2_sq + C2))
77 | return ssim_map.mean()
78 |
79 | border = scale
80 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 256.0 + 16.0
81 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 256.0 + 16.0
82 | if not img1.shape == img2.shape:
83 | raise ValueError('Input images must have the same dimensions.')
84 | h, w = img1.shape[:2]
85 | img1_y = img1_y[border:h - border, border:w - border]
86 | img2_y = img2_y[border:h - border, border:w - border]
87 |
88 | if img1_y.ndim == 2:
89 | return ssim(img1_y, img2_y)
90 | if img1.ndim == 3:
91 | if img1.shape[2] == 3:
92 | ssims = []
93 | for _ in range(3):
94 | ssims.append(ssim(img1, img2))
95 |
96 | return np.array(ssims).mean()
97 | if img1.shape[2] == 1:
98 | return ssim(np.squeeze(img1), np.squeeze(img2))
99 |
100 | raise ValueError('Wrong input image dimensions.')
101 |
--------------------------------------------------------------------------------
/MindSpore version/src/rcan_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """rcan"""
16 | import math
17 | import mindspore.ops as ops
18 | import mindspore.nn as nn
19 | import mindspore.common.dtype as mstype
20 | from mindspore.ops import operations as P
21 | from mindspore.common import Tensor, Parameter
22 |
23 |
24 | def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
25 | """rcan"""
26 | return nn.Conv2d(
27 | in_channels, out_channels, kernel_size,
28 | padding=(kernel_size // 2), has_bias=has_bias, pad_mode='pad')
29 |
30 |
31 | class MeanShift(nn.Conv2d):
32 | """rcan"""
33 | def __init__(self,
34 | rgb_range,
35 | rgb_mean=(0.4488, 0.4371, 0.4040),
36 | rgb_std=(1.0, 1.0, 1.0),
37 | sign=-1):
38 | """rcan"""
39 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
40 | self.reshape = P.Reshape()
41 | self.eye = P.Eye()
42 | std = Tensor(rgb_std, mstype.float32)
43 | self.weight.set_data(
44 | self.reshape(self.eye(3, 3, mstype.float32), (3, 3, 1, 1)) / self.reshape(std, (3, 1, 1, 1)))
45 | self.weight.requires_grad = False
46 | self.bias = Parameter(
47 | sign * rgb_range * Tensor(rgb_mean, mstype.float32) / std, name='bias', requires_grad=False)
48 | self.has_bias = True
49 |
50 |
51 | def _pixelsf_(x, scale):
52 | """rcan"""
53 | n, c, ih, iw = x.shape
54 | oh = ih * scale
55 | ow = iw * scale
56 | oc = c // (scale ** 2)
57 | output = P.Transpose()(x, (0, 2, 1, 3))
58 | output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
59 | output = P.Transpose()(output, (0, 1, 2, 4, 3))
60 | output = P.Reshape()(output, (n, ih, oc, scale, ow))
61 | output = P.Transpose()(output, (0, 2, 1, 3, 4))
62 | output = P.Reshape()(output, (n, oc, oh, ow))
63 | return output
64 |
65 |
66 | class SmallUpSampler(nn.Cell):
67 | """rcan"""
68 | def __init__(self, conv, upsize, n_feats, has_bias=True):
69 | """rcan"""
70 | super(SmallUpSampler, self).__init__()
71 | self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
72 | self.reshape = P.Reshape()
73 | self.upsize = upsize
74 | self.pixelsf = _pixelsf_
75 |
76 | def construct(self, x):
77 | """rcan"""
78 | x = self.conv(x)
79 | output = self.pixelsf(x, self.upsize)
80 | return output
81 |
82 |
83 | class Upsampler(nn.Cell):
84 | """rcan"""
85 | def __init__(self, conv, scale, n_feats, has_bias=True):
86 | """rcan"""
87 | super(Upsampler, self).__init__()
88 | m = []
89 | if (scale & (scale - 1)) == 0:
90 | for _ in range(int(math.log(scale, 2))):
91 | m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
92 | elif scale == 3:
93 | m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
94 | self.net = nn.SequentialCell(m)
95 |
96 | def construct(self, x):
97 | """rcan"""
98 | return self.net(x)
99 |
100 |
101 | class AdaptiveAvgPool2d(nn.Cell):
102 | """rcan"""
103 | def __init__(self):
104 | """rcan"""
105 | super().__init__()
106 | self.ReduceMean = ops.ReduceMean(keep_dims=True)
107 |
108 | def construct(self, x):
109 | """rcan"""
110 | return self.ReduceMean(x, 0)
111 |
112 |
113 | class CALayer(nn.Cell):
114 | """rcan"""
115 | def __init__(self, channel, reduction=16):
116 | """rcan"""
117 | super(CALayer, self).__init__()
118 | # global average pooling: feature --> point
119 | self.avg_pool = AdaptiveAvgPool2d()
120 | # feature channel downscale and upscale --> channel weight
121 | self.conv_du = nn.SequentialCell([
122 | nn.Conv2d(channel, channel // reduction, 1, padding=0, has_bias=True, pad_mode='pad'),
123 | nn.ReLU(),
124 | nn.Conv2d(channel // reduction, channel, 1, padding=0, has_bias=True, pad_mode='pad'),
125 | nn.Sigmoid()
126 | ])
127 |
128 | def construct(self, x):
129 | """rcan"""
130 | y = self.avg_pool(x)
131 | y = self.conv_du(y)
132 | return x * y
133 |
134 |
135 | class RCAB(nn.Cell):
136 | """rcan"""
137 | def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True
138 | , bn=False, act=nn.ReLU(), res_scale=1):
139 | """rcan"""
140 | super(RCAB, self).__init__()
141 | self.modules_body = []
142 | for i in range(2):
143 | self.modules_body.append(conv(n_feat, n_feat, kernel_size, has_bias=has_bias))
144 | if bn: self.modules_body.append(nn.BatchNorm2d(n_feat))
145 | if i == 0: self.modules_body.append(act)
146 | self.modules_body.append(CALayer(n_feat, reduction))
147 | self.body = nn.SequentialCell(*self.modules_body)
148 | self.res_scale = res_scale
149 |
150 | def construct(self, x):
151 | """rcan"""
152 | res = self.body(x)
153 | res += x
154 | return res
155 |
156 |
157 | class ResidualGroup(nn.Cell):
158 | """rcan"""
159 | def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks):
160 | """rcan"""
161 | super(ResidualGroup, self).__init__()
162 | modules_body = []
163 | modules_body = [
164 | RCAB(
165 | conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
166 | for _ in range(n_resblocks)]
167 | modules_body.append(conv(n_feat, n_feat, kernel_size))
168 | self.body = nn.SequentialCell(*modules_body)
169 |
170 | def construct(self, x):
171 | """rcan"""
172 | res = self.body(x)
173 | res += x
174 | return res
175 |
176 |
177 | class RCAN(nn.Cell):
178 | """rcan"""
179 | def __init__(self, args, conv=default_conv):
180 | """rcan"""
181 | super(RCAN, self).__init__()
182 |
183 | n_resgroups = args.n_resgroups
184 | n_resblocks = args.n_resblocks
185 | n_feats = args.n_feats
186 | kernel_size = 3
187 | reduction = args.reduction
188 | idx = args.task_id
189 | scale = args.scale[idx]
190 | self.dytpe = mstype.float16
191 |
192 | # RGB mean for DIV2K
193 | rgb_mean = (0.4488, 0.4371, 0.4040)
194 | rgb_std = (1.0, 1.0, 1.0)
195 |
196 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
197 |
198 | # define head module
199 | modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
200 |
201 | # define body module
202 | modules_body = [
203 | ResidualGroup(
204 | conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \
205 | for _ in range(n_resgroups)]
206 |
207 | modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
208 |
209 | # define tail module
210 | modules_tail = [
211 | Upsampler(conv, scale, n_feats).to_float(self.dytpe),
212 | conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)]
213 |
214 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
215 |
216 | self.head = modules_head
217 | self.body = nn.SequentialCell(modules_body)
218 | self.tail = nn.SequentialCell(modules_tail)
219 |
220 | def construct(self, x):
221 | """rcan"""
222 | x = self.sub_mean(x)
223 | x = self.head(x)
224 | res = self.body(x)
225 | res += x
226 | x = self.tail(res)
227 | x = self.add_mean(x)
228 | return x
229 |
--------------------------------------------------------------------------------
/MindSpore version/src/vgg_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """
16 | Image classifiation.
17 | """
18 | import math
19 | import mindspore.nn as nn
20 | import mindspore.common.dtype as mstype
21 | from mindspore.common import initializer as init
22 | from mindspore.common.initializer import initializer
23 | from utils.var_init import default_recurisive_init, KaimingNormal
24 |
25 |
26 | def _make_layer(base, args, batch_norm):
27 | """Make stage network of VGG."""
28 | layers = []
29 | in_channels = 3
30 | for v in base:
31 | if v == 'M':
32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
33 | else:
34 | weight = 'ones'
35 | if args.initialize_mode == "XavierUniform":
36 | weight_shape = (v, in_channels, 3, 3)
37 | weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32)
38 |
39 | conv2d = nn.Conv2d(in_channels=in_channels,
40 | out_channels=v,
41 | kernel_size=3,
42 | padding=args.padding,
43 | pad_mode=args.pad_mode,
44 | has_bias=args.has_bias,
45 | weight_init=weight)
46 | if batch_norm:
47 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
48 | else:
49 | layers += [conv2d, nn.ReLU()]
50 | in_channels = v
51 | return nn.SequentialCell(layers)
52 |
53 |
54 | class Vgg(nn.Cell):
55 | """
56 | VGG network definition.
57 |
58 | Args:
59 | base (list): Configuration for different layers, mainly the channel number of Conv layer.
60 | num_classes (int): Class numbers. Default: 1000.
61 | batch_norm (bool): Whether to do the batchnorm. Default: False.
62 | batch_size (int): Batch size. Default: 1.
63 | include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True.
64 |
65 | Returns:
66 | Tensor, infer output tensor.
67 |
68 | Examples:
69 | >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
70 | >>> num_classes=1000, batch_norm=False, batch_size=1)
71 | """
72 |
73 | def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train",
74 | include_top=True):
75 | super(Vgg, self).__init__()
76 | _ = batch_size
77 | self.layers = _make_layer(base, args, batch_norm=batch_norm)
78 | self.include_top = include_top
79 | self.flatten = nn.Flatten()
80 | dropout_ratio = 0.5
81 | # if not args.has_dropout or phase == "test":
82 | # dropout_ratio = 1.0
83 | self.classifier = nn.SequentialCell([
84 | nn.Dense(512 * 7 * 7, 4096),
85 | nn.ReLU(),
86 | nn.Dropout(dropout_ratio),
87 | nn.Dense(4096, 4096),
88 | nn.ReLU(),
89 | nn.Dropout(dropout_ratio),
90 | nn.Dense(4096, num_classes)])
91 | # if args.initialize_mode == "KaimingNormal":
92 | # default_recurisive_init(self)
93 | # self.custom_init_weight()
94 |
95 | def construct(self, x):
96 | x = self.layers(x)
97 | if self.include_top:
98 | x = self.flatten(x)
99 | x = self.classifier(x)
100 | return x
101 |
102 | def custom_init_weight(self):
103 | """
104 | Init the weight of Conv2d and Dense in the net.
105 | """
106 | for _, cell in self.cells_and_names():
107 | if isinstance(cell, nn.Conv2d):
108 | cell.weight.set_data(init.initializer(
109 | KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
110 | cell.weight.shape, cell.weight.dtype))
111 | if cell.bias is not None:
112 | cell.bias.set_data(init.initializer(
113 | 'zeros', cell.bias.shape, cell.bias.dtype))
114 | elif isinstance(cell, nn.Dense):
115 | cell.weight.set_data(init.initializer(
116 | init.Normal(0.01), cell.weight.shape, cell.weight.dtype))
117 | if cell.bias is not None:
118 | cell.bias.set_data(init.initializer(
119 | 'zeros', cell.bias.shape, cell.bias.dtype))
120 |
121 |
122 | cfg = {
123 | '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
124 | '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
125 | '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
126 | '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
127 | }
128 |
129 |
130 | def vgg16(num_classes=1000, args=None, phase="train", **kwargs):
131 | """
132 | Get Vgg16 neural network with batch normalization.
133 |
134 | Args:
135 | num_classes (int): Class numbers. Default: 1000.
136 | args(namespace): param for net init.
137 | phase(str): train or test mode.
138 |
139 | Returns:
140 | Cell, cell instance of Vgg16 neural network with batch normalization.
141 |
142 | Examples:
143 | >>> vgg16(num_classes=1000, args=args, **kwargs)
144 | """
145 |
146 | if args is None:
147 | from .config import cifar_cfg
148 | args = cifar_cfg
149 | net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase, **kwargs)
150 | return net
151 |
--------------------------------------------------------------------------------
/MindSpore version/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """train"""
16 | import os
17 | import time
18 | from mindspore import context
19 | from mindspore.context import ParallelMode
20 | import mindspore.dataset as ds
21 | import mindspore.nn as nn
22 | from mindspore.train.serialization import load_checkpoint, load_param_into_net
23 | from mindspore.communication.management import init
24 | from mindspore.common import set_seed
25 | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
26 | from mindspore.train.model import Model
27 | from mindspore.train.loss_scale_manager import DynamicLossScaleManager
28 | from src.args import args
29 | from src.data.div2k import DIV2K
30 | from src.data.srdata import SRData
31 | from src.rcan_model import RCAN
32 | # from src.edsr_model import EDSR
33 | from src.edsr_slim import EDSR
34 | from src.metric import PSNR
35 | from eval import do_eval
36 |
37 | # def do_eval(eval_network, ds_val, metrics, rank_id, cur_epoch=None):
38 | # """
39 | # do eval for psnr and save hr, sr
40 | # """
41 | # eval_network.set_train(False)
42 | # total_step = ds_val.get_dataset_size()
43 | # setw = len(str(total_step))
44 | # begin = time.time()
45 | # step_begin = time.time()
46 | # # rank_id = get_rank_id()
47 | # ds_val = ds_val.create_dict_iterator(output_numpy=True)
48 | # for i, (lr, hr) in enumerate(ds_val):
49 | # sr = eval_network(lr)
50 | # _ = [m.update(sr, hr) for m in metrics.values()]
51 | # result = {k: m.eval(sync=False) for k, m in metrics.items()}
52 | # result["time"] = time.time() - step_begin
53 | # step_begin = time.time()
54 | # print(f"[{i+1:>{setw}}/{total_step:>{setw}}] rank = {rank_id} result = {result}", flush=True)
55 | # result = {k: m.eval(sync=True) for k, m in metrics.items()}
56 | # result["time"] = time.time() - begin
57 | # if cur_epoch is not None:
58 | # result["epoch"] = cur_epoch
59 | # if rank_id == 0:
60 | # print(f"evaluation result = {result}", flush=True)
61 | # eval_network.set_train(True)
62 | # return result
63 |
64 | class EvalCallBack(Callback):
65 | """
66 | eval callback
67 | """
68 | def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, rank_id, result_evaluation=None):
69 | self.eval_network = eval_network
70 | self.ds_val = ds_val
71 | self.eval_epoch_frq = eval_epoch_frq
72 | self.epoch_size = epoch_size
73 | self.result_evaluation = result_evaluation
74 | self.metrics = metrics
75 | self.best_result = 0
76 | self.rank_id = rank_id
77 | self.eval_network.set_train(False)
78 |
79 | def epoch_end(self, run_context):
80 | """
81 | do eval in epoch end
82 | """
83 | cb_param = run_context.original_args()
84 | cur_epoch = cb_param.cur_epoch_num
85 | if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size:
86 | # result = do_eval(self.eval_network, self.ds_val, self.metrics, self.rank_id, cur_epoch=cur_epoch)
87 | result = do_eval(self.ds_val, self.eval_network)
88 | if self.best_result is None or self.best_result < result:
89 | self.best_result = result
90 | if self.rank_id == 0:
91 | print(f"best evaluation result = {self.best_result}", flush=True)
92 | if isinstance(self.result_evaluation, dict):
93 | for k, v in result.items():
94 | r_list = self.result_evaluation.get(k)
95 | if r_list is None:
96 | r_list = []
97 | self.result_evaluation[k] = r_list
98 | r_list.append(v)
99 |
100 |
101 | def train():
102 | """train"""
103 | set_seed(1)
104 | device_id = int(os.getenv('DEVICE_ID', '0'))
105 | rank_id = int(os.getenv('RANK_ID', '0'))
106 | device_num = int(os.getenv('RANK_SIZE', '1'))
107 | # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False, device_id=device_id)
108 | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False, device_id=device_id)
109 |
110 | if device_num > 1:
111 | init()
112 | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
113 | device_num=device_num, global_rank=device_id,
114 | gradients_mean=True)
115 | if args.modelArts_mode:
116 | import moxing as mox
117 | local_data_url = '/cache/data'
118 | mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
119 |
120 | train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
121 | train_dataset.set_scale(args.task_id)
122 | print(len(train_dataset))
123 | train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
124 | shard_id=rank_id, shuffle=True)
125 | train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
126 |
127 | eval_dataset = SRData(args, name=args.data_test, train=False, benchmark=True)
128 | print(len(eval_dataset))
129 | eval_ds = ds.GeneratorDataset(eval_dataset, ['LR', 'HR'], shuffle=False)
130 | eval_ds = eval_ds.batch(1, drop_remainder=True)
131 |
132 | # net_m = RCAN(args)
133 | net_m = EDSR(args)
134 | print("Init net weights successfully")
135 |
136 | if args.ckpt_path:
137 | param_dict = load_checkpoint(args.pth_path)
138 | load_param_into_net(net_m, param_dict)
139 | print("Load net weight successfully")
140 | step_size = train_de_dataset.get_dataset_size()
141 | lr = []
142 | for i in range(0, args.epochs):
143 | cur_lr = args.lr / (2 ** ((i + 1) // 200))
144 | lr.extend([cur_lr] * step_size)
145 | opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
146 | loss = nn.L1Loss()
147 | loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
148 | scale_factor=2, scale_window=1000)
149 |
150 | eval_net = net_m
151 | model = Model(net_m, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
152 |
153 | time_cb = TimeMonitor(data_size=step_size)
154 | loss_cb = LossMonitor()
155 | metrics = {
156 | "psnr": PSNR(rgb_range=args.rgb_range, shave=True),
157 | }
158 | eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size/args.batch_size, metrics=metrics, rank_id=rank_id)
159 | cb = [time_cb, loss_cb, eval_cb]
160 | config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
161 | keep_checkpoint_max=args.ckpt_save_max)
162 | ckpt_cb = ModelCheckpoint(prefix=args.filename, directory=args.ckpt_save_path, config=config_ck)
163 | if device_id == 0:
164 | cb += [ckpt_cb]
165 | model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)
166 |
167 |
168 | if __name__ == "__main__":
169 | time_start = time.time()
170 | train()
171 | time_end = time.time()
172 | print('train_time: %f' % (time_end - time_start))
173 |
--------------------------------------------------------------------------------
/MindSpore version/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/MindSpore version/utils/__init__.py
--------------------------------------------------------------------------------
/MindSpore version/utils/var_init.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Huawei Technologies Co., Ltd
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """
16 | Initialize.
17 | """
18 | import math
19 | from functools import reduce
20 | import numpy as np
21 | import mindspore.nn as nn
22 | from mindspore.common import initializer as init
23 |
24 | def _calculate_gain(nonlinearity, param=None):
25 | r"""
26 | Return the recommended gain value for the given nonlinearity function.
27 |
28 | The values are as follows:
29 | ================= ====================================================
30 | nonlinearity gain
31 | ================= ====================================================
32 | Linear / Identity :math:`1`
33 | Conv{1,2,3}D :math:`1`
34 | Sigmoid :math:`1`
35 | Tanh :math:`\frac{5}{3}`
36 | ReLU :math:`\sqrt{2}`
37 | Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
38 | ================= ====================================================
39 |
40 | Args:
41 | nonlinearity: the non-linear function
42 | param: optional parameter for the non-linear function
43 |
44 | Examples:
45 | >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
46 | """
47 | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
48 | if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
49 | return 1
50 | if nonlinearity == 'tanh':
51 | return 5.0 / 3
52 | if nonlinearity == 'relu':
53 | return math.sqrt(2.0)
54 | if nonlinearity == 'leaky_relu':
55 | if param is None:
56 | negative_slope = 0.01
57 | elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
58 | negative_slope = param
59 | else:
60 | raise ValueError("negative_slope {} not a valid number".format(param))
61 | return math.sqrt(2.0 / (1 + negative_slope ** 2))
62 |
63 | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
64 |
65 | def _assignment(arr, num):
66 | """Assign the value of `num` to `arr`."""
67 | if arr.shape == ():
68 | arr = arr.reshape((1))
69 | arr[:] = num
70 | arr = arr.reshape(())
71 | else:
72 | if isinstance(num, np.ndarray):
73 | arr[:] = num[:]
74 | else:
75 | arr[:] = num
76 | return arr
77 |
78 | def _calculate_in_and_out(arr):
79 | """
80 | Calculate n_in and n_out.
81 |
82 | Args:
83 | arr (Array): Input array.
84 |
85 | Returns:
86 | Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
87 | """
88 | dim = len(arr.shape)
89 | if dim < 2:
90 | raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
91 |
92 | n_in = arr.shape[1]
93 | n_out = arr.shape[0]
94 |
95 | if dim > 2:
96 | counter = reduce(lambda x, y: x * y, arr.shape[2:])
97 | n_in *= counter
98 | n_out *= counter
99 | return n_in, n_out
100 |
101 | def _select_fan(array, mode):
102 | mode = mode.lower()
103 | valid_modes = ['fan_in', 'fan_out']
104 | if mode not in valid_modes:
105 | raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
106 |
107 | fan_in, fan_out = _calculate_in_and_out(array)
108 | return fan_in if mode == 'fan_in' else fan_out
109 |
110 | class KaimingInit(init.Initializer):
111 | r"""
112 | Base Class. Initialize the array with He kaiming algorithm.
113 |
114 | Args:
115 | a: the negative slope of the rectifier used after this layer (only
116 | used with ``'leaky_relu'``)
117 | mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
118 | preserves the magnitude of the variance of the weights in the
119 | forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
120 | backwards pass.
121 | nonlinearity: the non-linear function, recommended to use only with
122 | ``'relu'`` or ``'leaky_relu'`` (default).
123 | """
124 | def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
125 | super(KaimingInit, self).__init__()
126 | self.mode = mode
127 | self.gain = _calculate_gain(nonlinearity, a)
128 | def _initialize(self, arr):
129 | pass
130 |
131 |
132 | class KaimingUniform(KaimingInit):
133 | r"""
134 | Initialize the array with He kaiming uniform algorithm. The resulting tensor will
135 | have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
136 |
137 | .. math::
138 | \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
139 |
140 | Input:
141 | arr (Array): The array to be assigned.
142 |
143 | Returns:
144 | Array, assigned array.
145 |
146 | Examples:
147 | >>> w = np.empty(3, 5)
148 | >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
149 | """
150 |
151 | def _initialize(self, arr):
152 | fan = _select_fan(arr, self.mode)
153 | bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
154 | data = np.random.uniform(-bound, bound, arr.shape)
155 |
156 | _assignment(arr, data)
157 |
158 |
159 | class KaimingNormal(KaimingInit):
160 | r"""
161 | Initialize the array with He kaiming normal algorithm. The resulting tensor will
162 | have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
163 |
164 | .. math::
165 | \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
166 |
167 | Input:
168 | arr (Array): The array to be assigned.
169 |
170 | Returns:
171 | Array, assigned array.
172 |
173 | Examples:
174 | >>> w = np.empty(3, 5)
175 | >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
176 | """
177 |
178 | def _initialize(self, arr):
179 | fan = _select_fan(arr, self.mode)
180 | std = self.gain / math.sqrt(fan)
181 | data = np.random.normal(0, std, arr.shape)
182 |
183 | _assignment(arr, data)
184 |
185 |
186 | def default_recurisive_init(custom_cell):
187 | """default_recurisive_init"""
188 | for _, cell in custom_cell.cells_and_names():
189 | if isinstance(cell, nn.Conv2d):
190 | cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
191 | cell.weight.shape,
192 | cell.weight.dtype))
193 | if cell.bias is not None:
194 | fan_in, _ = _calculate_in_and_out(cell.weight)
195 | bound = 1 / math.sqrt(fan_in)
196 | cell.bias.set_data(init.initializer(init.Uniform(bound),
197 | cell.bias.shape,
198 | cell.bias.dtype))
199 | elif isinstance(cell, nn.Dense):
200 | cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
201 | cell.weight.shape,
202 | cell.weight.dtype))
203 | if cell.bias is not None:
204 | fan_in, _ = _calculate_in_and_out(cell.weight)
205 | bound = 1 / math.sqrt(fan_in)
206 | cell.bias.set_data(init.initializer(init.Uniform(bound),
207 | cell.bias.shape,
208 | cell.bias.dtype))
209 | elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
210 | pass
211 |
--------------------------------------------------------------------------------
/PyTorch version/README.md:
--------------------------------------------------------------------------------
1 | ## Dependencies
2 |
3 | PyTorch
4 |
5 | matplotlib
6 |
7 | imageio
8 |
9 | tensorboardX
10 |
11 | opencv-python
12 |
13 | scipy
14 |
15 | scikit-image
16 |
17 | ## Train
18 |
19 | ### Prepare data
20 |
21 | We use DIV2K training set as our training data.
22 |
23 | About how to download data, you could refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch)
24 |
25 | ### Begin to train
26 |
27 | ```python
28 | # use '--teacher_model' to specify the teacher model
29 | python main.py --model EDSR --scale 4 --reset --dir_data LOCATION_OF_DATA --model_filename edsr_x4_0.25student --pre_train output/model/edsr/ --epochs 400 --model_stat --neg_num 10 --contra_lambda 200 --t_lambda 1 --t_l_remove 400 --contrast_t_detach
30 | ```
31 |
32 | ## Test
33 |
34 | You could download models of our paper from [BaiduYun](https://pan.baidu.com/s/1gYenkfLac1s19lfzczxtHA )(code: zser).
35 |
36 | 1. Test on benchmarks:
37 |
38 | ```
39 | python main.py --scale 4 --pre_train FOLDER_OF_THE_PRETRAINED_MODEL --model_filename edsr_x4_0.25student --test_only --self_ensemble --dir_demo test --model EDSR --dir_data LOCATION_OF_DATA --n_GPUs 1 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --stu_width_mult 0.25 --model_stat --data_test Set5
40 | ```
41 |
42 | 2. Test your own image:
43 |
44 | ```python
45 | python main.py --scale 4 --pre_train output/model --model_filename edsr_x4_0.25student --test_only --self_ensemble --model EDSR --n_GPUs 1 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --stu_width_mult 0.25 --model_stat --data_test Demo --save_results --dir_demo LOCATION_OF_YOUR_IMAGE
46 | ```
47 |
48 | The output SR image will be save in './test/result/'
49 |
50 |
--------------------------------------------------------------------------------
/PyTorch version/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | #from dataloader import MSDataLoader
3 | from torch.utils.data import dataloader
4 | from torch.utils.data import ConcatDataset
5 |
6 | # This is a simple wrapper function for ConcatDataset
7 | class MyConcatDataset(ConcatDataset):
8 | def __init__(self, datasets):
9 | super(MyConcatDataset, self).__init__(datasets)
10 | self.train = datasets[0].train
11 |
12 | def set_scale(self, idx_scale):
13 | for d in self.datasets:
14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15 |
16 | class Data:
17 | def __init__(self, args):
18 | self.loader_train = None
19 | if not args.test_only:
20 | datasets = []
21 | for d in args.data_train:
22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23 | m = import_module('data.' + module_name.lower())
24 | datasets.append(getattr(m, module_name)(args, name=d))
25 |
26 | self.loader_train = dataloader.DataLoader(
27 | # MyConcatDataset(datasets),
28 | datasets[0],
29 | batch_size=args.batch_size,
30 | shuffle=True,
31 | pin_memory=not args.cpu,
32 | num_workers=args.n_threads,
33 | )
34 |
35 | self.loader_test = []
36 | for d in args.data_test:
37 | if d in ['Set5', 'Set14', 'B100', 'Urban100']:
38 | m = import_module('data.benchmark')
39 | testset = getattr(m, 'Benchmark')(args, train=False, name=d)
40 | else:
41 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
42 | m = import_module('data.' + module_name.lower())
43 | testset = getattr(m, module_name)(args, train=False, name=d)
44 |
45 | self.loader_test.append(
46 | dataloader.DataLoader(
47 | testset,
48 | batch_size=1,
49 | shuffle=False,
50 | pin_memory=not args.cpu,
51 | num_workers=args.n_threads,
52 | )
53 | )
54 |
55 | # if args.data_aux == 'BSD500':
56 | # m = import_module('data.bsd500')
57 | # auxset = getattr(m, 'BSD500')(args, train=True, name='t')
58 | # self.loader_aux = dataloader.DataLoader(
59 | # auxset,
60 | # batch_size=args.neg_num,
61 | # shuffle=True,
62 | # pin_memory=not args.cpu,
63 | # num_workers=args.n_threads,
64 | # )
--------------------------------------------------------------------------------
/PyTorch version/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Benchmark(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(Benchmark, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'benchmark', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | if self.input_large:
21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22 | else:
23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24 | self.ext = ('', '.png')
25 |
26 |
--------------------------------------------------------------------------------
/PyTorch version/data/bsd500.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | import PIL
5 | import random
6 |
7 | import torch
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 |
11 | from data import common
12 |
13 | class BSD500(data.Dataset):
14 | def __init__(self, args, name='BSR', train=True):
15 | super(BSD500, self).__init__()
16 | self.args = args
17 | self.name = name
18 | self.train = train
19 | self.scale = args.scale[0]
20 | self.benchmark = False
21 |
22 | if train:
23 | self.dir_data = os.path.join(args.dir_data, 'BSR/BSDS500/data/images/train/')
24 | else:
25 | self.dir_data = os.path.join(args.dir_data, 'BSR/BSDS500/data/images/test/')
26 | self.dataset_from_folder = common.DatasetFromFolder(self.dir_data, '.jpg')
27 | self.cropper = transforms.RandomCrop(size=args.patch_size)
28 | self.resizer = transforms.Resize(size=args.patch_size // args.scale[0], interpolation=PIL.Image.BICUBIC)
29 | self.totensor = transforms.ToTensor()
30 | # @staticmethod
31 | # def to_tensor(patch):
32 | # return torch.Tensor(np.asarray(patch).swapaxes(0, 2)).float() / 255 - 0.5
33 | #
34 | # @staticmethod
35 | # def to_image(tensor):
36 | # if type(tensor) == torch.Tensor:
37 | # tensor = tensor.numpy()
38 | #
39 | # if len(tensor.shape) == 4:
40 | # tensor = tensor.swapaxes(1, 3)
41 | # elif len(tensor.shape) == 3:
42 | # tensor = tensor.swapaxes(0, 2)
43 | # else:
44 | # raise Exception("Predictions have shape not in set {3,4}")
45 | #
46 | # tensor = (tensor + 0.5) * 255
47 | # tensor[tensor > 255] = 255
48 | # tensor[tensor < 0] = 0
49 | # return tensor.round().astype(int)
50 |
51 | def __getitem__(self, index):
52 | img, filename = self.dataset_from_folder[index]
53 |
54 | if self.train:
55 | hr_patch = self.cropper(img)
56 | else:
57 | hr_patch = img
58 |
59 | lr_patch = self.resizer(hr_patch)
60 | hr_patch = self.totensor(hr_patch)
61 | lr_patch = self.totensor(lr_patch)
62 | if (not self.args.no_augment) and self.train:
63 | lr_patch, hr_patch = self.data_aug(lr_patch, hr_patch)
64 |
65 | hr_patch = hr_patch.mul_(self.args.rgb_range)
66 | lr_patch = lr_patch.mul_(self.args.rgb_range)
67 |
68 | return lr_patch, hr_patch, filename
69 |
70 | def __len__(self):
71 | return len(self.dataset_from_folder)
72 |
73 | def data_aug(self, lr, hr, hflip=True, rot=True):
74 | hflip = hflip and random.random() < 0.5
75 | vflip = rot and random.random() < 0.5
76 | rot90 = rot and random.random() < 0.5
77 |
78 | if hflip:
79 | lr = torch.flip(lr, [1])
80 | hr = torch.flip(hr, [1])
81 | # lr = lr[:, ::-1, :]
82 | # hr = hr[:, ::-1, :]
83 | if vflip:
84 | lr = torch.flip(lr, [2])
85 | hr = torch.flip(hr, [2])
86 | # lr = lr[::-1, :, :]
87 | # hr = hr[::-1, :, :]
88 | if rot90:
89 | lr = lr.permute(0, 2, 1)
90 | hr = hr.permute(0, 2, 1)
91 |
92 | return lr, hr
--------------------------------------------------------------------------------
/PyTorch version/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 |
4 | import numpy as np
5 | import skimage.color as sc
6 |
7 | import os
8 | import imageio
9 | import numpy as np
10 | import PIL
11 |
12 | import torch
13 | import torch.utils.data as data
14 |
15 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
16 | ih, iw = args[0].shape[:2]
17 |
18 | if not input_large:
19 | p = scale if multi else 1
20 | tp = p * patch_size
21 | ip = tp // scale
22 | else:
23 | tp = patch_size
24 | ip = patch_size
25 |
26 | ix = random.randrange(0, iw - ip + 1)
27 | iy = random.randrange(0, ih - ip + 1)
28 |
29 | if not input_large:
30 | tx, ty = scale * ix, scale * iy
31 | else:
32 | tx, ty = ix, iy
33 |
34 | ret = [
35 | args[0][iy:iy + ip, ix:ix + ip, :],
36 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
37 | ]
38 |
39 | return ret
40 |
41 | def set_channel(*args, n_channels=3):
42 | def _set_channel(img):
43 | if img.ndim == 2:
44 | img = np.expand_dims(img, axis=2)
45 |
46 | c = img.shape[2]
47 | if n_channels == 1 and c == 3:
48 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
49 | elif n_channels == 3 and c == 1:
50 | img = np.concatenate([img] * n_channels, 2)
51 |
52 | return img
53 |
54 | return [_set_channel(a) for a in args]
55 |
56 | def np2Tensor(*args, rgb_range=255):
57 | def _np2Tensor(img):
58 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
59 | tensor = torch.from_numpy(np_transpose).float()
60 | tensor.mul_(rgb_range / 255)
61 |
62 | return tensor
63 |
64 | return [_np2Tensor(a) for a in args]
65 |
66 | def augment(*args, hflip=True, rot=True):
67 | hflip = hflip and random.random() < 0.5
68 | vflip = rot and random.random() < 0.5
69 | rot90 = rot and random.random() < 0.5
70 |
71 | def _augment(img):
72 | if hflip: img = img[:, ::-1, :]
73 | if vflip: img = img[::-1, :, :]
74 | if rot90: img = img.transpose(1, 0, 2)
75 |
76 | return img
77 |
78 | return [_augment(a) for a in args]
79 |
80 | def is_image_file(filename, extension):
81 | return filename.endswith(extension)
82 |
83 | class DatasetFromFolder(data.Dataset):
84 | def __init__(self, image_dir, extension = ".png"):
85 | super(DatasetFromFolder, self).__init__()
86 | self.image_dir = image_dir
87 | self.image_filenames = [x for x in sorted(os.listdir(image_dir)) if is_image_file(x, extension)]
88 | self.extension = extension
89 |
90 | def __getitem__(self, index):
91 | #return imageio.imread(os.path.join(self.image_dir, self.image_filenames[index])), self.image_filenames[index]
92 | filename = self.image_filenames[index]
93 | filename = filename.rstrip(self.extension)
94 | return PIL.Image.fromarray(imageio.imread(os.path.join(self.image_dir, self.image_filenames[index]))), filename
95 |
96 | def __len__(self):
97 | return len(self.image_filenames)
--------------------------------------------------------------------------------
/PyTorch version/data/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 |
5 | import numpy as np
6 | import imageio
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Demo(data.Dataset):
12 | def __init__(self, args, name='Demo', train=False, benchmark=False):
13 | self.args = args
14 | self.name = name
15 | self.scale = args.scale
16 | self.idx_scale = 0
17 | self.train = False
18 | self.benchmark = benchmark
19 |
20 | self.filelist = []
21 | for f in os.listdir(args.dir_demo):
22 | if f.find('.png') >= 0 or f.find('.jp') >= 0:
23 | self.filelist.append(os.path.join(args.dir_demo, f))
24 | self.filelist.sort()
25 |
26 | def __getitem__(self, idx):
27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28 | lr = imageio.imread(self.filelist[idx])
29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31 |
32 | return lr_t, -1, filename
33 |
34 | def __len__(self):
35 | return len(self.filelist)
36 |
37 | def set_scale(self, idx_scale):
38 | self.idx_scale = idx_scale
39 |
40 |
--------------------------------------------------------------------------------
/PyTorch version/data/div2k.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class DIV2K(srdata.SRData):
5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6 | data_range = [r.split('-') for r in args.data_range.split('/')]
7 | if train:
8 | data_range = data_range[0]
9 | else:
10 | if args.test_only and len(data_range) == 1:
11 | data_range = data_range[0]
12 | else:
13 | data_range = data_range[1]
14 |
15 | self.begin, self.end = list(map(lambda x: int(x), data_range))
16 | super(DIV2K, self).__init__(
17 | args, name=name, train=train, benchmark=benchmark
18 | )
19 |
20 | def _scan(self):
21 | names_hr, names_lr = super(DIV2K, self)._scan()
22 | names_hr = names_hr[self.begin - 1:self.end]
23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24 |
25 | return names_hr, names_lr
26 |
27 | def _set_filesystem(self, dir_data):
28 | super(DIV2K, self)._set_filesystem(dir_data)
29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31 | if self.input_large: self.dir_lr += 'L'
32 |
33 |
--------------------------------------------------------------------------------
/PyTorch version/data/div2kjpeg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 | from data import div2k
4 |
5 | class DIV2KJPEG(div2k.DIV2K):
6 | def __init__(self, args, name='', train=True, benchmark=False):
7 | self.q_factor = int(name.replace('DIV2K-Q', ''))
8 | super(DIV2KJPEG, self).__init__(
9 | args, name=name, train=train, benchmark=benchmark
10 | )
11 |
12 | def _set_filesystem(self, dir_data):
13 | self.apath = os.path.join(dir_data, 'DIV2K')
14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15 | self.dir_lr = os.path.join(
16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17 | )
18 | if self.input_large: self.dir_lr += 'L'
19 | self.ext = ('.png', '.jpg')
20 |
21 |
--------------------------------------------------------------------------------
/PyTorch version/data/sr291.py:
--------------------------------------------------------------------------------
1 | from data import srdata
2 |
3 | class SR291(srdata.SRData):
4 | def __init__(self, args, name='SR291', train=True, benchmark=False):
5 | super(SR291, self).__init__(args, name=name)
6 |
7 |
--------------------------------------------------------------------------------
/PyTorch version/data/srdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | import pickle
5 |
6 | from data import common
7 |
8 | import numpy as np
9 | import imageio
10 | import torch
11 | import torch.utils.data as data
12 |
13 | class SRData(data.Dataset):
14 | def __init__(self, args, name='', train=True, benchmark=False):
15 | self.args = args
16 | self.name = name
17 | self.train = train
18 | self.split = 'train' if train else 'test'
19 | self.do_eval = True
20 | self.benchmark = benchmark
21 | self.input_large = (args.model == 'VDSR')
22 | self.scale = args.scale
23 | self.idx_scale = 0
24 |
25 | self._set_filesystem(args.dir_data)
26 | if args.ext.find('img') < 0:
27 | path_bin = os.path.join(self.apath, 'bin')
28 | os.makedirs(path_bin, exist_ok=True)
29 |
30 | list_hr, list_lr = self._scan()
31 | if args.ext.find('img') >= 0 or benchmark:
32 | self.images_hr, self.images_lr = list_hr, list_lr
33 | elif args.ext.find('sep') >= 0:
34 | os.makedirs(
35 | self.dir_hr.replace(self.apath, path_bin),
36 | exist_ok=True
37 | )
38 | for s in self.scale:
39 | os.makedirs(
40 | os.path.join(
41 | self.dir_lr.replace(self.apath, path_bin),
42 | 'X{}'.format(s)
43 | ),
44 | exist_ok=True
45 | )
46 |
47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48 | for h in list_hr:
49 | b = h.replace(self.apath, path_bin)
50 | b = b.replace(self.ext[0], '.pt')
51 | self.images_hr.append(b)
52 | self._check_and_load(args.ext, h, b, verbose=True)
53 | for i, ll in enumerate(list_lr):
54 | for l in ll:
55 | b = l.replace(self.apath, path_bin)
56 | b = b.replace(self.ext[1], '.pt')
57 | self.images_lr[i].append(b)
58 | self._check_and_load(args.ext, l, b, verbose=True)
59 | if train:
60 | n_patches = args.batch_size * args.test_every
61 | n_images = len(args.data_train) * len(self.images_hr)
62 | if n_images == 0:
63 | self.repeat = 0
64 | else:
65 | self.repeat = max(n_patches // n_images, 1)
66 |
67 | # Below functions as used to prepare images
68 | def _scan(self):
69 | names_hr = sorted(
70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71 | )
72 | names_lr = [[] for _ in self.scale]
73 | for f in names_hr:
74 | filename, _ = os.path.splitext(os.path.basename(f))
75 | for si, s in enumerate(self.scale):
76 | names_lr[si].append(os.path.join(
77 | self.dir_lr, 'X{}/{}x{}{}'.format(
78 | s, filename, s, self.ext[1]
79 | )
80 | ))
81 |
82 | return names_hr, names_lr
83 |
84 | def _set_filesystem(self, dir_data):
85 | self.apath = os.path.join(dir_data, self.name)
86 | self.dir_hr = os.path.join(self.apath, 'HR')
87 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
88 | if self.input_large: self.dir_lr += 'L'
89 | self.ext = ('.png', '.png')
90 |
91 | def _check_and_load(self, ext, img, f, verbose=True):
92 | if not os.path.isfile(f) or ext.find('reset') >= 0:
93 | if verbose:
94 | print('Making a binary: {}'.format(f))
95 | with open(f, 'wb') as _f:
96 | pickle.dump(imageio.imread(img), _f)
97 |
98 | def __getitem__(self, idx):
99 | lr, hr, filename = self._load_file(idx)
100 | pair = self.get_patch(lr, hr)
101 | pair = common.set_channel(*pair, n_channels=self.args.n_colors)
102 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
103 |
104 | return pair_t[0], pair_t[1], filename
105 |
106 | def __len__(self):
107 | if self.train:
108 | return len(self.images_hr) * self.repeat
109 | else:
110 | return len(self.images_hr)
111 |
112 | def _get_index(self, idx):
113 | if self.train:
114 | return idx % len(self.images_hr)
115 | else:
116 | return idx
117 |
118 | def _load_file(self, idx):
119 | idx = self._get_index(idx)
120 | f_hr = self.images_hr[idx]
121 | f_lr = self.images_lr[self.idx_scale][idx]
122 |
123 | filename, _ = os.path.splitext(os.path.basename(f_hr))
124 | if self.args.ext == 'img' or self.benchmark:
125 | hr = imageio.imread(f_hr)
126 | lr = imageio.imread(f_lr)
127 | elif self.args.ext.find('sep') >= 0:
128 | with open(f_hr, 'rb') as _f:
129 | hr = pickle.load(_f)
130 | with open(f_lr, 'rb') as _f:
131 | lr = pickle.load(_f)
132 |
133 | return lr, hr, filename
134 |
135 | def get_patch(self, lr, hr):
136 | scale = self.scale[self.idx_scale]
137 |
138 | if self.train:
139 | lr, hr = common.get_patch(
140 | lr, hr,
141 | patch_size=self.args.patch_size,
142 | scale=scale,
143 | multi=(len(self.scale) > 1),
144 | input_large=self.input_large
145 | )
146 | if not self.args.no_augment: lr, hr = common.augment(lr, hr)
147 | else:
148 | ih, iw = lr.shape[:2]
149 | hr = hr[0:ih * scale, 0:iw * scale]
150 | return lr, hr
151 |
152 | def set_scale(self, idx_scale):
153 | if not self.input_large:
154 | self.idx_scale = idx_scale
155 | else:
156 | self.idx_scale = random.randint(0, len(self.scale) - 1)
157 |
158 |
--------------------------------------------------------------------------------
/PyTorch version/data/video.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 |
5 | import cv2
6 | import numpy as np
7 | import imageio
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | class Video(data.Dataset):
13 | def __init__(self, args, name='Video', train=False, benchmark=False):
14 | self.args = args
15 | self.name = name
16 | self.scale = args.scale
17 | self.idx_scale = 0
18 | self.train = False
19 | self.do_eval = False
20 | self.benchmark = benchmark
21 |
22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23 | self.vidcap = cv2.VideoCapture(args.dir_demo)
24 | self.n_frames = 0
25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26 |
27 | def __getitem__(self, idx):
28 | success, lr = self.vidcap.read()
29 | if success:
30 | self.n_frames += 1
31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33 |
34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35 | else:
36 | vidcap.release()
37 | return None
38 |
39 | def __len__(self):
40 | return self.total_frames
41 |
42 | def set_scale(self, idx_scale):
43 | self.idx_scale = idx_scale
44 |
45 |
--------------------------------------------------------------------------------
/PyTorch version/dataloader.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import random
3 |
4 | import torch
5 | import torch.multiprocessing as multiprocessing
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data import SequentialSampler
8 | from torch.utils.data import RandomSampler
9 | from torch.utils.data import BatchSampler
10 | from torch.utils.data import _utils
11 | from torch.utils.data.dataloader import _DataLoaderIter
12 |
13 | from torch.utils.data._utils import collate
14 | from torch.utils.data._utils import signal_handling
15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16 | from torch.utils.data._utils import ExceptionWrapper
17 | from torch.utils.data._utils import IS_WINDOWS
18 | from torch.utils.data._utils.worker import ManagerWatchdog
19 |
20 | from torch._six import queue
21 |
22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23 | try:
24 | collate._use_shared_memory = True
25 | signal_handling._set_worker_signal_handlers()
26 |
27 | torch.set_num_threads(1)
28 | random.seed(seed)
29 | torch.manual_seed(seed)
30 |
31 | data_queue.cancel_join_thread()
32 |
33 | if init_fn is not None:
34 | init_fn(worker_id)
35 |
36 | watchdog = ManagerWatchdog()
37 |
38 | while watchdog.is_alive():
39 | try:
40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41 | except queue.Empty:
42 | continue
43 |
44 | if r is None:
45 | assert done_event.is_set()
46 | return
47 | elif done_event.is_set():
48 | continue
49 |
50 | idx, batch_indices = r
51 | try:
52 | idx_scale = 0
53 | if len(scale) > 1 and dataset.train:
54 | idx_scale = random.randrange(0, len(scale))
55 | dataset.set_scale(idx_scale)
56 |
57 | samples = collate_fn([dataset[i] for i in batch_indices])
58 | samples.append(idx_scale)
59 | except Exception:
60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61 | else:
62 | data_queue.put((idx, samples))
63 | del samples
64 |
65 | except KeyboardInterrupt:
66 | pass
67 |
68 | class _MSDataLoaderIter(_DataLoaderIter):
69 |
70 | def __init__(self, loader):
71 | self.dataset = loader.dataset
72 | self.scale = loader.scale
73 | self.collate_fn = loader.collate_fn
74 | self.batch_sampler = loader.batch_sampler
75 | self.num_workers = loader.num_workers
76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77 | self.timeout = loader.timeout
78 |
79 | self.sample_iter = iter(self.batch_sampler)
80 |
81 | base_seed = torch.LongTensor(1).random_().item()
82 |
83 | if self.num_workers > 0:
84 | self.worker_init_fn = loader.worker_init_fn
85 | self.worker_queue_idx = 0
86 | self.worker_result_queue = multiprocessing.Queue()
87 | self.batches_outstanding = 0
88 | self.worker_pids_set = False
89 | self.shutdown = False
90 | self.send_idx = 0
91 | self.rcvd_idx = 0
92 | self.reorder_dict = {}
93 | self.done_event = multiprocessing.Event()
94 |
95 | base_seed = torch.LongTensor(1).random_()[0]
96 |
97 | self.index_queues = []
98 | self.workers = []
99 | for i in range(self.num_workers):
100 | index_queue = multiprocessing.Queue()
101 | index_queue.cancel_join_thread()
102 | w = multiprocessing.Process(
103 | target=_ms_loop,
104 | args=(
105 | self.dataset,
106 | index_queue,
107 | self.worker_result_queue,
108 | self.done_event,
109 | self.collate_fn,
110 | self.scale,
111 | base_seed + i,
112 | self.worker_init_fn,
113 | i
114 | )
115 | )
116 | w.daemon = True
117 | w.start()
118 | self.index_queues.append(index_queue)
119 | self.workers.append(w)
120 |
121 | if self.pin_memory:
122 | self.data_queue = queue.Queue()
123 | pin_memory_thread = threading.Thread(
124 | target=_utils.pin_memory._pin_memory_loop,
125 | args=(
126 | self.worker_result_queue,
127 | self.data_queue,
128 | torch.cuda.current_device(),
129 | self.done_event
130 | )
131 | )
132 | pin_memory_thread.daemon = True
133 | pin_memory_thread.start()
134 | self.pin_memory_thread = pin_memory_thread
135 | else:
136 | self.data_queue = self.worker_result_queue
137 |
138 | _utils.signal_handling._set_worker_pids(
139 | id(self), tuple(w.pid for w in self.workers)
140 | )
141 | _utils.signal_handling._set_SIGCHLD_handler()
142 | self.worker_pids_set = True
143 |
144 | for _ in range(2 * self.num_workers):
145 | self._put_indices()
146 |
147 |
148 | class MSDataLoader(DataLoader):
149 |
150 | def __init__(self, cfg, *args, **kwargs):
151 | super(MSDataLoader, self).__init__(
152 | *args, **kwargs, num_workers=cfg.n_threads
153 | )
154 | self.scale = cfg.scale
155 |
156 | def __iter__(self):
157 | return _MSDataLoaderIter(self)
158 |
159 |
--------------------------------------------------------------------------------
/PyTorch version/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/loss/__init__.py
--------------------------------------------------------------------------------
/PyTorch version/loss/adversarial.py:
--------------------------------------------------------------------------------
1 | import utils.utility as utility
2 | from types import SimpleNamespace
3 |
4 | from model import common
5 | from loss import discriminator
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 | class Adversarial(nn.Module):
13 | def __init__(self, args, gan_type):
14 | super(Adversarial, self).__init__()
15 | self.gan_type = gan_type
16 | self.gan_k = args.gan_k
17 | self.dis = discriminator.Discriminator(args)
18 | if torch.cuda.is_available():
19 | self.dis.cuda()
20 | if gan_type == 'WGAN_GP':
21 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4
22 | optim_dict = {
23 | 'optimizer': 'ADAM',
24 | 'betas': (0, 0.9),
25 | 'epsilon': 1e-8,
26 | 'lr': 1e-5,
27 | 'weight_decay': args.weight_decay,
28 | 'decay': args.decay,
29 | 'gamma': args.gamma
30 | }
31 | optim_args = SimpleNamespace(**optim_dict)
32 | else:
33 | optim_args = args
34 |
35 | self.optimizer = utility.make_optimizer(optim_args, self.dis)
36 |
37 | def forward(self, fake, real):
38 | # updating discriminator...
39 | self.loss = 0
40 | fake_detach = fake.detach() # do not backpropagate through G
41 | for _ in range(self.gan_k):
42 | self.optimizer.zero_grad()
43 | # d: B x 1 tensor
44 | d_fake = self.dis(fake_detach)
45 | d_real = self.dis(real)
46 | retain_graph = False
47 | if self.gan_type == 'GAN':
48 | loss_d = self.bce(d_real, d_fake)
49 | elif self.gan_type.find('WGAN') >= 0:
50 | loss_d = (d_fake - d_real).mean()
51 | if self.gan_type.find('GP') >= 0:
52 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
53 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
54 | hat.requires_grad = True
55 | d_hat = self.dis(hat)
56 | gradients = torch.autograd.grad(
57 | outputs=d_hat.sum(), inputs=hat,
58 | retain_graph=True, create_graph=True, only_inputs=True
59 | )[0]
60 | gradients = gradients.view(gradients.size(0), -1)
61 | gradient_norm = gradients.norm(2, dim=1)
62 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
63 | loss_d += gradient_penalty
64 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
65 | elif self.gan_type == 'RGAN':
66 | better_real = d_real - d_fake.mean(dim=0, keepdim=True)
67 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
68 | loss_d = self.bce(better_real, better_fake)
69 | retain_graph = True
70 |
71 | # Discriminator update
72 | self.loss += loss_d.item()
73 | loss_d.backward(retain_graph=retain_graph)
74 | self.optimizer.step()
75 |
76 | if self.gan_type == 'WGAN':
77 | for p in self.dis.parameters():
78 | p.data.clamp_(-1, 1)
79 |
80 | self.loss /= self.gan_k
81 |
82 | # updating generator...
83 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
84 | if self.gan_type == 'GAN':
85 | label_real = torch.ones_like(d_fake_bp)
86 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
87 | elif self.gan_type.find('WGAN') >= 0:
88 | loss_g = -d_fake_bp.mean()
89 | elif self.gan_type == 'RGAN':
90 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
91 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
92 | loss_g = self.bce(better_fake, better_real)
93 |
94 | # Generator loss
95 | return loss_g
96 |
97 | def state_dict(self, *args, **kwargs):
98 | state_discriminator = self.dis.state_dict(*args, **kwargs)
99 | state_optimizer = self.optimizer.state_dict()
100 |
101 | return dict(**state_discriminator, **state_optimizer)
102 |
103 | def bce(self, real, fake):
104 | label_real = torch.ones_like(real)
105 | label_fake = torch.zeros_like(fake)
106 | bce_real = F.binary_cross_entropy_with_logits(real, label_real)
107 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
108 | bce_loss = bce_real + bce_fake
109 | return bce_loss
110 |
111 | # Some references
112 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
113 | # OR
114 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
115 |
--------------------------------------------------------------------------------
/PyTorch version/loss/contrast_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torch.nn import functional as F
4 | import torch.nn.functional as fnn
5 | from torch.autograd import Variable
6 | import numpy as np
7 | from torchvision import models
8 |
9 | import math
10 |
11 | class Vgg19(torch.nn.Module):
12 | def __init__(self, requires_grad=False):
13 | super(Vgg19, self).__init__()
14 | vgg_pretrained_features = models.vgg19(pretrained=True).features
15 | self.slice1 = torch.nn.Sequential()
16 | self.slice2 = torch.nn.Sequential()
17 | self.slice3 = torch.nn.Sequential()
18 | self.slice4 = torch.nn.Sequential()
19 | self.slice5 = torch.nn.Sequential()
20 | for x in range(2):
21 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(2, 7):
23 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(7, 12):
25 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
26 | for x in range(12, 21):
27 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
28 | for x in range(21, 30):
29 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
30 | if not requires_grad:
31 | for param in self.parameters():
32 | param.requires_grad = False
33 |
34 | def forward(self, X):
35 | h_relu1 = self.slice1(X)
36 | h_relu2 = self.slice2(h_relu1)
37 | h_relu3 = self.slice3(h_relu2)
38 | h_relu4 = self.slice4(h_relu3)
39 | h_relu5 = self.slice5(h_relu4)
40 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
41 | return out
42 |
43 |
44 | class ContrastLoss(nn.Module):
45 | def __init__(self, weights, d_func, t_detach = False, is_one=False):
46 | super(ContrastLoss, self).__init__()
47 | self.vgg = Vgg19().cuda()
48 | self.l1 = nn.L1Loss()
49 | self.weights = weights
50 | self.d_func = d_func
51 | self.is_one = is_one
52 | self.t_detach = t_detach
53 |
54 | def forward(self, teacher, student, neg, blur_neg=None):
55 | teacher_vgg, student_vgg, neg_vgg, = self.vgg(teacher), self.vgg(student), self.vgg(neg)
56 | blur_neg_vgg = None
57 | if blur_neg is not None:
58 | blur_neg_vgg = self.vgg(blur_neg)
59 | if self.d_func == "L1":
60 | self.forward_func = self.L1_forward
61 | elif self.d_func == 'cos':
62 | self.forward_func = self.cos_forward
63 |
64 | return self.forward_func(teacher_vgg, student_vgg, neg_vgg, blur_neg_vgg)
65 |
66 | def L1_forward(self, teacher, student, neg, blur_neg=None):
67 | """
68 | :param teacher: 5*batchsize*color*patchsize*patchsize
69 | :param student: 5*batchsize*color*patchsize*patchsize
70 | :param neg: 5*negnum*color*patchsize*patchsize
71 | :return:
72 | """
73 | loss = 0
74 | for i in range(len(teacher)):
75 | neg_i = neg[i].unsqueeze(0)
76 | neg_i = neg_i.repeat(student[i].shape[0], 1, 1, 1, 1)
77 | neg_i = neg_i.permute(1, 0, 2, 3, 4)### batchsize*negnum*color*patchsize*patchsize
78 | if blur_neg is not None:
79 | blur_neg_i = blur_neg[i].unsqueeze(0)
80 | neg_i = torch.cat((neg_i, blur_neg_i))
81 |
82 |
83 | if self.t_detach:
84 | d_ts = self.l1(teacher[i].detach(), student[i])
85 | else:
86 | d_ts = self.l1(teacher[i], student[i])
87 | d_sn = torch.mean(torch.abs(neg_i.detach() - student[i]).sum(0))
88 |
89 | contrastive = d_ts / (d_sn + 1e-7)
90 | loss += self.weights[i] * contrastive
91 | return loss
92 |
93 |
94 | def cos_forward(self, teacher, student, neg, blur_neg=None):
95 | loss = 0
96 | for i in range(len(teacher)):
97 | neg_i = neg[i].unsqueeze(0)
98 | neg_i = neg_i.repeat(student[i].shape[0], 1, 1, 1, 1)
99 | neg_i = neg_i.permute(1, 0, 2, 3, 4)
100 | if blur_neg is not None:
101 | blur_neg_i = blur_neg[i].unsqueeze(0)
102 | neg_i = torch.cat((neg_i, blur_neg_i))
103 |
104 | if self.t_detach:
105 | d_ts = torch.cosine_similarity(teacher[i].detach(), student[i], dim=0).mean()
106 | else:
107 | d_ts = torch.cosine_similarity(teacher[i], student[i], dim=0).mean()
108 | d_sn = self.calc_cos_stu_neg(student[i], neg_i.detach())
109 |
110 | contrastive = -torch.log(torch.exp(d_ts)/(torch.exp(d_sn)+1e-7))
111 | loss += self.weights[i] * contrastive
112 | return loss
113 |
114 | def calc_cos_stu_neg(self, stu, neg):
115 | n = stu.shape[0]
116 | m = neg.shape[0]
117 |
118 | stu = stu.view(n, -1)
119 | neg = neg.view(m, n, -1)
120 | # normalize
121 | stu = F.normalize(stu, p=2, dim=1)
122 | neg = F.normalize(neg, p=2, dim=2)
123 | # multiply
124 | d_sn = torch.mean((stu * neg).sum(0))
125 | return d_sn
126 |
--------------------------------------------------------------------------------
/PyTorch version/loss/discriminator.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class Discriminator(nn.Module):
7 | '''
8 | output is not normalized
9 | '''
10 | def __init__(self, args):
11 | super(Discriminator, self).__init__()
12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 | in_channels = args.n_colors
14 | out_channels = 64
15 | depth = 7
16 |
17 | def _block(_in_channels, _out_channels, stride=1):
18 | return nn.Sequential(
19 | nn.Conv2d(
20 | _in_channels,
21 | _out_channels,
22 | 3,
23 | padding=1,
24 | stride=stride,
25 | bias=False
26 | ),
27 | nn.BatchNorm2d(_out_channels),
28 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
29 | )
30 |
31 | m_features = [_block(in_channels, out_channels)]
32 | for i in range(depth):
33 | in_channels = out_channels
34 | if i % 2 == 1:
35 | stride = 1
36 | out_channels *= 2
37 | else:
38 | stride = 2
39 | m_features.append(_block(in_channels, out_channels, stride=stride))
40 |
41 | patch_size = args.patch_size // (2**((depth + 1) // 2))
42 | m_classifier = [
43 | nn.Linear(out_channels * patch_size**2, 1024),
44 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
45 | nn.Linear(1024, 1)
46 | ]
47 |
48 | self.features = nn.Sequential(*m_features)
49 | self.classifier = nn.Sequential(*m_classifier)
50 |
51 | def forward(self, x):
52 | features = self.features(x)
53 | output = self.classifier(features.view(features.size(0), -1))
54 |
55 | return output
56 |
57 |
--------------------------------------------------------------------------------
/PyTorch version/loss/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.vgg import vgg19
3 | import torch
4 |
5 |
6 | class PerceptualLoss(nn.Module):
7 | def __init__(self):
8 | super(PerceptualLoss, self).__init__()
9 |
10 | vgg = vgg19(pretrained=True)
11 | loss_network = nn.Sequential(*list(vgg.features)[:35]).eval()
12 | for param in loss_network.parameters():
13 | param.requires_grad = False
14 | self.loss_network = loss_network
15 | if torch.cuda.is_available():
16 | self.loss_network.cuda()
17 | self.l1_loss = nn.L1Loss()
18 |
19 | def forward(self, high_resolution, fake_high_resolution):
20 | perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution))
21 | return perception_loss
22 |
--------------------------------------------------------------------------------
/PyTorch version/loss/vgg.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class VGG(nn.Module):
9 | def __init__(self, conv_index, rgb_range=1):
10 | super(VGG, self).__init__()
11 | vgg_features = models.vgg19(pretrained=True).features
12 | modules = [m for m in vgg_features]
13 | if conv_index.find('22') >= 0:
14 | self.vgg = nn.Sequential(*modules[:8])
15 | elif conv_index.find('54') >= 0:
16 | self.vgg = nn.Sequential(*modules[:35])
17 |
18 | vgg_mean = (0.485, 0.456, 0.406)
19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21 | for p in self.parameters():
22 | p.requires_grad = False
23 |
24 | def forward(self, sr, hr):
25 | def _forward(x):
26 | x = self.sub_mean(x)
27 | x = self.vgg(x)
28 | return x
29 |
30 | vgg_sr = _forward(sr)
31 | with torch.no_grad():
32 | vgg_hr = _forward(hr.detach())
33 |
34 | loss = F.mse_loss(vgg_sr, vgg_hr)
35 |
36 | return loss
37 |
--------------------------------------------------------------------------------
/PyTorch version/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | import utils.utility as utility
5 | import data
6 | import loss
7 | from option import args
8 | from trainer.slim_contrast_trainer import SlimContrastiveTrainer
9 |
10 | torch.manual_seed(args.seed)
11 | checkpoint = utility.checkpoint(args)
12 |
13 | from signal import signal, SIGPIPE, SIG_DFL, SIG_IGN
14 | signal(SIGPIPE, SIG_IGN)
15 |
16 | if __name__ == '__main__':
17 | loader = data.Data(args)
18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 | if args.seperate:
20 | t = SeperateContrastiveTrainer(args, loader, device)
21 | else:
22 | t = SlimContrastiveTrainer(args, loader, device)
23 | # t = SlimContrastiveTrainer(args, loader, device, neg_loader)
24 |
25 | if args.model_stat:
26 | total_param = 0
27 | if args.seperate:
28 | for name, param in t.s_model.named_parameters():
29 | total_param += torch.numel(param)
30 | else:
31 | for name, param in t.model.named_parameters():
32 | # print(name, ' ', torch.numel(param))
33 | total_param += torch.numel(param)
34 | print(total_param)
35 |
36 |
37 |
38 | if not args.test_only:
39 | t.train()
40 | else:
41 | t.test(args.stu_width_mult)
42 | checkpoint.done()
--------------------------------------------------------------------------------
/PyTorch version/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/model/__init__.py
--------------------------------------------------------------------------------
/PyTorch version/model/carn.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class BasicBlock(nn.Module):
7 | def __init__(self,
8 | in_channels, out_channels,
9 | ksize=3, stride=1, pad=1):
10 | super(BasicBlock, self).__init__()
11 |
12 | self.conv = nn.Conv2d(in_channels, out_channels, ksize, stride, pad)
13 | self.act = nn.ReLU(inplace=True)
14 |
15 | def forward(self, x, width_mult):
16 | out = common.SlimModule(x, self.conv, width_mult)
17 | out = self.act(out)
18 | return out
19 |
20 | class Block(nn.Module):
21 | def __init__(self, n_feats):
22 | super(Block, self).__init__()
23 |
24 | self.act = nn.ReLU(inplace=True)
25 |
26 | self.b1 = common.ResidualBlock(n_feats, 3, act=self.act, res_scale=1)
27 | self.b2 = common.ResidualBlock(n_feats, 3, self.act, 1)
28 | self.b3 = common.ResidualBlock(n_feats, 3, self.act, 1)
29 | self.c1 = BasicBlock(n_feats*2, n_feats, 1, 1, 0)
30 | self.c2 = BasicBlock(n_feats*3, n_feats, 1, 1, 0)
31 | self.c3 = BasicBlock(n_feats*4, n_feats, 1, 1, 0)
32 |
33 | def forward(self, x, width_mult=1):
34 | c0 = o0 = x
35 |
36 | b1 = self.b1(o0, width_mult)
37 | c1 = torch.cat([c0, b1], dim=1)
38 | o1 = self.c1(c1, width_mult)
39 |
40 | b2 = self.b2(o1, width_mult)
41 | c2 = torch.cat([c1, b2], dim=1)
42 | o2 = self.c2(c2, width_mult)
43 |
44 | b3 = self.b3(o2, width_mult)
45 | c3 = torch.cat([c2, b3], dim=1)
46 | o3 = self.c3(c3, width_mult)
47 |
48 | return o3
49 |
50 | class CARN(nn.Module):
51 | def __init__(self, args):
52 | super(CARN, self).__init__()
53 |
54 | scale = args.scale[0]
55 | self.act = nn.ReLU(inplace=True)
56 | self.n_feat = args.n_feats
57 |
58 | self.sub_mean = common.MeanShift(args.rgb_range)
59 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
60 |
61 | self.entry = nn.Conv2d(3, self.n_feat, 3, 1, 1)
62 | self.b1 = Block(self.n_feat)
63 | self.b2 = Block(self.n_feat)
64 | self.b3 = Block(self.n_feat)
65 | self.c1 = BasicBlock(self.n_feat*2, self.n_feat, 1, 1, 0)
66 | self.c2 = BasicBlock(self.n_feat*3, self.n_feat, 1, 1, 0)
67 | self.c3 = BasicBlock(self.n_feat*4, self.n_feat, 1, 1, 0)
68 |
69 | self.upsample = common.Upsampler(scale, self.n_feat)
70 | self.exit = nn.Conv2d(self.n_feat, 3, 3, 1, 1)
71 |
72 | def forward(self, x, width_mult=1):
73 | x = self.sub_mean(x)
74 | nf = int(self.n_feat * width_mult)
75 | weight = self.entry.weight[:nf, :3, :, :]
76 | bias = self.entry.bias[:nf]
77 | x = nn.functional.conv2d(x, weight, bias, stride=self.entry.stride, padding=self.entry.padding)
78 | c0 = o0 = x
79 |
80 | b1 = self.b1(o0, width_mult)
81 | c1 = torch.cat([c0, b1], dim=1)
82 | o1 = self.c1(c1, width_mult)
83 |
84 | b2 = self.b2(o1, width_mult)
85 | c2 = torch.cat([c1, b2], dim=1)
86 | o2 = self.c2(c2, width_mult)
87 |
88 | b3 = self.b3(o2, width_mult)
89 | c3 = torch.cat([c2, b3], dim=1)
90 | o3 =self.c3(c3, width_mult)
91 |
92 | out = self.upsample(o3, width_mult)
93 | weight = self.exit.weight[:3, :nf, :, :]
94 | bias = self.exit.bias[:3]
95 | out = nn.functional.conv2d(out, weight, bias, stride=self.exit.stride, padding=self.exit.padding)
96 | out = self.add_mean(out)
97 |
98 | return out
99 |
100 |
--------------------------------------------------------------------------------
/PyTorch version/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class MeanShift(nn.Conv2d):
9 | def __init__(
10 | self, rgb_range,
11 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
12 |
13 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
14 | std = torch.Tensor(rgb_std)
15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
16 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
17 | for p in self.parameters():
18 | p.requires_grad = False
19 |
20 | class ResidualBlock(nn.Module):
21 | def __init__(self, n_feats, kernel_size, act, res_scale):
22 | super(ResidualBlock, self).__init__()
23 | self.n_feats = n_feats
24 | self.res_scale = res_scale
25 | self.kernel_size = kernel_size
26 |
27 | self.conv1 = nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=1)
28 | self.act = act
29 | self.conv2 = nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=1)
30 |
31 | def forward(self, x, width_mult=1):
32 | width = int(self.n_feats* width_mult)
33 | weight = self.conv1.weight[:width, :width, :, :]
34 | bias = self.conv1.bias[:width]
35 | residual = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2))
36 | residual = self.act(residual)
37 | weight = self.conv2.weight[:width, :width, :, :]
38 | bias = self.conv2.bias[:width]
39 | residual = nn.functional.conv2d(residual, weight, bias, padding=(self.kernel_size//2))
40 |
41 | return x + residual.mul(self.res_scale)
42 |
43 | class Upsampler(nn.Sequential):
44 | def __init__(self, scale_factor, nf):
45 | super(Upsampler, self).__init__()
46 | block = []
47 | self.nf = nf
48 | self.scale = scale_factor
49 |
50 | if scale_factor == 3:
51 | block += [
52 | nn.Conv2d(nf, nf * 9, 3, padding=1, bias=True)
53 | ]
54 | self.pixel_shuffle = nn.PixelShuffle(3)
55 | else:
56 | self.block_num = scale_factor // 2
57 | self.pixel_shuffle = nn.PixelShuffle(2)
58 | #self.act = nn.ReLU()
59 |
60 | for _ in range(self.block_num):
61 | block += [
62 | nn.Conv2d(nf, nf * (2 ** 2), 3, padding=1, bias=True)
63 | ]
64 | self.blocks = nn.ModuleList(block)
65 |
66 | def forward(self, x, width_mult=1):
67 | res = x
68 | nf = self.nf
69 | if self.scale == 3:
70 | width = int(width_mult * nf)
71 | width9 = width * 9
72 | for block in self.blocks:
73 | weight = block.weight[:width9, :width, :, :]
74 | bias = block.bias[:width9]
75 | res = nn.functional.conv2d(res, weight, bias, padding=1)
76 | res = self.pixel_shuffle(res)
77 | else:
78 | for block in self.blocks:
79 | width = int(width_mult * nf)
80 | width4 = width * 4
81 | weight = block.weight[:width4, :width, :, :]
82 | bias = block.bias[:width4]
83 | res = nn.functional.conv2d(res, weight, bias, padding=1)
84 | res = self.pixel_shuffle(res)
85 | #res = self.act(res)
86 |
87 | return res
88 |
89 | def SlimModule(input, module, width_mult):
90 | weight = module.weight
91 | out_ch, in_ch = weight.shape[:2]
92 | out_ch = int(out_ch * width_mult)
93 | in_ch = int(in_ch * width_mult)
94 | weight = weight[:out_ch, :in_ch, :, :]
95 | bias = module.bias
96 | if bias is not None:
97 | bias = module.bias[:out_ch]
98 | return nn.functional.conv2d(input, weight, bias, stride=module.stride, padding=module.padding)
--------------------------------------------------------------------------------
/PyTorch version/model/edsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | class EDSR(nn.Module):
6 | def __init__(self, args):
7 | super(EDSR, self).__init__()
8 |
9 | self.n_colors = args.n_colors
10 | n_resblocks = args.n_resblocks
11 | self.n_feats = args.n_feats
12 | self.kernel_size = 3
13 | scale = args.scale[0]
14 | act = nn.ReLU(True)
15 | self.sub_mean = common.MeanShift(args.rgb_range)
16 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
17 |
18 | self.head = nn.Conv2d(args.n_colors, self.n_feats, self.kernel_size, padding=(self.kernel_size//2), bias=True)
19 |
20 | m_body = [
21 | common.ResidualBlock(
22 | self.n_feats, self.kernel_size, act=act, res_scale=args.res_scale
23 | ) for _ in range(n_resblocks)
24 | ]
25 | self.body = nn.ModuleList(m_body)
26 | self.body_conv = nn.Conv2d(self.n_feats, self.n_feats, self.kernel_size, padding=(self.kernel_size//2), bias=True)
27 |
28 | self.upsampler = common.Upsampler(scale, self.n_feats)
29 | self.tail_conv = nn.Conv2d(self.n_feats, args.n_colors, self.kernel_size, padding=(self.kernel_size//2), bias=True)
30 |
31 | def forward(self, x, width_mult=1):
32 | feature_width = int(self.n_feats * width_mult)
33 |
34 | x = self.sub_mean(x)
35 | weight = self.head.weight[:feature_width, :self.n_colors, :, :]
36 | bias = self.head.bias[:feature_width]
37 | x = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2))
38 |
39 | residual = x
40 | for block in self.body:
41 | residual = block(residual, width_mult)
42 | weight = self.body_conv.weight[:feature_width, :feature_width, :, :]
43 | bias = self.body_conv.bias[:feature_width]
44 | residual = nn.functional.conv2d(residual, weight, bias, padding=(self.kernel_size//2))
45 | residual += x
46 |
47 | x = self.upsampler(residual, width_mult)
48 | weight = self.tail_conv.weight[:self.n_colors, :feature_width, :, :]
49 | bias = self.tail_conv.bias[:self.n_colors]
50 | x = nn.functional.conv2d(x, weight, bias, padding=(self.kernel_size//2))
51 | x = self.add_mean(x)
52 |
53 | return x
54 |
55 | def load_state_dict(self, state_dict, strict=True):
56 | own_state = self.state_dict()
57 | for name, param in state_dict.items():
58 | if name in own_state:
59 | if isinstance(param, nn.Parameter):
60 | param = param.data
61 | try:
62 | own_state[name].copy_(param)
63 | except Exception:
64 | if name.find('tail') == -1:
65 | raise RuntimeError('While copying the parameter named {}, '
66 | 'whose dimensions in the model are {} and '
67 | 'whose dimensions in the checkpoint are {}.'
68 | .format(name, own_state[name].size(), param.size()))
69 | elif strict:
70 | if name.find('tail') == -1:
71 | raise KeyError('unexpected key "{}" in state_dict'
72 | .format(name))
73 |
--------------------------------------------------------------------------------
/PyTorch version/model/rcan.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | # class SlimConv(nn.Module):
6 | # def __init__(self, in_channel, out_channel, kernel_size, padding=(kernel_size//2), bias=True):
7 | # self.in_channel = in_channel
8 | # self.out_channel = out_channel
9 | # self.padding = padding
10 | # self.bias = bias
11 | #
12 | # self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, bias=bias)
13 | #
14 | # def forward(self, x, width_mult=1):
15 | # in_channel_width = int(width_mult * self.in_channel)
16 | # out_channel_width = int(width_mult * self.out_channel)
17 | # weight = self.conv.weight[:out_channel_width, :in_channel_width, :, :]
18 | # bias = None
19 | # if self.bias:
20 | # bias = self.conv.bias[:out_channel_width]
21 | #
22 | # return nn.functional.conv2d(x, weight, bias, self.conv.stride)
23 |
24 | class CALayer(nn.Module):
25 | def __init__(self, channel, reduction=16):
26 | super(CALayer, self).__init__()
27 |
28 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
29 | self.conv_du = nn.ModuleList([
30 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
31 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
32 | ])
33 | self.relu = nn.ReLU(inplace=True)
34 | self.sigmoid = nn.Sigmoid()
35 |
36 | def forward(self, x, width_mult=1):
37 | y = self.avg_pool(x)
38 |
39 | module = getattr(self.conv_du, '0')
40 | y = common.SlimModule(y, module, width_mult)
41 | y = self.relu(y)
42 |
43 | module = getattr(self.conv_du, '1')
44 | y = common.SlimModule(y, module, width_mult)
45 | y = self.sigmoid(y)
46 |
47 | return x * y
48 |
49 | class RCAB(nn.Module):
50 | def __init__(self, n_feats, kernel_size, reduction, bias=True, act=nn.ReLU(True), res_scale=1):
51 | super(RCAB, self).__init__()
52 |
53 | modules_body = []
54 | for i in range(2):
55 | modules_body.append(nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=bias))
56 | self.caLayer = CALayer(n_feats, reduction)
57 | self.body = nn.ModuleList(modules_body)
58 | self.act = act
59 | self.res_scale = res_scale
60 |
61 | def forward(self, x, width_mult=1):
62 | module = self.body[0]
63 | res = common.SlimModule(x, module, width_mult)
64 | res = self.act(res)
65 | module = self.body[1]
66 | res = common.SlimModule(res, module, width_mult)
67 | res = self.caLayer(res, width_mult)
68 | res += x
69 | return res
70 |
71 | class ResidualGroup(nn.Module):
72 | def __init__(self, n_feats, kernel_size, reduction, act, res_scale, n_resblocks):
73 | super(ResidualGroup, self).__init__()
74 |
75 | modules_body = [
76 | RCAB(
77 | n_feats, kernel_size, reduction, bias=True, act=nn.ReLU(True), res_scale=1
78 | ) for _ in range(n_resblocks)
79 | ]
80 | self.body = nn.ModuleList(modules_body)
81 | self.conv = nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=True)
82 |
83 | def forward(self, x, width_mult):
84 | res = x
85 | for module in self.body:
86 | res = module(res, width_mult)
87 | res = common.SlimModule(res, self.conv, width_mult)
88 | res += x
89 | return res
90 |
91 |
92 | class RCAN(nn.Module):
93 | def __init__(self, args):
94 | super(RCAN, self).__init__()
95 |
96 | self.n_colors = args.n_colors
97 | n_resgroups = args.n_resgroups
98 | n_resblocks = args.n_resblocks
99 | n_feats = args.n_feats
100 | kernel_size = 3
101 | reduction = args.reduction
102 | scale = args.scale[0]
103 | act = nn.ReLU(True)
104 |
105 | self.sub_mean = common.MeanShift(args.rgb_range)
106 |
107 | self.head_conv = nn.Conv2d(args.n_colors, n_feats, kernel_size, padding=(kernel_size//2), bias=True)
108 |
109 | modules_body = [
110 | ResidualGroup(
111 | n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks
112 | ) for _ in range(n_resgroups)
113 | ]
114 | # modules_body = [
115 | # common.ResidualBlock(
116 | # n_feats, kernel_size, act, args.res_scale
117 | # ) for _ in range(n_resgroups)
118 | # ]
119 | self.body = nn.ModuleList(modules_body)
120 | self.body_conv = nn.Conv2d(n_feats, n_feats, kernel_size, padding=(kernel_size//2), bias=True)
121 |
122 | self.upsampler = common.Upsampler(scale, n_feats)
123 | self.tail_conv = nn.Conv2d(n_feats, args.n_colors, kernel_size, padding=(kernel_size // 2),
124 | bias=True)
125 |
126 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
127 |
128 | def forward(self, x, width_mult=1):
129 | x = self.sub_mean(x)
130 | weight = self.head_conv.weight
131 | n_feats = weight.shape[0]
132 | out_ch = int(n_feats * width_mult)
133 | weight = weight[:out_ch, :self.n_colors, :, :]
134 | bias = self.head_conv.bias[:out_ch]
135 | x = nn.functional.conv2d(x, weight, bias, stride=self.head_conv.stride, padding=self.head_conv.padding)
136 |
137 | res = x
138 | for module in self.body:
139 | res = module(res, width_mult)
140 | res = common.SlimModule(res, self.body_conv, width_mult)
141 | res += x
142 |
143 | x = self.upsampler(res, width_mult)
144 | weight = self.tail_conv.weight[:self.n_colors, :out_ch, :, :]
145 | bias = self.tail_conv.bias[:self.n_colors]
146 | x = nn.functional.conv2d(x, weight, bias, stride=self.tail_conv.stride, padding=self.tail_conv.padding)
147 | x = self.add_mean(x)
148 |
149 | return x
150 |
151 | def load_state_dict(self, state_dict, strict=False):
152 | own_state = self.state_dict()
153 | for name, param in state_dict.items():
154 | if name in own_state:
155 | if isinstance(param, nn.Parameter):
156 | param = param.data
157 | try:
158 | own_state[name].copy_(param)
159 | except Exception:
160 | if name.find('tail') >= 0:
161 | print('Replace pre-trained upsampler to new one...')
162 | else:
163 | raise RuntimeError('While copying the parameter named {}, '
164 | 'whose dimensions in the model are {} and '
165 | 'whose dimensions in the checkpoint are {}.'
166 | .format(name, own_state[name].size(), param.size()))
167 | elif strict:
168 | if name.find('tail') == -1:
169 | raise KeyError('unexpected key "{}" in state_dict'
170 | .format(name))
171 |
172 | if strict:
173 | missing = set(own_state.keys()) - set(state_dict.keys())
174 | if len(missing) > 0:
175 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
176 |
--------------------------------------------------------------------------------
/PyTorch version/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='EDSR and MDSR')
4 |
5 | parser.add_argument('--debug', action='store_true',
6 | help='Enables debug mode')
7 | parser.add_argument('--template', default='.',
8 | help='You can set various templates in option.py')
9 |
10 | # Hardware specifications
11 | parser.add_argument('--n_threads', type=int, default=6,
12 | help='number of threads for data loading')
13 | parser.add_argument('--cpu', action='store_true',
14 | help='use cpu only')
15 | parser.add_argument('--n_GPUs', type=int, default=1,
16 | help='number of GPUs')
17 | parser.add_argument('--seed', type=int, default=1,
18 | help='random seed')
19 |
20 | # Data specifications
21 | parser.add_argument('--dir_data', type=str, default='../../../dataset',
22 | help='dataset directory')
23 | parser.add_argument('--dir_demo', type=str, default='../test',
24 | help='demo image directory')
25 | parser.add_argument('--data_train', type=str, default='DIV2K',
26 | help='train dataset name')
27 | parser.add_argument('--data_test', type=str, default='DIV2K',
28 | help='test dataset name')
29 | parser.add_argument('--data_range', type=str, default='1-800/801-810',
30 | help='train/test data range')
31 | parser.add_argument('--ext', type=str, default='sep',
32 | help='dataset file extension')
33 | parser.add_argument('--scale', type=str, default='4',
34 | help='super resolution scale')
35 | parser.add_argument('--patch_size', type=int, default=192,
36 | help='output patch size')
37 | parser.add_argument('--rgb_range', type=int, default=255,
38 | help='maximum value of RGB')
39 | parser.add_argument('--n_colors', type=int, default=3,
40 | help='number of color channels to use')
41 | parser.add_argument('--chop', action='store_true',
42 | help='enable memory-efficient forward')
43 | parser.add_argument('--no_augment', action='store_true',
44 | help='do not use data augmentation')
45 |
46 | # Model specifications
47 | parser.add_argument('--model', type=str, default='EDSR',
48 | help='model name')
49 |
50 | parser.add_argument('--act', type=str, default='relu',
51 | help='activation function')
52 | parser.add_argument('--pre_train', type=str, default='',
53 | help='pre-trained model directory')
54 | parser.add_argument('--extend', type=str, default='.',
55 | help='pre-trained model directory')
56 | parser.add_argument('--n_resblocks', type=int, default=32,
57 | help='number of residual blocks')
58 | parser.add_argument('--n_feats', type=int, default=256,
59 | help='number of feature maps')
60 | parser.add_argument('--res_scale', type=float, default=0.1,
61 | help='residual scaling')
62 | parser.add_argument('--shift_mean', default=True,
63 | help='subtract pixel mean from the input')
64 | parser.add_argument('--dilation', action='store_true',
65 | help='use dilated convolution')
66 | parser.add_argument('--precision', type=str, default='single',
67 | choices=('single', 'half'),
68 | help='FP precision for test (single | half)')
69 |
70 | # Option for Residual dense network (RDN)
71 | parser.add_argument('--G0', type=int, default=64,
72 | help='default number of filters. (Use in RDN)')
73 | parser.add_argument('--RDNkSize', type=int, default=3,
74 | help='default kernel size. (Use in RDN)')
75 | parser.add_argument('--RDNconfig', type=str, default='B',
76 | help='parameters config of RDN. (Use in RDN)')
77 |
78 | # Option for Residual channel attention network (RCAN)
79 | parser.add_argument('--n_resgroups', type=int, default=10,
80 | help='number of residual groups')
81 | parser.add_argument('--reduction', type=int, default=16,
82 | help='number of feature maps reduction')
83 |
84 | # Training specifications
85 | parser.add_argument('--reset', action='store_true',
86 | help='reset the training')
87 | parser.add_argument('--test_every', type=int, default=1000,
88 | help='do test per every N batches')
89 | parser.add_argument('--epochs', type=int, default=300,
90 | help='number of epochs to train')
91 | parser.add_argument('--batch_size', type=int, default=16,
92 | help='input batch size for training')
93 | parser.add_argument('--split_batch', type=int, default=1,
94 | help='split the batch into smaller chunks')
95 | parser.add_argument('--self_ensemble', action='store_true',
96 | help='use self-ensemble method for test')
97 | parser.add_argument('--test_only', action='store_true',
98 | help='set this option to test the model')
99 | parser.add_argument('--gan_k', type=int, default=1,
100 | help='k value for adversarial loss')
101 |
102 | # Optimization specifications
103 | parser.add_argument('--lr', type=float, default=1e-4,
104 | help='learning rate')
105 | parser.add_argument('--decay', type=str, default='200',
106 | help='learning rate decay type')
107 | parser.add_argument('--gamma', type=float, default=0.5,
108 | help='learning rate decay factor for step decay')
109 | parser.add_argument('--optimizer', default='ADAM',
110 | choices=('SGD', 'ADAM', 'RMSprop'),
111 | help='optimizer to use (SGD | ADAM | RMSprop)')
112 | parser.add_argument('--momentum', type=float, default=0.9,
113 | help='SGD momentum')
114 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
115 | help='ADAM beta')
116 | parser.add_argument('--epsilon', type=float, default=1e-8,
117 | help='ADAM epsilon for numerical stability')
118 | parser.add_argument('--weight_decay', type=float, default=0,
119 | help='weight decay')
120 | parser.add_argument('--gclip', type=float, default=0,
121 | help='gradient clipping threshold (0 = no clipping)')
122 |
123 | # Loss specifications
124 | parser.add_argument('--loss', type=str, default='1*L1',
125 | help='loss function configuration')
126 | parser.add_argument('--skip_threshold', type=float, default='1e8',
127 | help='skipping batch that has large error')
128 |
129 | # Log specifications
130 | parser.add_argument('--save', type=str, default='test',
131 | help='file name to save')
132 | parser.add_argument('--load', type=str, default='',
133 | help='file name to load')
134 | parser.add_argument('--resume', type=int, default=0,
135 | help='resume from specific checkpoint')
136 | parser.add_argument('--save_models', action='store_true',
137 | help='save all intermediate models')
138 | parser.add_argument('--print_every', type=int, default=100,
139 | help='how many batches to wait before logging training status')
140 | parser.add_argument('--save_results', action='store_true',
141 | help='save output results')
142 | parser.add_argument('--save_gt', action='store_true',
143 | help='save low-resolution and high-resolution images together')
144 |
145 | parser.add_argument('--stu_width_mult', type=float, default=0.25,
146 | help='width_mult of student model')
147 | parser.add_argument('--bn', action='store_true', help='use bn in residual block')
148 | parser.add_argument('--slim_num', type=int, default=1, help='divide the network into {slim_part} parts')
149 | parser.add_argument('--full_width_mult', type=float, default=1, help='set full width mult')
150 | parser.add_argument('--batch_test', action='store_true', help='test a batch')
151 | parser.add_argument('--print_each', action='store_true', help='print result of every sample')
152 | parser.add_argument('--model_stat', action='store_true', help='count model params and flops')
153 |
154 | parser.add_argument('--content_loss_factor', type=float, default=1e-1, help='content loss factor')
155 | parser.add_argument('--feature_loss_factor', type=float, default=1, help='feature loss factor')
156 | parser.add_argument('--adversarial_loss_factor', type=float, default=5e-3, help='adversarial loss factor')
157 |
158 | parser.add_argument('--model_filename', type=str, default='', help='pre-train model filename')
159 | parser.add_argument('--model_str', type=str, default='', help='save the model as filename{model_str}')
160 | parser.add_argument('--teacher_model', type=str, default='output/model/edsr/baseline/edsr_x4_baseline.pth',
161 | help='load teacher model from {teacher_model}')
162 |
163 | parser.add_argument('--neg_num', type=int, default=8,
164 | help='negative samples number')
165 |
166 | parser.add_argument('--t_lambda', type=float, default=0, help='weight of l1(hr, teacher_sr)')
167 | parser.add_argument('--contra_lambda', type=float, default=1, help='weight of contra_loss')
168 | parser.add_argument('--ad_lambda', type=float, default=0, help='weight of adversarial loss')
169 | parser.add_argument('--percep_lambda', type=float, default=0, help='weight of perceptual loss')
170 | parser.add_argument('--kd_lambda', type=float, default=0, help='weight of kd loss')
171 | parser.add_argument('--vgg_weight', nargs='+', type=float, default=[1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0],
172 | help='weight of vgg features in contrastive loss')
173 | parser.add_argument('--d_func', type=str, default="L1", help='the distance function in contrastive loss')
174 | parser.add_argument('--mean_outside', action='store_true', help='calc mean for negative samples outside the contrast_loss')
175 | parser.add_argument('--t_l_remove', type=int, default=0, help='remove teacher loss @ epoch {t_l_remove}')
176 | parser.add_argument('--contrast_t_detach', action='store_true', help='detach teacher in contrast_loss')
177 | parser.add_argument('--gt_as_pos', action='store_true', help='use gt as positive sample')
178 | parser.add_argument('--blur_sigma', type=float, default=0, help='blur sigma of neg sample')
179 | parser.add_argument('--noise_sigma', type=float, default=0, help='noise sigma of neg sample')
180 |
181 | parser.add_argument('--seperate', action='store_true', help='seperate teacher and student')
182 |
183 | args = parser.parse_args()
184 |
185 | args.scale = list(map(lambda x: int(x), args.scale.split('+')))
186 | args.data_train = args.data_train.split('+')
187 | args.data_test = args.data_test.split('+')
188 |
189 | if args.epochs == 0:
190 | args.epochs = 1e8
191 |
192 | for arg in vars(args):
193 | if vars(args)[arg] == 'True':
194 | vars(args)[arg] = True
195 | elif vars(args)[arg] == 'False':
196 | vars(args)[arg] = False
197 |
198 |
--------------------------------------------------------------------------------
/PyTorch version/trainer/slim_contrast_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from decimal import Decimal
4 | from glob import glob
5 | import datetime, time
6 | from importlib import import_module
7 | import numpy as np
8 |
9 | # import lpips
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torchvision.transforms as transforms
14 | from tensorboardX import SummaryWriter
15 |
16 | import utils.utility as utility
17 | from loss.contrast_loss import ContrastLoss
18 | from loss.adversarial import Adversarial
19 | from loss.perceptual import PerceptualLoss
20 | from model.edsr import EDSR
21 | from model.rcan import RCAN
22 | from utils.niqe import niqe
23 | from utils.ssim import calc_ssim
24 |
25 |
26 | class SlimContrastiveTrainer:
27 | def __init__(self, args, loader, device, neg_loader=None):
28 | self.model_str = args.model.lower()
29 | self.pic_path = f'./output/{self.model_str}/{args.model_filename}/'
30 | if not os.path.exists(self.pic_path):
31 | self.makedirs = os.makedirs(self.pic_path)
32 | self.teacher_model = args.teacher_model
33 | self.checkpoint_dir = args.pre_train
34 | self.model_filename = args.model_filename
35 | self.model_filepath = f'{self.model_filename}.pth'
36 | self.writer = SummaryWriter(f'log/{self.model_filename}')
37 |
38 | self.start_epoch = -1
39 | self.device = device
40 | self.epochs = args.epochs
41 | self.init_lr = args.lr
42 | self.rgb_range = args.rgb_range
43 | self.scale = args.scale[0]
44 | self.stu_width_mult = args.stu_width_mult
45 | self.batch_size = args.batch_size
46 | self.neg_num = args.neg_num
47 | self.save_results = args.save_results
48 | self.self_ensemble = args.self_ensemble
49 | self.print_every = args.print_every
50 | self.best_psnr = 0
51 | self.best_psnr_epoch = -1
52 |
53 | self.loader = loader
54 | self.mean = [0.404, 0.436, 0.446]
55 | self.std = [0.288, 0.263, 0.275]
56 |
57 | self.build_model(args)
58 | self.upsampler = nn.Upsample(scale_factor=self.scale, mode='bicubic')
59 | self.optimizer = utility.make_optimizer(args, self.model)
60 |
61 | self.t_lambda = args.t_lambda
62 | self.contra_lambda = args.contra_lambda
63 | self.ad_lambda = args.ad_lambda
64 | self.percep_lambda = args.percep_lambda
65 | self.t_detach = args.contrast_t_detach
66 | self.contra_loss = ContrastLoss(args.vgg_weight, args.d_func, self.t_detach)
67 | self.l1_loss = nn.L1Loss()
68 | self.ad_loss = Adversarial(args, 'GAN')
69 | self.percep_loss = PerceptualLoss()
70 | self.t_l_remove = args.t_l_remove
71 |
72 | def train(self):
73 | self.model.train()
74 |
75 | total_iter = (self.start_epoch+1)*len(self.loader.loader_train)
76 | for epoch in range(self.start_epoch + 1, self.epochs):
77 | if epoch >= self.t_l_remove:
78 | self.t_lambda = 0
79 |
80 | starttime = datetime.datetime.now()
81 |
82 | lrate = utility.adjust_learning_rate(self.optimizer, epoch, self.epochs, self.init_lr)
83 | print("[Epoch {}]\tlr:{}\t".format(epoch, lrate))
84 | psnr, t_psnr = 0.0, 0.0
85 | step = 0
86 | for batch, (lr, hr, _,) in enumerate(self.loader.loader_train):
87 | torch.cuda.empty_cache()
88 | step += 1
89 | total_iter += 1
90 | lr = lr.to(self.device)
91 | hr = hr.to(self.device)
92 |
93 | self.optimizer.zero_grad()
94 | teacher_sr = self.model(lr)
95 |
96 | student_sr = self.model(lr, self.stu_width_mult)
97 | l1_loss = self.l1_loss(hr, student_sr)
98 | teacher_l1_loss = self.l1_loss(hr, teacher_sr)
99 |
100 | bic_sample = lr[torch.randperm(self.neg_num), :, :, :]
101 | bic_sample = self.upsampler(bic_sample)
102 | contras_loss = 0.0
103 |
104 | if self.neg_num > 0:
105 | contras_loss = self.contra_loss(teacher_sr, student_sr, bic_sample)
106 |
107 | loss = l1_loss + self.contra_lambda * contras_loss + self.t_lambda * teacher_l1_loss
108 | if self.ad_lambda > 0:
109 | ad_loss = self.ad_loss(student_sr, hr)
110 | loss += self.ad_lambda * ad_loss
111 | self.writer.add_scalar('Train/Ad_loss', ad_loss, total_iter)
112 | if self.percep_lambda > 0:
113 | percep_loss = self.percep_loss(hr, student_sr)
114 | loss += self.percep_lambda * percep_loss
115 | self.writer.add_scalar('Train/Percep_loss', percep_loss, total_iter)
116 |
117 | loss.backward()
118 | self.optimizer.step()
119 |
120 | self.writer.add_scalar('Train/L1_loss', l1_loss, total_iter)
121 | self.writer.add_scalar('Train/Contras_loss', contras_loss, total_iter)
122 | self.writer.add_scalar('Train/Teacher_l1_loss', teacher_l1_loss, total_iter)
123 | self.writer.add_scalar('Train/Total_loss', loss, total_iter)
124 |
125 | student_sr = utility.quantize(student_sr, self.rgb_range)
126 | psnr += utility.calc_psnr(student_sr, hr, self.scale, self.rgb_range)
127 | teacher_sr = utility.quantize(teacher_sr, self.rgb_range)
128 | t_psnr += utility.calc_psnr(teacher_sr, hr, self.scale, self.rgb_range)
129 | if (batch + 1) % self.print_every == 0:
130 | print(
131 | f"[Epoch {epoch}/{self.epochs}] [Batch {batch * self.batch_size}/{len(self.loader.loader_train.dataset)}] "
132 | f"[psnr {psnr / step}]"
133 | f"[t_psnr {t_psnr / step}]"
134 | )
135 | utility.save_results(f'result_{batch}', hr, self.scale, width=1, rgb_range=self.rgb_range,
136 | postfix='hr', dir=self.pic_path)
137 | utility.save_results(f'result_{batch}', teacher_sr, self.scale, width=1, rgb_range=self.rgb_range,
138 | postfix='t_sr', dir=self.pic_path)
139 | utility.save_results(f'result_{batch}', student_sr, self.scale, width=1, rgb_range=self.rgb_range,
140 | postfix='s_sr', dir=self.pic_path)
141 |
142 | print(f"training PSNR @epoch {epoch}: {psnr / step}")
143 |
144 | test_psnr = self.test(self.stu_width_mult)
145 | if test_psnr > self.best_psnr:
146 | print(f"saving models @epoch {epoch} with psnr: {test_psnr}")
147 | self.best_psnr = test_psnr
148 | self.best_psnr_epoch = epoch
149 | torch.save({
150 | 'epoch': epoch,
151 | 'model_state_dict': self.model.state_dict(),
152 | 'optimizer_state_dict': self.optimizer.state_dict(),
153 | 'best_psnr': self.best_psnr,
154 | 'best_psnr_epoch': self.best_psnr_epoch,
155 | }, f'{self.checkpoint_dir}{self.model_filepath}')
156 |
157 | endtime = datetime.datetime.now()
158 | cost = (endtime - starttime).seconds
159 | print(f"time of epoch{epoch}: {cost}")
160 |
161 | def test(self, width_mult=1):
162 | self.model.eval()
163 | with torch.no_grad():
164 | psnr = 0
165 | niqe_score = 0
166 | ssim = 0
167 | t0 = time.time()
168 |
169 | starttime = datetime.datetime.now()
170 | for d in self.loader.loader_test:
171 | for lr, hr, filename in d:
172 | lr = lr.to(self.device)
173 | hr = hr.to(self.device)
174 |
175 | x = [lr]
176 | for tf in 'v', 'h', 't':
177 | x.extend([utility.transform(_x, tf, self.device) for _x in x])
178 | op = ['', 'v', 'h', 'hv', 't', 'tv', 'th', 'thv']
179 |
180 | if self.self_ensemble:
181 | res = self.model(lr, width_mult)
182 | for i in range(1, len(x)):
183 | _x = x[i]
184 | _sr = self.model(_x, width_mult)
185 | for _op in op[i]:
186 | _sr = utility.transform(_sr, _op, self.device)
187 | res = torch.cat((res, _sr), 0)
188 | sr = torch.mean(res, 0).unsqueeze(0)
189 | else:
190 | sr = self.model(lr, width_mult)
191 |
192 | sr = utility.quantize(sr, self.rgb_range)
193 | if self.save_results:
194 | # if not os.path.exists(f'./output/test/{self.model_str}/{self.model_filename}'):
195 | # self.makedirs = os.makedirs(f'./output/test/{self.model_str}/{self.model_filename}')
196 | utility.save_results(str(filename), sr, self.scale, width_mult,
197 | self.rgb_range, 'SR')
198 |
199 | psnr += utility.calc_psnr(sr, hr, self.scale, self.rgb_range, dataset=d)
200 | niqe_score += niqe(sr.squeeze(0).permute(1, 2, 0).cpu().numpy())
201 | ssim += calc_ssim(sr, hr, self.scale, dataset=d)
202 |
203 | psnr /= len(d)
204 | niqe_score /= len(d)
205 | ssim /= len(d)
206 | print(width_mult, d.dataset.name, psnr, niqe_score, ssim)
207 |
208 | endtime = datetime.datetime.now()
209 | cost = (endtime - starttime).seconds
210 | t1 = time.time()
211 | total_time = (t1 - t0)
212 | print(f"time of test: {total_time}")
213 | return psnr
214 |
215 | def build_model(self, args):
216 | m = import_module('model.' + self.model_str)
217 | self.model = getattr(m, self.model_str.upper())(args).to(self.device)
218 | self.model = nn.DataParallel(self.model, device_ids=range(args.n_GPUs))
219 | self.load_model()
220 |
221 | # test teacher
222 | # self.test()
223 |
224 | def load_model(self):
225 | checkpoint_dir = self.checkpoint_dir
226 | print(f"[*] Load model from {checkpoint_dir}")
227 | if not os.path.exists(checkpoint_dir):
228 | self.makedirs = os.makedirs(checkpoint_dir)
229 |
230 | if not os.listdir(checkpoint_dir):
231 | print(f"[!] No checkpoint in {checkpoint_dir}")
232 | return
233 |
234 | model = glob(os.path.join(checkpoint_dir, self.model_filepath))
235 |
236 | no_student = False
237 | if not model:
238 | no_student = True
239 | print(f"[!] No checkpoint ")
240 | print("Loading pre-trained teacher model")
241 | model = glob(self.teacher_model)
242 | if not model:
243 | print(f"[!] No teacher model ")
244 | return
245 |
246 | model_state_dict = torch.load(model[0])
247 | if not no_student:
248 | self.start_epoch = model_state_dict['epoch']
249 | self.best_psnr = model_state_dict['best_psnr']
250 | self.best_psnr_epoch = model_state_dict['best_psnr_epoch']
251 |
252 | self.model.load_state_dict(model_state_dict['model_state_dict'], False)
--------------------------------------------------------------------------------
/PyTorch version/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/utils/__init__.py
--------------------------------------------------------------------------------
/PyTorch version/utils/niqe.py:
--------------------------------------------------------------------------------
1 | import math
2 | from os.path import dirname, join
3 |
4 | import cv2
5 | import numpy as np
6 | import scipy
7 | import scipy.io
8 | import scipy.misc
9 | import scipy.ndimage
10 | import scipy.special
11 | from PIL import Image
12 |
13 | gamma_range = np.arange(0.2, 10, 0.001)
14 | a = scipy.special.gamma(2.0/gamma_range)
15 | a *= a
16 | b = scipy.special.gamma(1.0/gamma_range)
17 | c = scipy.special.gamma(3.0/gamma_range)
18 | prec_gammas = a/(b*c)
19 |
20 |
21 | def aggd_features(imdata):
22 | # flatten imdata
23 | imdata.shape = (len(imdata.flat),)
24 | imdata2 = imdata*imdata
25 | left_data = imdata2[imdata < 0]
26 | right_data = imdata2[imdata >= 0]
27 | left_mean_sqrt = 0
28 | right_mean_sqrt = 0
29 | if len(left_data) > 0:
30 | left_mean_sqrt = np.sqrt(np.average(left_data))
31 | if len(right_data) > 0:
32 | right_mean_sqrt = np.sqrt(np.average(right_data))
33 |
34 | if right_mean_sqrt != 0:
35 | gamma_hat = left_mean_sqrt/right_mean_sqrt
36 | else:
37 | gamma_hat = np.inf
38 | # solve r-hat norm
39 |
40 | imdata2_mean = np.mean(imdata2)
41 | if imdata2_mean != 0:
42 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2))
43 | else:
44 | r_hat = np.inf
45 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) *
46 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2))
47 |
48 | # solve alpha by guessing values that minimize ro
49 | pos = np.argmin((prec_gammas - rhat_norm)**2)
50 | alpha = gamma_range[pos]
51 |
52 | gam1 = scipy.special.gamma(1.0/alpha)
53 | gam2 = scipy.special.gamma(2.0/alpha)
54 | gam3 = scipy.special.gamma(3.0/alpha)
55 |
56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3)
57 | bl = aggdratio * left_mean_sqrt
58 | br = aggdratio * right_mean_sqrt
59 |
60 | # mean parameter
61 | N = (br - bl)*(gam2 / gam1) # *aggdratio
62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt)
63 |
64 |
65 | def ggd_features(imdata):
66 | nr_gam = 1/prec_gammas
67 | sigma_sq = np.var(imdata)
68 | E = np.mean(np.abs(imdata))
69 | rho = sigma_sq/E**2
70 | pos = np.argmin(np.abs(nr_gam - rho))
71 | return gamma_range[pos], sigma_sq
72 |
73 |
74 | def paired_product(new_im):
75 | shift1 = np.roll(new_im.copy(), 1, axis=1)
76 | shift2 = np.roll(new_im.copy(), 1, axis=0)
77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1)
78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1)
79 |
80 | H_img = shift1 * new_im
81 | V_img = shift2 * new_im
82 | D1_img = shift3 * new_im
83 | D2_img = shift4 * new_im
84 |
85 | return (H_img, V_img, D1_img, D2_img)
86 |
87 |
88 | def gen_gauss_window(lw, sigma):
89 | sd = np.float32(sigma)
90 | lw = int(lw)
91 | weights = [0.0] * (2 * lw + 1)
92 | weights[lw] = 1.0
93 | sum = 1.0
94 | sd *= sd
95 | for ii in range(1, lw + 1):
96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd)
97 | weights[lw + ii] = tmp
98 | weights[lw - ii] = tmp
99 | sum += 2.0 * tmp
100 | for ii in range(2 * lw + 1):
101 | weights[ii] /= sum
102 | return weights
103 |
104 |
105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'):
106 | if avg_window is None:
107 | avg_window = gen_gauss_window(3, 7.0/6.0)
108 | assert len(np.shape(image)) == 2
109 | h, w = np.shape(image)
110 | mu_image = np.zeros((h, w), dtype=np.float32)
111 | var_image = np.zeros((h, w), dtype=np.float32)
112 | image = np.array(image).astype('float32')
113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode)
114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1,
115 | mu_image, mode=extend_mode)
116 | scipy.ndimage.correlate1d(image**2, avg_window, 0,
117 | var_image, mode=extend_mode)
118 | scipy.ndimage.correlate1d(var_image, avg_window,
119 | 1, var_image, mode=extend_mode)
120 | var_image = np.sqrt(np.abs(var_image - mu_image**2))
121 | return (image - mu_image)/(var_image + C), var_image, mu_image
122 |
123 |
124 | def _niqe_extract_subband_feats(mscncoefs):
125 | # alpha_m, = extract_ggd_features(mscncoefs)
126 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy())
127 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs)
128 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1)
129 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2)
130 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3)
131 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4)
132 | return np.array([alpha_m, (bl+br)/2.0,
133 | alpha1, N1, bl1, br1, # (V)
134 | alpha2, N2, bl2, br2, # (H)
135 | alpha3, N3, bl3, bl3, # (D1)
136 | alpha4, N4, bl4, bl4, # (D2)
137 | ])
138 |
139 |
140 | def get_patches_train_features(img, patch_size, stride=8):
141 | return _get_patches_generic(img, patch_size, 1, stride)
142 |
143 |
144 | def get_patches_test_features(img, patch_size, stride=8):
145 | return _get_patches_generic(img, patch_size, 0, stride)
146 |
147 |
148 | def extract_on_patches(img, patch_size):
149 | h, w = img.shape
150 | patch_size = np.int(patch_size)
151 | patches = []
152 | for j in range(0, h-patch_size+1, patch_size):
153 | for i in range(0, w-patch_size+1, patch_size):
154 | patch = img[j:j+patch_size, i:i+patch_size]
155 | patches.append(patch)
156 |
157 | patches = np.array(patches)
158 |
159 | patch_features = []
160 | for p in patches:
161 | patch_features.append(_niqe_extract_subband_feats(p))
162 | patch_features = np.array(patch_features)
163 |
164 | return patch_features
165 |
166 |
167 | def _get_patches_generic(img, patch_size, is_train, stride):
168 | h, w = np.shape(img)
169 | if h < patch_size or w < patch_size:
170 | print("Input image is too small")
171 | exit(0)
172 |
173 | # ensure that the patch divides evenly into img
174 | hoffset = (h % patch_size)
175 | woffset = (w % patch_size)
176 |
177 | if hoffset > 0:
178 | img = img[:-hoffset, :]
179 | if woffset > 0:
180 | img = img[:, :-woffset]
181 |
182 | img = img.astype(np.float32)
183 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F')
184 | img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5)
185 |
186 | mscn1, var, mu = compute_image_mscn_transform(img)
187 | mscn1 = mscn1.astype(np.float32)
188 |
189 | mscn2, _, _ = compute_image_mscn_transform(img2)
190 | mscn2 = mscn2.astype(np.float32)
191 |
192 | feats_lvl1 = extract_on_patches(mscn1, patch_size)
193 | feats_lvl2 = extract_on_patches(mscn2, patch_size/2)
194 |
195 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3))
196 |
197 | return feats
198 |
199 |
200 | def niqe(inputImgData):
201 |
202 | patch_size = 96
203 | module_path = dirname(__file__)
204 |
205 | # TODO: memoize
206 | params = scipy.io.loadmat(
207 | join(module_path, 'niqe_image_params.mat'))
208 | pop_mu = np.ravel(params["pop_mu"])
209 | pop_cov = params["pop_cov"]
210 |
211 | if inputImgData.ndim == 3:
212 | inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY)
213 | M, N = inputImgData.shape
214 |
215 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,)
216 | assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
217 | assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
218 |
219 | feats = get_patches_test_features(inputImgData, patch_size)
220 | sample_mu = np.mean(feats, axis=0)
221 | sample_cov = np.cov(feats.T)
222 |
223 | X = sample_mu - pop_mu
224 | covmat = ((pop_cov+sample_cov)/2.0)
225 | pinvmat = scipy.linalg.pinv(covmat)
226 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X))
227 |
228 | return niqe_score
229 |
--------------------------------------------------------------------------------
/PyTorch version/utils/niqe_image_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/PyTorch version/utils/niqe_image_params.mat
--------------------------------------------------------------------------------
/PyTorch version/utils/spatial_trans.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def spatial_similarity(fm):
5 | fm = fm.view(fm.size(0), fm.size(1), -1)
6 | norm_fm = fm / (torch.sqrt(torch.sum(torch.pow(fm,2), 1)).unsqueeze(1).expand(fm.shape) + 0.0000001)
7 | s = norm_fm.transpose(1,2).bmm(norm_fm)
8 | s = s.unsqueeze(1)
9 | return s
10 |
11 | def channel_similarity(fm):
12 | fm = fm.view(fm.size(0), fm.size(1), -1)
13 | norm_fm = fm / (torch.sqrt(torch.sum(torch.pow(fm,2), 2)).unsqueeze(2).expand(fm.shape) + 0.0000001)
14 | s = norm_fm.bmm(norm_fm.transpose(1,2))
15 | s = s.unsqueeze(1)
16 | return s
17 |
18 | def batch_similarity(fm):
19 | fm = fm.view(fm.size(0), -1)
20 | Q = torch.mm(fm, fm.transpose(0,1))
21 | normalized_Q = Q / torch.norm(Q,2,dim=1).unsqueeze(1).expand(Q.shape)
22 | return normalized_Q
23 |
24 | def FSP(fm1, fm2):
25 | if fm1.size(2) > fm2.size(2):
26 | fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
27 |
28 | fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
29 | fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)
30 |
31 | fsp = torch.bmm(fm1, fm2) / fm1.size(2)
32 |
33 | return fsp
34 |
35 | def AT(fm):
36 | eps=1e-6
37 | am = torch.pow(torch.abs(fm), 2)
38 | am = torch.sum(am, dim=1, keepdim=True)
39 | norm = torch.norm(am, dim=(2,3), keepdim=True)
40 | am = torch.div(am, norm+eps)
41 | return am
--------------------------------------------------------------------------------
/PyTorch version/utils/ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import signal
3 |
4 | def trans2Y(img):
5 | img_r = img[:, 0, :, :]
6 | img_g = img[:, 1, :, :]
7 | img_b = img[:, 2, :, :]
8 | img_y = 0.256789 * img_r + 0.504129 * img_g + 0.097906 * img_b + 16
9 | return img_y
10 |
11 | def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
12 | """
13 | 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
14 | Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
15 | """
16 | m,n = [(ss-1.)/2. for ss in shape]
17 | y,x = np.ogrid[-m:m+1,-n:n+1]
18 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
19 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
20 | sumh = h.sum()
21 | if sumh != 0:
22 | h /= sumh
23 | return h
24 |
25 | def calc_ssim(X, Y, scale, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
26 | '''
27 | X : y channel (i.e., luminance) of transformed YCbCr space of X
28 | Y : y channel (i.e., luminance) of transformed YCbCr space of Y
29 | Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017).
30 | Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba
31 | The authors of EDSR use MATLAB's ssim as the evaluation tool,
32 | thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2.
33 | '''
34 | gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
35 |
36 | X = trans2Y(X).squeeze()
37 | Y = trans2Y(Y).squeeze()
38 | X = X.cpu().numpy().astype(np.float64)
39 | Y = Y.cpu().numpy().astype(np.float64)
40 |
41 | shave = scale
42 | if dataset and not dataset.dataset.benchmark:
43 | shave = scale + 6
44 | X = X[shave:-shave, shave:-shave]
45 | Y = Y[shave:-shave, shave:-shave]
46 |
47 | window = gaussian_filter / np.sum(np.sum(gaussian_filter))
48 |
49 | window = np.fliplr(window)
50 | window = np.flipud(window)
51 |
52 | ux = signal.convolve2d(X, window, mode='valid', boundary='fill', fillvalue=0)
53 | uy = signal.convolve2d(Y, window, mode='valid', boundary='fill', fillvalue=0)
54 |
55 | uxx = signal.convolve2d(X * X, window, mode='valid', boundary='fill', fillvalue=0)
56 | uyy = signal.convolve2d(Y * Y, window, mode='valid', boundary='fill', fillvalue=0)
57 | uxy = signal.convolve2d(X * Y, window, mode='valid', boundary='fill', fillvalue=0)
58 |
59 | vx = uxx - ux * ux
60 | vy = uyy - uy * uy
61 | vxy = uxy - ux * uy
62 |
63 | C1 = (K1 * R) ** 2
64 | C2 = (K2 * R) ** 2
65 |
66 | A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
67 | D = B1 * B2
68 | S = (A1 * A2) / D
69 | mssim = S.mean()
70 |
71 | # window = gaussian_filter
72 | #
73 | # ux = signal.convolve2d(X, window, mode='same', boundary='symm')
74 | # uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
75 | #
76 | # uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
77 | # uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
78 | # uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
79 | #
80 | # vx = uxx - ux * ux
81 | # vy = uyy - uy * uy
82 | # vxy = uxy - ux * uy
83 | #
84 | # C1 = (K1 * R) ** 2
85 | # C2 = (K2 * R) ** 2
86 | #
87 | # A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
88 | # D = B1 * B2
89 | # S = (A1 * A2) / D
90 | # mssim = S.mean()
91 |
92 | return mssim
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CSD
2 | This is the official implementation (including [PyTorch version](https://github.com/Booooooooooo/CSD/tree/main/PyTorch%20version) and [MindSpore version](https://github.com/Booooooooooo/CSD/tree/main/MindSpore%20version)) of [Towards Compact Single Image Super-Resolution via Contrastive Self-distillation, IJCAI2021](https://arxiv.org/abs/2105.11683)
3 |
4 | ## Abstract
5 |
6 | Convolutional neural networks (CNNs) are highly successful for super-resolution (SR) but often require sophisticated architectures with heavy memory cost and computational overhead, significantly restricts their practical deployments on resource-limited devices. In this paper, we proposed a novel contrastive self-distillation (CSD) framework to simultaneously compress and accelerate various off-the-shelf SR models. In particular, a channel-splitting super-resolution network can first be constructed from a target teacher network as a compact student network. Then, we propose a novel contrastive loss to improve the quality of SR images and PSNR/SSIM via explicit knowledge transfer. Extensive experiments demonstrate that the proposed CSD scheme effectively compresses and accelerates several standard SR models such as EDSR, RCAN and CARN.
7 |
8 | 
9 |
10 | ## Results
11 |
12 | 
13 |
14 | 
15 |
16 | 
17 |
18 | ## Citation
19 |
20 | If you find the code helpful in you research or work, please cite as:
21 |
22 | ```@inproceedings{wu2021contrastive,
23 | @misc{wang2021compact,
24 | title={Towards Compact Single Image Super-Resolution via Contrastive Self-distillation},
25 | author={Yanbo Wang and Shaohui Lin and Yanyun Qu and Haiyan Wu and Zhizhong Zhang and Yuan Xie and Angela Yao},
26 | year={2021},
27 | eprint={2105.11683},
28 | archivePrefix={arXiv},
29 | primaryClass={cs.CV}
30 | }
31 | ```
32 |
33 | ## Acknowledgements
34 |
35 | This code is built on [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch). For the training part of the MindSpore version we referred to [DBPN-MindSpore](https://gitee.com/amythist/DBPN-MindSpore/tree/master), [ModelZoo-RCAN](https://gitee.com/mindspore/models/tree/master/research/cv/RCAN) and the official [tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html). We thank the authors for sharing their codes.
36 |
37 |
--------------------------------------------------------------------------------
/images/debug.log:
--------------------------------------------------------------------------------
1 | [0526/113715:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
2 | [0526/113720:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
3 | [0527/124008:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
4 | [0527/124014:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
5 | [0527/124019:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
6 | [0527/124024:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
7 | [0527/124029:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
8 | [0527/124034:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
9 | [0527/124039:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
10 | [0527/124044:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
11 | [0527/124050:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
12 | [0527/124055:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
13 | [0527/135848:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
14 | [0527/135853:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
15 | [0527/135858:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
16 | [0527/135903:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
17 | [0527/135908:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
18 | [0527/135913:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
19 | [0527/135918:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
20 | [0527/135923:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
21 | [0527/135928:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
22 | [0527/135933:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
23 | [0527/135938:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
24 | [0527/135943:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
25 | [0527/135948:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
26 | [0527/135953:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
27 | [0527/135958:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
28 | [0527/140003:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
29 | [0527/140008:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
30 | [0527/140013:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
31 | [0527/140018:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
32 | [0527/140023:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
33 | [0527/140028:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
34 | [0527/140033:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
35 | [0527/140038:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
36 | [0527/140043:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
37 | [0527/140048:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
38 | [0527/140053:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
39 | [0527/140058:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
40 | [0527/140103:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
41 | [0527/140108:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
42 | [0527/140113:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
43 | [0527/140118:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
44 | [0527/140123:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
45 | [0527/140128:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
46 | [0527/140133:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
47 | [0527/140138:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
48 | [0527/140143:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
49 | [0527/140148:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
50 | [0527/140153:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
51 | [0527/140158:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
52 | [0527/140203:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
53 | [0527/140208:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
54 | [0527/140213:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
55 | [0527/140218:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
56 | [0527/140223:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
57 | [0527/140228:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
58 | [0527/140233:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
59 | [0527/140238:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
60 | [0527/140243:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
61 | [0527/140248:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
62 | [0527/140253:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
63 | [0527/140258:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
64 | [0527/140303:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
65 | [0527/140308:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
66 | [0527/140313:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
67 | [0527/140318:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
68 | [0527/140323:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
69 | [0527/140328:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
70 | [0527/140333:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
71 | [0527/140338:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
72 | [0527/140343:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
73 | [0527/140348:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
74 | [0527/140353:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
75 | [0527/140358:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
76 | [0527/140403:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
77 | [0527/140408:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
78 | [0527/140413:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
79 | [0527/140418:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
80 | [0527/140423:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
81 | [0527/140428:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
82 | [0527/140433:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
83 | [0527/140438:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
84 | [0527/140443:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
85 | [0527/140448:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
86 | [0527/140453:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
87 | [0527/140458:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
88 | [0527/140503:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
89 | [0527/140508:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
90 | [0527/140513:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
91 | [0527/140518:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
92 | [0527/140523:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
93 | [0527/140528:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
94 | [0527/140533:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
95 | [0527/140538:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
96 | [0527/140543:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
97 | [0527/140548:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
98 | [0527/140553:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
99 | [0527/140558:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
100 | [0527/140603:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
101 | [0527/140608:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
102 | [0527/140613:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
103 | [0527/140618:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
104 | [0527/140623:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
105 | [0527/140628:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
106 | [0527/140633:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
107 | [0527/140638:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
108 | [0527/140643:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
109 | [0527/140648:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
110 | [0527/140653:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
111 | [0527/140658:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
112 | [0527/140703:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
113 | [0527/140708:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
114 | [0527/140713:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
115 | [0527/140718:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
116 | [0527/140723:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
117 | [0527/140728:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
118 | [0527/140734:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
119 | [0527/140739:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
120 | [0527/140744:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
121 | [0527/140749:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
122 | [0529/121726:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
123 | [0529/121745:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
124 | [0529/121749:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
125 | [0529/121754:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
126 | [0529/121819:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
127 | [0529/121820:WARNING:dns_config_service_win.cc(674)] Failed to read DnsConfig.
128 |
--------------------------------------------------------------------------------
/images/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/model.png
--------------------------------------------------------------------------------
/images/psnr-speed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/psnr-speed.png
--------------------------------------------------------------------------------
/images/psnr-tradeoff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/psnr-tradeoff.png
--------------------------------------------------------------------------------
/images/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/table.png
--------------------------------------------------------------------------------
/images/tradeoff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/tradeoff.png
--------------------------------------------------------------------------------
/images/visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Booooooooooo/CSD/f479c8fd96e7da24c123642043f0d42f4b827af9/images/visual.png
--------------------------------------------------------------------------------