├── README.md
├── __pycache__
├── checkpoint.cpython-36.pyc
├── checkpoint.cpython-39.pyc
├── loss.cpython-36.pyc
├── loss.cpython-39.pyc
├── option.cpython-36.pyc
├── option.cpython-39.pyc
├── trainer.cpython-36.pyc
├── trainer.cpython-39.pyc
├── utility.cpython-35.pyc
├── utility.cpython-36.pyc
└── utility.cpython-39.pyc
├── checkpoint.py
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-39.pyc
│ ├── benchmark.cpython-36.pyc
│ ├── benchmark.cpython-39.pyc
│ ├── common.cpython-36.pyc
│ ├── common.cpython-39.pyc
│ ├── df2k.cpython-36.pyc
│ ├── div2k.cpython-36.pyc
│ ├── div2k.cpython-39.pyc
│ ├── srdata.cpython-36.pyc
│ └── srdata.cpython-39.pyc
├── benchmark.py
├── common.py
├── div2k.py
└── srdata.py
├── experiments
└── CFIN
│ └── model
│ ├── model_best_x2.pt
│ ├── model_best_x3.pt
│ └── model_best_x4.pt
├── img
└── compare.png
├── loss.py
├── main.py
├── model
├── CFIN.py
├── CFINx2.py
├── CFINx3.py
├── CFINx4.py
├── MultiAdd.py
├── __init__.py
├── __pycache__
│ ├── CFIN.cpython-39.pyc
│ ├── CFINx2.cpython-39.pyc
│ ├── CFINx3.cpython-39.pyc
│ ├── CFINx4.cpython-39.pyc
│ ├── MSDNN_LW1.cpython-36.pyc
│ ├── MsDNN.cpython-36.pyc
│ ├── MultiAdd.cpython-36.pyc
│ ├── MultiAdd.cpython-39.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-39.pyc
│ ├── common.cpython-36.pyc
│ ├── common.cpython-39.pyc
│ ├── msfin3.cpython-36.pyc
│ ├── no_transmy.cpython-39.pyc
│ ├── rebuttal_updown1.cpython-39.pyc
│ ├── transmy.cpython-36.pyc
│ ├── transmy3.cpython-36.pyc
│ ├── transmy5.cpython-36.pyc
│ ├── transmy5_0.cpython-36.pyc
│ ├── transmy_j.cpython-36.pyc
│ ├── transmy_j.cpython-39.pyc
│ └── transmy_jx4.cpython-39.pyc
└── common.py
├── option.py
├── test_summary.py
├── trainer.py
└── utility.py
/README.md:
--------------------------------------------------------------------------------
1 | ## CFIN-Pytorch
2 | This repository is an official PyTorch implementation of our paper "Cross-receptive Focused Inference Network for Lightweight Image Super-Resolution".
3 |
4 | ---
5 |
6 | ### CFIN: Cross-receptive Focused Inference Network for Lightweight Image Super-Resolution. (IEEE TRANSACTIONS ON MULTIMEDIA, 2023)
7 |
8 | > [[Paper(arxiv)](https://arxiv.org/abs/2207.02796)] [[Paper(IEEE)](https://ieeexplore.ieee.org/document/10114600)] [[Code](https://github.com/24wenjie-li/CFIN)]
9 |
10 |
11 |
12 |
13 |
14 | ---
15 |
16 | ## Prerequisites:
17 | ```
18 | 1. Python >= 3.6
19 | 2. PyTorch >= 1.2
20 | 3. numpy
21 | 4. skimage
22 | 5. imageio
23 | 6. tqdm
24 | 7. timm
25 | 8. einops
26 | ```
27 |
28 | ## Dataset
29 | We used only DIV2K dataset to train our model. To speed up the data reading during training, we converted the format of the images within the dataset from png to npy.
30 | Please download the DIV2K_decoded with npy format from here.[Quark Netdisk]
31 |
32 | The test set contains five datasets, Set5, Set14, B100, Urban100, Manga109. The benchmark can be downloaded from here.[Baidu Netdisk][Password:8888]
33 |
34 | Extract the file and place it in the same location as args.data_dir in option.py.
35 |
36 | The code and datasets need satisfy the following structures:
37 | ```
38 | ├── CFIN # Train / Test Code
39 | ├── dataset # all datasets for this code
40 | | └── DIV2K_decoded # train datasets with npy format
41 | | | └── DIV2K_train_HR
42 | | | └── DIV2K_train_LR_bicubic
43 | | └── benchmark # test datasets with png format
44 | | | └── Set5
45 | | | └── Set14
46 | | | └── B100
47 | | | └── Urban100
48 | | | └── Manga109
49 | ─────────────────
50 | ```
51 |
52 |
53 | ## Results
54 | All our SR Results can be downloaded from here.[Baidu Netdisk][Password:8888]
55 |
56 | All pretrained model can be found in experiments/CFIN/model/.
57 |
58 | ## Training
59 | Note:You need to manually import the name of the model to be trained/tested in line 37 of model/init.py.
60 | ```
61 | # CFIN x2
62 | python main.py --scale 2 --model CFINx2 --patch_size 96 --save experiments/CFINx2
63 |
64 | # CFIN x3
65 | python main.py --scale 3 --model CFINx3 --patch_size 144 --save experiments/CFINx3
66 |
67 | # CFIN x4
68 | python main.py --scale 4 --model CFINx4 --patch_size 192 --save experiments/CFINx4
69 | ```
70 | **Somethings you need to know:**
71 | Since the training/testing code we provide is the initial code, the module names within the code are not consistent with those within the paper. For this reason, we provide model/CFIN.py to help readers understand the paper better.
72 |
73 | ## Testing
74 | Note:Since the PSNR/SSIM values in our paper are obtained from the Matlab program, the data obtained using the python code may have a floating error of 0.01 dB in the PSNR. The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to here. (You need to modify the test path!)
75 | ```
76 | # CFIN x2
77 | python main.py --scale 2 --model CFINx2 --save test_results/CFINx2 --pre_train experiments/CFIN/model/model_best_x2.pt --test_only --save_results --data_test Set5
78 |
79 | # CFIN x3
80 | python main.py --scale 3 --model CFINx3 --save test_results/CFINx3 --pre_train experiments/CFIN/model/model_best_x3.pt --test_only --save_results --data_test Set5
81 |
82 | # CFIN x4
83 | python main.py --scale 4 --model CFINx4 --save test_results/CFINx4 --pre_train experiments/CFIN/model/model_best_x4.pt --test_only --save_results --data_test Set5
84 |
85 | # CFIN+ x2 with self-ensemble strategy
86 | python main.py --scale 2 --model CFINx2 --save test_results/CFINx2 --pre_train experiments/CFIN/model/model_best_x2.pt --test_only --save_results --chop --self_ensemble --data_test Set5
87 |
88 | # CFIN+ x3 with self-ensemble strategy
89 | python main.py --scale 3 --model CFINx3 --save test_results/CFINx3 --pre_train experiments/CFIN/model/model_best_x3.pt --test_only --save_results --chop --self_ensemble --data_test Set5
90 |
91 | # CFIN+ x4 with self-ensemble strategy
92 | python main.py --scale 4 --model CFINx4 --save test_results/CFINx4 --pre_train experiments/CFIN/model/model_best_x4.pt --test_only --save_results --chop --self_ensemble --data_test Set5
93 | ```
94 |
95 | ## Test Parmas and Muti-adds
96 | Note:You need to install torchsummaryX!
97 | ```
98 | # Default CFINx4
99 | python test_summary.py
100 | ```
101 |
102 | ## Performance
103 | Our CFIN is trained on RGB, but as in previous work, we only reported PSNR/SSIM on the Y channel.
104 |
105 | Model|Scale|Params|Multi-adds|Set5|Set14|B100|Urban100|Manga109
106 | --|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:
107 | CFIN |x2|675K|116.9G|38.14/0.9610|33.80/0.9199|32.26/0.9006|32.48/0.9311|38.97/0.9777
108 | CFIN |x3|681K|53.5G|34.65/0.9289|30.45/0.8443|29.18/0.8071|28.49/0.8583|33.89/0.9464
109 | CFIN |x4|699K|31.2G|32.49/0.8985|28.74/0.7849|27.68/0.7396|26.39/0.7946|30.73/0.9124
110 |
111 | ## Some extra questions
112 | You can download the supplementary materials on issues requested by reviewers from here.[Baidu Netdisk][Password:8888]
113 |
114 | ## Acknowledgements
115 | This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [DRN](https://github.com/guoyongcs/DRN). We thank the authors for sharing their codes.
116 |
117 |
118 | ## :clipboard: Citation
119 |
120 | ```
121 | @article{li2023cross,
122 | title={Cross-receptive focused inference network for lightweight image super-resolution},
123 | author={Li, Wenjie and Li, Juncheng and Gao, Guangwei and Deng, Weihong and Zhou, Jiantao and Yang, Jian and Qi, Guo-Jun},
124 | journal={IEEE Transactions on Multimedia},
125 | year={2023},
126 | publisher={IEEE}
127 | }
128 | ```
129 |
130 | ## :e-mail: Contact
131 |
132 | If you have any question, please email `lewj2408@gmail.com`
133 |
--------------------------------------------------------------------------------
/__pycache__/checkpoint.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/checkpoint.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/checkpoint.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/checkpoint.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/option.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/option.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/option.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/option.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/trainer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/trainer.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utility.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/utility.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utility.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-39.pyc
--------------------------------------------------------------------------------
/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import datetime
4 | import numpy as np
5 | import imageio
6 | import matplotlib
7 | matplotlib.use('Agg')
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class Checkpoint():
12 | def __init__(self, opt):
13 | self.opt = opt
14 | self.ok = True
15 | self.log = torch.Tensor()
16 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
17 |
18 | if opt.save == '.': opt.save = '../experiment/EXP/' + now
19 | self.dir = opt.save
20 |
21 | def _make_dir(path):
22 | if not os.path.exists(path): os.makedirs(path)
23 |
24 | _make_dir(self.dir)
25 | _make_dir(self.dir + '/model')
26 | _make_dir(self.dir + '/results')
27 |
28 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
29 | self.log_file = open(self.dir + '/log.txt', open_type)
30 | with open(self.dir + '/config.txt', open_type) as f:
31 | f.write(now + '\n\n')
32 | for arg in vars(opt):
33 | f.write('{}: {}\n'.format(arg, getattr(opt, arg)))
34 | f.write('\n')
35 |
36 | def save(self, trainer, epoch, is_best=False):
37 | trainer.model.save(self.dir, is_best=is_best)
38 | trainer.loss.save(self.dir)
39 | trainer.loss.plot_loss(self.dir, epoch)
40 |
41 | self.plot_psnr(epoch)
42 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
43 | torch.save(
44 | trainer.optimizer.state_dict(),
45 | os.path.join(self.dir, 'optimizer.pt')
46 | )
47 | # dual_optimizers = {}
48 | # for i in range(len(trainer.dual_optimizers)):
49 | # dual_optimizers[i] = trainer.dual_optimizers[i]
50 | # torch.save(
51 | # dual_optimizers,
52 | # os.path.join(self.dir, 'dual_optimizers.pt')
53 | # )
54 |
55 | def add_log(self, log):
56 | self.log = torch.cat([self.log, log])
57 |
58 | def write_log(self, log, refresh=False):
59 | print(log)
60 | self.log_file.write(log + '\n')
61 | if refresh:
62 | self.log_file.close()
63 | self.log_file = open(self.dir + '/log.txt', 'a')
64 |
65 | def done(self):
66 | self.log_file.close()
67 |
68 | def plot_psnr(self, epoch):
69 | axis = np.linspace(1, epoch, epoch)
70 | label = 'SR on {}'.format(self.opt.data_test)
71 | fig = plt.figure()
72 | plt.title(label)
73 | for idx_scale, scale in enumerate([self.opt.scale[0]]):
74 | plt.plot(
75 | axis,
76 | self.log[:, idx_scale].numpy(),
77 | label='Scale {}'.format(scale)
78 | )
79 | plt.legend()
80 | plt.xlabel('Epochs')
81 | plt.ylabel('PSNR')
82 | plt.grid(True)
83 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.opt.data_test))
84 | plt.close(fig)
85 |
86 | def save_results_nopostfix(self, filename, sr, scale):
87 | apath = '{}/results/{}/x{}'.format(self.dir, self.opt.data_test, scale)
88 | if not os.path.exists(apath):
89 | os.makedirs(apath)
90 | filename = os.path.join(apath, filename)
91 |
92 | normalized = sr[0].data.mul(255 / self.opt.rgb_range)
93 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
94 | imageio.imwrite('{}.png'.format(filename), ndarr)
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | class Data:
6 | def __init__(self, args):
7 | self.loader_train = None
8 | if not args.test_only:
9 | module_train = import_module('data.' + args.data_train.lower())
10 | trainset = getattr(module_train, args.data_train)(args)
11 | self.loader_train = DataLoader(
12 | trainset,
13 | batch_size=args.batch_size,
14 | num_workers=args.n_threads,
15 | shuffle=True,
16 | pin_memory=not args.cpu
17 | )
18 |
19 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'Look']:
20 | module_test = import_module('data.benchmark')
21 | testset = getattr(module_test, 'Benchmark')(args, name=args.data_test, train=False)
22 | else:
23 | module_test = import_module('data.' + args.data_test.lower())
24 | testset = getattr(module_test, args.data_test)(args, train=False)
25 |
26 | self.loader_test = DataLoader(
27 | testset,
28 | batch_size=1,
29 | num_workers=1,
30 | shuffle=False,
31 | pin_memory=not args.cpu
32 | )
33 |
34 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/benchmark.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/benchmark.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/benchmark.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/benchmark.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/df2k.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/df2k.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/div2k.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/div2k.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/div2k.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/div2k.cpython-39.pyc
--------------------------------------------------------------------------------
/data/__pycache__/srdata.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/srdata.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/srdata.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/srdata.cpython-39.pyc
--------------------------------------------------------------------------------
/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 |
5 | class Benchmark(srdata.SRData1):
6 | def __init__(self, args, name='', train=True, benchmark=True):
7 | super(Benchmark, self).__init__(
8 | args, name=name, train=train, benchmark=True
9 | )
10 |
11 | def _set_filesystem(self, data_dir):
12 |
13 | self.apath = os.path.join(data_dir, 'benchmark', self.name)
14 | self.dir_hr = os.path.join(self.apath, 'HR')
15 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
16 | self.ext = ('', '.png')
17 |
18 |
--------------------------------------------------------------------------------
/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import skimage.color as sc
4 | import torch
5 |
6 |
7 | def get_patch(*args, patch_size=96, scale=[2], multi_scale=False):
8 | th, tw = args[-1].shape[:2] # target images size
9 |
10 | tp = patch_size # patch size of target hr image
11 | ip = [patch_size // s for s in scale] # patch size of lr images
12 |
13 | # tx and ty are the top and left coordinate of the patch
14 | tx = random.randrange(0, tw - tp + 1)
15 | ty = random.randrange(0, th - tp + 1)
16 | tx, ty = tx- tx % scale[0], ty - ty % scale[0]
17 | ix, iy = [ tx // s for s in scale], [ty // s for s in scale]
18 |
19 | lr = [args[0][i][iy[i]:iy[i] + ip[i], ix[i]:ix[i] + ip[i], :] for i in range(len(scale))]
20 | hr = args[-1][ty:ty + tp, tx:tx + tp, :]
21 |
22 | return [lr, hr]
23 |
24 | def set_channel(*args, n_channels=3):
25 | def _set_channel(img):
26 | if img.ndim == 2:
27 | img = np.expand_dims(img, axis=2)
28 |
29 | c = img.shape[2]
30 | if n_channels == 1 and c == 3:
31 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
32 | elif n_channels == 3 and c == 1:
33 | img = np.concatenate([img] * n_channels, 2)
34 |
35 | return img
36 |
37 | return [_set_channel(a) for a in args[0]], _set_channel(args[-1])
38 |
39 |
40 | def np2Tensor(*args, rgb_range=255):
41 | def _np2Tensor(img):
42 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
43 | tensor = torch.from_numpy(np_transpose).float()
44 | tensor.mul_(rgb_range / 255)
45 |
46 | return tensor
47 |
48 | return [_np2Tensor(a) for a in args[0]], _np2Tensor(args[1])
49 |
50 |
51 | def augment(*args, hflip=True, rot=True):
52 | hflip = hflip and random.random() < 0.5
53 | vflip = rot and random.random() < 0.5
54 | rot90 = rot and random.random() < 0.5
55 |
56 | def _augment(img):
57 | if hflip: img = img[:, ::-1, :]
58 | if vflip: img = img[::-1, :, :]
59 | if rot90: img = img.transpose(1, 0, 2)
60 |
61 | return img
62 |
63 | return [_augment(a) for a in args[0]], _augment(args[-1])
64 |
65 |
--------------------------------------------------------------------------------
/data/div2k.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | from data import srdata
5 |
6 |
7 | class DIV2K(srdata.SRData):
8 | def __init__(self, args, name='DIV2K_decoded', train=True, benchmark=False):
9 | super(DIV2K, self).__init__(
10 | args, name=name, train=train, benchmark=benchmark
11 | )
12 |
13 | def _set_filesystem(self, data_dir):
14 | super(DIV2K, self)._set_filesystem(data_dir)
15 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
16 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
17 |
18 |
19 |
--------------------------------------------------------------------------------
/data/srdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from data import common
4 | import numpy as np
5 | import imageio
6 | import torch.utils.data as data
7 | import cv2
8 |
9 | def default_loader(path):
10 | return cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, [2, 1, 0]]
11 |
12 | def npy_loader(path):
13 | return np.load(path, allow_pickle=True)
14 |
15 | class SRData(data.Dataset):
16 | def __init__(self, args, name='', train=True, benchmark=False):
17 | self.args = args
18 | self.name = name
19 | self.train = train
20 | self.split = 'train' if train else 'test'
21 | self.do_eval = True
22 | self.benchmark = benchmark
23 | self.scale = args.scale.copy()
24 | self.scale.reverse()
25 | self.ext1 = args.ext
26 |
27 | self._set_filesystem(args.data_dir)
28 | self._get_imgs_path(args)
29 | self._set_dataset_length()
30 |
31 | def __getitem__(self, idx):
32 | lr, hr, filename = self._load_file(idx)
33 |
34 | lr, hr = self.get_patch(lr, hr)
35 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors)
36 |
37 | lr_tensor, hr_tensor = common.np2Tensor(
38 | lr, hr, rgb_range=self.args.rgb_range
39 | )
40 |
41 | return lr_tensor, hr_tensor, filename
42 |
43 | def __len__(self):
44 | return self.dataset_length
45 |
46 | def _get_imgs_path(self, args):
47 | list_hr, list_lr = self._scan()
48 | self.images_hr, self.images_lr = list_hr, list_lr
49 |
50 | def _set_dataset_length(self):
51 | if self.train:
52 | self.dataset_length = self.args.test_every * self.args.batch_size
53 | repeat = self.dataset_length // len(self.images_hr)
54 | self.random_border = len(self.images_hr) * repeat
55 | else:
56 | self.dataset_length = len(self.images_hr)
57 |
58 | def _scan(self):
59 | names_hr = sorted(
60 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[1]))
61 | )
62 | names_lr = [[] for _ in self.scale]
63 | for f in names_hr:
64 | filename, _ = os.path.splitext(os.path.basename(f))
65 | for si, s in enumerate(self.scale):
66 | names_lr[si].append(os.path.join(
67 | self.dir_lr, 'X{}/{}x{}{}'.format(
68 | s, filename, s, self.ext[1]
69 | )
70 | ))
71 |
72 | return names_hr, names_lr
73 |
74 | def _set_filesystem(self, data_dir):
75 | self.apath = os.path.join(data_dir, self.name)
76 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
77 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
78 | self.ext = ('.png', '.npy')
79 |
80 | def _get_index(self, idx):
81 | if self.train:
82 | if idx < self.random_border:
83 | return idx % len(self.images_hr)
84 | else:
85 | return np.random.randint(len(self.images_hr))
86 | else:
87 | return idx
88 |
89 | def _load_file(self, idx):
90 | idx = self._get_index(idx)
91 | f_hr = self.images_hr[idx]
92 | f_lr = [self.images_lr[idx_scale][idx] for idx_scale in range(len(self.scale))]
93 |
94 | filename, _ = os.path.splitext(os.path.basename(f_hr))
95 | if self.ext1 == '.npy':
96 | lr = [npy_loader(f_lr[idx_scale]) for idx_scale in range(len(self.scale))]
97 | hr = npy_loader(f_hr)
98 | if not self.train:
99 | hr = default_loader(f_hr)
100 | lr = [default_loader(f_lr[idx_scale]) for idx_scale in range(len(self.scale))]
101 | return lr, hr, filename
102 |
103 | def get_patch(self, lr, hr):
104 | scale = self.scale
105 | multi_scale = len(self.scale) > 1
106 | if self.train:
107 | lr, hr = common.get_patch(
108 | lr,
109 | hr,
110 | patch_size=self.args.patch_size,
111 | scale=scale,
112 | multi_scale=multi_scale
113 | )
114 | if not self.args.no_augment:
115 | lr, hr = common.augment(lr, hr)
116 | else:
117 | if isinstance(lr, list):
118 | ih, iw = lr[0].shape[:2]
119 | else:
120 | ih, iw = lr.shape[:2]
121 | hr = hr[0:ih * scale[0], 0:iw * scale[0]]
122 |
123 | return lr, hr
124 |
125 | ##############################################################################
126 | class SRData1(data.Dataset):
127 | def __init__(self, args, name='', train=True, benchmark=False):
128 | self.args = args
129 | self.name = name
130 | self.train = train
131 | self.split = 'train' if train else 'test'
132 | self.do_eval = True
133 | self.benchmark = benchmark
134 | self.scale = args.scale.copy()
135 | self.scale.reverse()
136 | self.ext1 = args.ext
137 |
138 | self._set_filesystem(args.data_dir)
139 | self._get_imgs_path(args)
140 | self._set_dataset_length()
141 |
142 | def __getitem__(self, idx):
143 | lr, hr, filename = self._load_file(idx)
144 |
145 | lr, hr = self.get_patch(lr, hr)
146 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors)
147 |
148 | lr_tensor, hr_tensor = common.np2Tensor(
149 | lr, hr, rgb_range=self.args.rgb_range
150 | )
151 |
152 | return lr_tensor, hr_tensor, filename
153 |
154 | def __len__(self):
155 | return self.dataset_length
156 |
157 | def _get_imgs_path(self, args):
158 | list_hr, list_lr = self._scan()
159 | self.images_hr, self.images_lr = list_hr, list_lr
160 |
161 | def _set_dataset_length(self):
162 | if self.train:
163 | self.dataset_length = self.args.test_every * self.args.batch_size
164 | repeat = self.dataset_length // len(self.images_hr)
165 | self.random_border = len(self.images_hr) * repeat
166 | else:
167 | self.dataset_length = len(self.images_hr)
168 |
169 | def _scan(self):
170 | names_hr = sorted(
171 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[1]))
172 | )
173 |
174 | names_lr = [[] for _ in self.scale]
175 | for f in names_hr:
176 | filename, _ = os.path.splitext(os.path.basename(f))
177 | for si, s in enumerate(self.scale):
178 | names_lr[si].append(os.path.join(
179 | self.dir_lr, 'X{}/{}x{}{}'.format(
180 | s, filename, s, self.ext[1]
181 | )
182 | ))
183 |
184 | return names_hr, names_lr
185 |
186 | def _set_filesystem(self, data_dir):
187 | self.apath = os.path.join(data_dir, 'DIV2K')
188 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
189 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
190 | self.ext = ('.png', '.png')
191 |
192 | def _get_index(self, idx):
193 | if self.train:
194 | if idx < self.random_border:
195 | return idx % len(self.images_hr)
196 | else:
197 | return np.random.randint(len(self.images_hr))
198 | else:
199 | return idx
200 |
201 | def _load_file(self, idx):
202 | idx = self._get_index(idx)
203 | f_hr = self.images_hr[idx]
204 | f_lr = [self.images_lr[idx_scale][idx] for idx_scale in range(len(self.scale))]
205 |
206 | filename, _ = os.path.splitext(os.path.basename(f_hr))
207 | hr = imageio.imread(f_hr)
208 | lr = [imageio.imread(f_lr[idx_scale]) for idx_scale in range(len(self.scale))]
209 | return lr, hr, filename
210 |
211 | def get_patch(self, lr, hr):
212 | scale = self.scale
213 | multi_scale = len(self.scale) > 1
214 | if self.train:
215 | lr, hr = common.get_patch(
216 | lr,
217 | hr,
218 | patch_size=self.args.patch_size,
219 | scale=scale,
220 | multi_scale=multi_scale
221 | )
222 | if not self.args.no_augment:
223 | lr, hr = common.augment(lr, hr)
224 | else:
225 | if isinstance(lr, list):
226 | ih, iw = lr[0].shape[:2]
227 | else:
228 | ih, iw = lr.shape[:2]
229 | hr = hr[0:ih * scale[0], 0:iw * scale[0]]
230 |
231 | return lr, hr
232 |
233 |
--------------------------------------------------------------------------------
/experiments/CFIN/model/model_best_x2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x2.pt
--------------------------------------------------------------------------------
/experiments/CFIN/model/model_best_x3.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x3.pt
--------------------------------------------------------------------------------
/experiments/CFIN/model/model_best_x4.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x4.pt
--------------------------------------------------------------------------------
/img/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/img/compare.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class Loss(nn.modules.loss._Loss):
11 | def __init__(self, args, ckp):
12 | super(Loss, self).__init__()
13 |
14 | self.loss = []
15 | self.loss_module = nn.ModuleList()
16 | for loss in args.loss.split('+'):
17 | weight, loss_type = loss.split('*')
18 |
19 | if loss_type == 'MSE':
20 | loss_function = nn.MSELoss()
21 | elif loss_type == 'L1':
22 | loss_function = nn.L1Loss(reduction='mean')
23 | else:
24 | assert False, f"Unsupported loss type: {loss_type:s}"
25 |
26 | self.loss.append({
27 | 'type': loss_type,
28 | 'weight': float(weight),
29 | 'function': loss_function}
30 | )
31 |
32 | if len(self.loss) > 1:
33 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
34 |
35 | for l in self.loss:
36 | if l['function'] is not None:
37 | print('{:.3f} * {}'.format(l['weight'], l['type']))
38 | self.loss_module.append(l['function'])
39 |
40 | self.log = torch.Tensor()
41 |
42 | def forward(self, sr, hr):
43 | losses = []
44 | for i, l in enumerate(self.loss):
45 | if l['function'] is not None:
46 | loss = l['function'](sr, hr)
47 | effective_loss = l['weight'] * loss
48 | losses.append(effective_loss)
49 | self.log[-1, i] += effective_loss.item()
50 |
51 | loss_sum = sum(losses)
52 | if len(self.loss) > 1:
53 | self.log[-1, -1] += loss_sum.item()
54 |
55 | return loss_sum
56 |
57 | def start_log(self):
58 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
59 |
60 | def end_log(self, n_batches):
61 | self.log[-1].div_(n_batches)
62 |
63 | def display_loss(self, batch):
64 | n_samples = batch + 1
65 | log = []
66 | for l, c in zip(self.loss, self.log[-1]):
67 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
68 |
69 | return ''.join(log)
70 |
71 | def plot_loss(self, apath, epoch):
72 | axis = np.linspace(1, epoch, epoch)
73 | for i, l in enumerate(self.loss):
74 | label = '{} Loss'.format(l['type'])
75 | fig = plt.figure()
76 | plt.title(label)
77 | plt.plot(axis, self.log[:, i].numpy(), label=label)
78 | plt.legend()
79 | plt.xlabel('Epochs')
80 | plt.ylabel('Loss')
81 | plt.grid(True)
82 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
83 | plt.close(fig)
84 |
85 | def save(self, apath):
86 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
87 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import utility
2 | import data
3 | import model
4 | import loss
5 | from option import args
6 | from checkpoint import Checkpoint
7 | from trainer import Trainer
8 |
9 | utility.set_seed(args.seed)
10 | checkpoint = Checkpoint(args)
11 |
12 | if checkpoint.ok:
13 | loader = data.Data(args)
14 | model = model.Model(args, checkpoint)
15 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
16 | t = Trainer(args, loader, model, loss, checkpoint)
17 | while not t.terminate():
18 | t.train()
19 | t.test()
20 | checkpoint.done()
21 |
22 |
--------------------------------------------------------------------------------
/model/CFIN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from model import common
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import torch.nn.functional as F
7 | from pdb import set_trace as stx
8 | import numbers
9 | from einops import rearrange
10 | from torch.nn.parameter import Parameter
11 | from torch.autograd import Variable
12 | #from IPython import embed
13 |
14 | def make_model(args, parent=False):
15 | return MODEL(args)
16 |
17 | class IGP(nn.Module):
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = CGM(in_features, hidden_features, kernel_size=1, padding=0)
23 | self.act = act_layer()
24 | self.fc2 = CGM(hidden_features, out_features, kernel_size=3, padding=1)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | class CGM(nn.Conv2d):
36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1,
37 | bias=True):
38 | super(CGM, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
39 |
40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True)
41 | self.bias_conv = Parameter(torch.Tensor(out_channels))
42 | nn.init.kaiming_normal_(self.weight_conv)
43 |
44 | self.stride = stride
45 | self.padding = padding
46 | self.dilation = dilation
47 | self.groups = groups
48 |
49 | if kernel_size == 0:
50 | self.ind = True
51 | else:
52 | self.ind = False
53 | self.oc = out_channels
54 | self.ks = kernel_size
55 |
56 | # target spatial size of the pooling layer
57 | ws = kernel_size
58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws))
59 |
60 | # the dimension of latent representation
61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
62 |
63 | # the context encoding module
64 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
65 |
66 | self.act = nn.ReLU()
67 |
68 | # the number of groups in the channel interaction module
69 | if in_channels // 8:
70 | self.g = 8
71 | else:
72 | self.g = in_channels
73 |
74 | # the channel interacting module
75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
76 |
77 | # the gate decoding module (spatial interaction)
78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
80 |
81 | # used to prepare the input feature map to patches
82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
83 |
84 | # sigmoid function
85 | self.sig = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | if self.ind:
89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups)
90 | else:
91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch
92 | weight = self.weight_conv
93 |
94 | # allocate global information
95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9
96 |
97 | # context-encoding module
98 | out = self.ce(gl) # out: batch x n_feat x 5
99 |
100 | # use different bn for following two branches
101 | ce2 = out # ce2: batch x n_feat x 5
102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization)
103 |
104 | # gate decoding branch 1 (spatial interaction)
105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3)
106 |
107 | # channel interacting module
108 | if self.g > 3:
109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous()
110 | else:
111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous()
112 | oc = oc.view(b, self.oc, -1)
113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer)
114 |
115 | # gate decoding branch 2 (spatial interaction)
116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3)
117 |
118 | # produce gate (equation (4) in the CRAN paper)
119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight)
121 |
122 | # unfolding input feature map to patches
123 | x_un = self.unfold(x)
124 | b, _, l = x_un.size()
125 | out = (out * weight.unsqueeze(0))#.to(device)
126 | out = out.view(b, self.oc, -1)
127 |
128 | # currently only handle square input and output
129 | return torch.matmul(out, x_un).view(b, self.oc, h, w)
130 |
131 | class CGA(nn.Module):
132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
133 | super().__init__()
134 | self.num_heads = num_heads
135 | head_dim = dim // num_heads
136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True)
137 | self.to_qkv = CGM(dim, dim*3)
138 |
139 | def forward(self, x, k1=None, v1=None, return_x=False):
140 | weight = self.weight
141 | b,c,h,w = x.shape
142 |
143 | qkv = self.to_qkv(x)
144 | q, k, v = qkv.chunk(3, dim=1)
145 |
146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
149 |
150 | if k1 is None:
151 | k = k
152 | v = v
153 | else:
154 | k = k1 + k
155 | v = v1 + v
156 | q = torch.nn.functional.normalize(q, dim=-1)
157 | k = torch.nn.functional.normalize(k, dim=-1)
158 |
159 | attn = (q @ k.transpose(-2, -1)) * weight
160 | attn = attn.softmax(dim=-1)
161 | x = (attn @ v)
162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
163 |
164 | if return_x:
165 | return x
166 | else:
167 | return x, k, v
168 |
169 |
170 | class WithBias_LayerNorm(nn.Module):
171 | def __init__(self, normalized_shape):
172 | super(WithBias_LayerNorm, self).__init__()
173 | if isinstance(normalized_shape, numbers.Integral):
174 | normalized_shape = (normalized_shape,)
175 | normalized_shape = torch.Size(normalized_shape)
176 |
177 | assert len(normalized_shape) == 1
178 |
179 | self.weight = nn.Parameter(torch.ones(normalized_shape))
180 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
181 | self.normalized_shape = normalized_shape
182 |
183 | def forward(self, x):
184 | mu = x.mean(-1, keepdim=True)
185 | sigma = x.var(-1, keepdim=True, unbiased=False)
186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
187 |
188 |
189 | def to_3d(x):
190 | return rearrange(x, 'b c h w -> b (h w) c')
191 |
192 | def to_4d(x,h,w):
193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
194 |
195 | class LayerNorm(nn.Module):
196 | def __init__(self, dim):
197 | super(LayerNorm, self).__init__()
198 | self.body = WithBias_LayerNorm(dim)
199 |
200 | def forward(self, x):
201 | h, w = x.shape[-2:]
202 | return to_4d(self.body(to_3d(x)), h, w)
203 |
204 |
205 | class CFGT(nn.Module):
206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
207 | drop_path=0., act_layer=nn.GELU):
208 | super().__init__()
209 | self.norm1 = LayerNorm(dim)
210 | kernel_size1 = 1
211 | padding1 = 0
212 | kernel_size2 = 3
213 | padding2 = 1
214 | self.attn = CGA(dim, num_heads, kernel_size1, padding1)
215 | self.attn1 = CGA(dim, num_heads, kernel_size2, padding2)
216 |
217 | self.norm2 = LayerNorm(dim)
218 | self.norm3 = LayerNorm(dim)
219 | mlp_hidden_dim = int(dim*1)
220 | self.mlp = IGP(in_features=dim, hidden_features=mlp_hidden_dim)
221 |
222 | def forward(self, x):
223 | res = x
224 | x, k1, v1 = self.attn(x)
225 | x = res + self.norm1(x)
226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True))
227 | x = x + self.norm3(self.mlp(x))
228 | return x
229 |
230 |
231 | class Scale(nn.Module):
232 | def __init__(self, init_value=1e-3):
233 | super().__init__()
234 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
235 |
236 | def forward(self, input):
237 | return input * self.scale
238 |
239 |
240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
241 | act_type = act_type.lower()
242 | if act_type == 'relu':
243 | layer = nn.ReLU()
244 | elif act_type == 'lrelu':
245 | layer = nn.LeakyReLU(neg_slope)
246 | elif act_type == 'prelu':
247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
248 | else:
249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
250 | return layer
251 |
252 |
253 | class eca_layer(nn.Module):
254 | def __init__(self, channel, k_size):
255 | super(eca_layer, self).__init__()
256 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
257 | self.k_size = k_size
258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
259 | self.sigmoid = nn.Sigmoid()
260 |
261 |
262 | def forward(self, x):
263 | b, c, _, _ = x.size()
264 | y = self.avg_pool(x)
265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
267 | y = self.sigmoid(y)
268 | x = x * y.expand_as(x)
269 | return x
270 |
271 |
272 | class MaskPredictor(nn.Module):
273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)):
274 | super(MaskPredictor,self).__init__()
275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False)
276 |
277 | def forward(self,x):
278 | spa_mask=self.spatial_mask(x)
279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1)
280 | return spa_mask
281 |
282 |
283 | class RIFU(nn.Module):
284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
285 | super(RIFU, self).__init__()
286 | self.CA = eca_layer(n_feats, k_size=3)
287 | self.MaskPredictor = MaskPredictor(n_feats*8//8)
288 |
289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
290 | nn.LeakyReLU(0.05),
291 | )
292 |
293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
294 | nn.LeakyReLU(0.05),
295 | )
296 |
297 | self.res_scale = Scale(1)
298 | self.x_scale = Scale(1)
299 |
300 | def forward(self, x):
301 | res = x
302 | x = self.k(x)
303 |
304 | MaskPredictor = self.MaskPredictor(x)
305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1)
306 | x = x * (mask.expand_as(x))
307 |
308 | x1 = self.k1(x)
309 | x2 = self.CA(x1)
310 | out = self.x_scale(x2) + self.res_scale(res)
311 |
312 | return out
313 |
314 |
315 | class CIAM(nn.Module):
316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
317 | super(CIAM, self).__init__()
318 | pooling_r = 2
319 | med_feats = n_feats // 1
320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True),
321 | nn.LeakyReLU(0.05),
322 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1),
323 | )
324 |
325 | self.sig = nn.Sigmoid()
326 |
327 | self.k3 = RIFU(n_feats)
328 |
329 | self.k4 = RIFU(n_feats)
330 |
331 | self.k5 = RIFU(n_feats)
332 |
333 | self.res_scale = Scale(1)
334 | self.x_scale = Scale(1)
335 |
336 | def forward(self, x):
337 | identity = x
338 | _, _, H, W = identity.shape
339 | x1_1 = self.k3(x)
340 | x1 = self.k4(x1_1)
341 |
342 |
343 | x1_s = self.sig(self.k1(x) + x)
344 | x1 = self.k5(x1_s * x1)
345 |
346 | out = self.res_scale(x1) + self.x_scale(identity)
347 |
348 | return out
349 |
350 |
351 | class FCUUp(nn.Module):
352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)):
354 | super(FCUUp, self).__init__()
355 | self.up_stride = up_stride
356 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
357 | self.act = act_layer()
358 |
359 | def forward(self, x_t):
360 | x_r = self.act(self.conv_project(x_t))
361 |
362 | return x_r
363 |
364 | class FCUDown(nn.Module):
365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)):
367 | super(FCUDown, self).__init__()
368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
369 |
370 | def forward(self, x):
371 | x = self.conv_project(x)
372 |
373 | return x
374 |
375 |
376 | class CTGroup(nn.Module):
377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None):
378 | super(CTGroup, self).__init__()
379 |
380 | expansion = 1
381 | med_planes = outplanes // expansion
382 | embed_dim = 144
383 | num_heads = 8
384 | mlp_ratio = 1.0
385 |
386 | self.rb_search1 = CIAM(med_planes)
387 | self.rb_search2 = CIAM(med_planes)
388 | self.rb_search3 = CIAM(med_planes)
389 | self.rb_search4 = CIAM(med_planes)
390 |
391 | self.trans_block = CFGT(
392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
393 | drop=0., attn_drop=0., drop_path=0.)
394 |
395 | self.trans_block1 = CFGT(
396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
397 | drop=0., attn_drop=0., drop_path=0.)
398 |
399 | self.trans_block2 = CFGT(
400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
401 | drop=0., attn_drop=0., drop_path=0.)
402 |
403 | self.trans_block3 = CFGT(
404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
405 | drop=0., attn_drop=0., drop_path=0.)
406 |
407 | self.trans_block4 = CFGT(
408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
409 | drop=0., attn_drop=0., drop_path=0.)
410 |
411 | self.trans_block5 = CFGT(
412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
413 | drop=0., attn_drop=0., drop_path=0.)
414 |
415 | self.trans_block6 = CFGT(
416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
417 | drop=0., attn_drop=0., drop_path=0.)
418 |
419 | self.trans_block7 = CFGT(
420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
421 | drop=0., attn_drop=0., drop_path=0.)
422 |
423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
431 |
432 | self.res_scale = Scale(1)
433 | self.x_scale = Scale(1)
434 | self.num_rbs = 1
435 |
436 | self.res_conv = res_conv
437 | self.drop_block = drop_block
438 | self.drop_path = drop_path
439 |
440 | def zero_init_last_bn(self):
441 | nn.init.zeros_(self.bn3.weight)
442 |
443 | def forward(self, x):
444 | residual = x
445 |
446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x # 1 CT block
447 |
448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x # 2 CT block
449 |
450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x # 3 CT block
451 |
452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x # 4 CT block
453 |
454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x # 5 CT block
455 |
456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x # 6 CT block
457 |
458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x # 7 CT block
459 |
460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x # 8 CT block
461 |
462 | x = self.x_scale(x) + self.res_scale(residual)
463 |
464 | return x
465 |
466 |
467 | class ConvTransBlock(nn.Module):
468 | """
469 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
470 | """
471 |
472 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio,
473 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
474 | last_fusion=False, num_med_block=0, groups=1):
475 | super(ConvTransBlock, self).__init__()
476 | expansion = 1
477 | self.cnn_block = CTGroup(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
478 |
479 | self.dw_stride = dw_stride
480 | self.embed_dim = embed_dim
481 | self.num_med_block = num_med_block
482 | self.last_fusion = last_fusion
483 | self.res_scale = Scale(1)
484 | self.x_scale = Scale(1)
485 |
486 | def forward(self, x):
487 | x = self.cnn_block(x)
488 |
489 | return x
490 |
491 |
492 | class MODEL(nn.Module):
493 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1.,
494 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0.,
495 | patch_norm=True):
496 | super(MODEL, self).__init__()
497 | scale = args.scale
498 | n_feats = 48
499 | n_colors = 3
500 | embed_dim = 64
501 |
502 | self.patch_norm = patch_norm
503 | self.num_features = embed_dim
504 | rgb_mean = (0.4488, 0.4371, 0.4040)
505 | rgb_std = (1.0, 1.0, 1.0)
506 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std)
507 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1)
508 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1)
509 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
510 |
511 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule
512 |
513 | # 2~final Stage
514 | init_stage = 2
515 | fin_stage = 3
516 | stage_1_channel = n_feats
517 | trans_dw_stride = patch_size
518 | for i in range(init_stage, fin_stage):
519 | if i%2==0:
520 | m = i
521 | else:
522 | m = i-1
523 | self.add_module('conv_trans_' + str(m),
524 | ConvTransBlock(
525 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride,
526 | embed_dim=embed_dim,
527 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
528 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
529 | drop_path_rate=self.trans_dpr[i - 1],
530 | num_med_block=num_med_block
531 | )
532 | )
533 |
534 | self.fin_stage = fin_stage
535 | self.dw_stride = trans_dw_stride
536 |
537 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
538 |
539 | m = []
540 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
541 | m.append(nn.PixelShuffle(scale[0]))
542 | self.UP1 = nn.Sequential(*m)
543 |
544 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
545 | up_body = []
546 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
547 | up_body.append(nn.PixelShuffle(scale[0]))
548 | self.UP2 = nn.Sequential(*up_body)
549 |
550 | self.apply(self._init_weights)
551 |
552 | def _init_weights(self, m):
553 | if isinstance(m, nn.Linear):
554 | trunc_normal_(m.weight, std=.02)
555 | if isinstance(m, nn.Linear) and m.bias is not None:
556 | nn.init.constant_(m.bias, 0)
557 | elif isinstance(m, nn.LayerNorm):
558 | nn.init.constant_(m.bias, 0)
559 | nn.init.constant_(m.weight, 1.0)
560 |
561 | def forward(self, x):
562 | (H, W) = (x.shape[2], x.shape[3])
563 | residual = x
564 | x = self.sub_mean(x)
565 | x = self.conv_first_cnn(x)
566 |
567 | for i in range(2, self.fin_stage):
568 | if i%2==0:
569 | m = i
570 | else:
571 | m = i-1
572 | x = eval('self.conv_trans_' + str(m))(x)
573 |
574 | x = self.conv_after_body(x)
575 | y1 = self.UP1(x)
576 | y2 = self.UP2(self.conv_stright(residual))
577 | output = self.add_mean(y1 + y2)
578 |
579 | return output
--------------------------------------------------------------------------------
/model/CFINx2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from model import common
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import torch.nn.functional as F
7 | from pdb import set_trace as stx
8 | import numbers
9 | from einops import rearrange
10 | from torch.nn.parameter import Parameter
11 | from torch.autograd import Variable
12 | #from IPython import embed
13 |
14 | def make_model(args, parent=False):
15 | return MODEL(args)
16 |
17 | class Mlp(nn.Module):
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0)
23 | self.act = act_layer()
24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | class Conv2d_CG(nn.Conv2d):
36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1,
37 | bias=True):
38 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
39 |
40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True)
41 | self.bias_conv = Parameter(torch.Tensor(out_channels))
42 | nn.init.kaiming_normal_(self.weight_conv)
43 |
44 | self.stride = stride
45 | self.padding = padding
46 | self.dilation = dilation
47 | self.groups = groups
48 |
49 | if kernel_size == 0:
50 | self.ind = True
51 | else:
52 | self.ind = False
53 | self.oc = out_channels
54 | self.ks = kernel_size
55 |
56 | # target spatial size of the pooling layer
57 | ws = kernel_size
58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws))
59 |
60 | # the dimension of latent representation
61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
62 |
63 | # the context encoding module
64 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
65 |
66 | self.act = nn.ReLU()
67 |
68 | # the number of groups in the channel interaction module
69 | if in_channels // 8:
70 | self.g = 8
71 | else:
72 | self.g = in_channels
73 |
74 | # the channel interacting module
75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
76 |
77 | # the gate decoding module (spatial interaction)
78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
80 |
81 | # used to prepare the input feature map to patches
82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
83 |
84 | # sigmoid function
85 | self.sig = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | if self.ind:
89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups)
90 | else:
91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch
92 | weight = self.weight_conv
93 |
94 | # allocate global information
95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9
96 |
97 | # context-encoding module
98 | out = self.ce(gl) # out: batch x n_feat x 5
99 |
100 | # use different bn for following two branches
101 | ce2 = out # ce2: batch x n_feat x 5
102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization)
103 |
104 | # gate decoding branch 1 (spatial interaction)
105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3)
106 |
107 | # channel interacting module
108 | if self.g > 3:
109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous()
110 | else:
111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous()
112 | oc = oc.view(b, self.oc, -1)
113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer)
114 |
115 | # gate decoding branch 2 (spatial interaction)
116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3)
117 |
118 | # produce gate (equation (4) in the CRAN paper)
119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight)
121 |
122 | # unfolding input feature map to patches
123 | x_un = self.unfold(x)
124 | b, _, l = x_un.size()
125 | out = (out * weight.unsqueeze(0))#.to(device)
126 | out = out.view(b, self.oc, -1)
127 |
128 | # currently only handle square input and output
129 | return torch.matmul(out, x_un).view(b, self.oc, h, w)
130 |
131 | class ConvAttention(nn.Module):
132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
133 | super().__init__()
134 | self.num_heads = num_heads
135 | head_dim = dim // num_heads
136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True)
137 | self.to_qkv = Conv2d_CG(dim, dim*3)
138 |
139 | def forward(self, x, k1=None, v1=None, return_x=False):
140 | weight = self.weight
141 | b,c,h,w = x.shape
142 |
143 | qkv = self.to_qkv(x)
144 | q, k, v = qkv.chunk(3, dim=1)
145 |
146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
149 |
150 | if k1 is None:
151 | k = k
152 | v = v
153 | else:
154 | k = k1 + k
155 | v = v1 + v
156 | q = torch.nn.functional.normalize(q, dim=-1)
157 | k = torch.nn.functional.normalize(k, dim=-1)
158 |
159 | attn = (q @ k.transpose(-2, -1)) * weight
160 | attn = attn.softmax(dim=-1)
161 | x = (attn @ v)
162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
163 |
164 | if return_x:
165 | return x
166 | else:
167 | return x, k, v
168 |
169 |
170 | class WithBias_LayerNorm(nn.Module):
171 | def __init__(self, normalized_shape):
172 | super(WithBias_LayerNorm, self).__init__()
173 | if isinstance(normalized_shape, numbers.Integral):
174 | normalized_shape = (normalized_shape,)
175 | normalized_shape = torch.Size(normalized_shape)
176 |
177 | assert len(normalized_shape) == 1
178 |
179 | self.weight = nn.Parameter(torch.ones(normalized_shape))
180 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
181 | self.normalized_shape = normalized_shape
182 |
183 | def forward(self, x):
184 | mu = x.mean(-1, keepdim=True)
185 | sigma = x.var(-1, keepdim=True, unbiased=False)
186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
187 |
188 |
189 | def to_3d(x):
190 | return rearrange(x, 'b c h w -> b (h w) c')
191 |
192 | def to_4d(x,h,w):
193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
194 |
195 | class LayerNorm(nn.Module):
196 | def __init__(self, dim):
197 | super(LayerNorm, self).__init__()
198 | self.body = WithBias_LayerNorm(dim)
199 |
200 | def forward(self, x):
201 | h, w = x.shape[-2:]
202 | return to_4d(self.body(to_3d(x)), h, w)
203 |
204 |
205 | class Block(nn.Module):
206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
207 | drop_path=0., act_layer=nn.GELU):
208 | super().__init__()
209 | self.norm1 = LayerNorm(dim)
210 | kernel_size1 = 1
211 | padding1 = 0
212 | kernel_size2 = 3
213 | padding2 = 1
214 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1)
215 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2)
216 |
217 | self.norm2 = LayerNorm(dim)
218 | self.norm3 = LayerNorm(dim)
219 | mlp_hidden_dim = int(dim*1)
220 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim)
221 |
222 | def forward(self, x):
223 | res = x
224 | x, k1, v1 = self.attn(x)
225 | x = res + self.norm1(x)
226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True))
227 | x = x + self.norm3(self.mlp(x))
228 | return x
229 |
230 |
231 | class Scale(nn.Module):
232 | def __init__(self, init_value=1e-3):
233 | super().__init__()
234 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
235 |
236 | def forward(self, input):
237 | return input * self.scale
238 |
239 |
240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
241 | act_type = act_type.lower()
242 | if act_type == 'relu':
243 | layer = nn.ReLU()
244 | elif act_type == 'lrelu':
245 | layer = nn.LeakyReLU(neg_slope)
246 | elif act_type == 'prelu':
247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
248 | else:
249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
250 | return layer
251 |
252 |
253 | class eca_layer(nn.Module):
254 | def __init__(self, channel, k_size):
255 | super(eca_layer, self).__init__()
256 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
257 | self.k_size = k_size
258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
259 | self.sigmoid = nn.Sigmoid()
260 |
261 |
262 | def forward(self, x):
263 | b, c, _, _ = x.size()
264 | y = self.avg_pool(x)
265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
267 | y = self.sigmoid(y)
268 | x = x * y.expand_as(x)
269 | return x
270 |
271 |
272 | class MaskPredictor(nn.Module):
273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)):
274 | super(MaskPredictor,self).__init__()
275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False)
276 |
277 | def forward(self,x):
278 | spa_mask=self.spatial_mask(x)
279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1)
280 | return spa_mask
281 |
282 |
283 | class RB(nn.Module):
284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
285 | super(RB, self).__init__()
286 | self.CA = eca_layer(n_feats, k_size=3)
287 | self.MaskPredictor = MaskPredictor(n_feats*8//8)
288 |
289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
290 | nn.LeakyReLU(0.05),
291 | )
292 |
293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
294 | nn.LeakyReLU(0.05),
295 | )
296 |
297 | self.res_scale = Scale(1)
298 | self.x_scale = Scale(1)
299 |
300 | def forward(self, x):
301 | res = x
302 | x = self.k(x)
303 |
304 | MaskPredictor = self.MaskPredictor(x)
305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1)
306 | x = x * (mask.expand_as(x))
307 |
308 | x1 = self.k1(x)
309 | x2 = self.CA(x1)
310 | out = self.x_scale(x2) + self.res_scale(res)
311 |
312 | return out
313 |
314 |
315 | class SCConv(nn.Module):
316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
317 | super(SCConv, self).__init__()
318 | pooling_r = 2
319 | med_feats = n_feats // 1
320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*3//2, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True),
321 | nn.LeakyReLU(0.05),
322 | nn.Conv2d(n_feats*3//2, n_feats, kernel_size=1, stride=2, padding=0, groups=1),
323 | )
324 |
325 | self.sig = nn.Sigmoid()
326 |
327 | self.k3 = RB(n_feats)
328 |
329 | self.k4 = RB(n_feats)
330 |
331 | self.k5 = RB(n_feats)
332 |
333 | self.res_scale = Scale(1)
334 | self.x_scale = Scale(1)
335 |
336 | def forward(self, x):
337 | identity = x
338 | _, _, H, W = identity.shape
339 | x1_1 = self.k3(x)
340 | x1 = self.k4(x1_1)
341 |
342 |
343 | x1_s = self.sig(self.k1(x) + x)
344 | x1 = self.k5(x1_s * x1)
345 |
346 | out = self.res_scale(x1) + self.x_scale(identity)
347 |
348 | return out
349 |
350 |
351 | class FCUUp(nn.Module):
352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)):
354 | super(FCUUp, self).__init__()
355 | self.up_stride = up_stride
356 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
357 | self.act = act_layer()
358 |
359 | def forward(self, x_t):
360 | x_r = self.act(self.conv_project(x_t))
361 |
362 | return x_r
363 |
364 | class FCUDown(nn.Module):
365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)):
367 | super(FCUDown, self).__init__()
368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
369 |
370 | def forward(self, x):
371 | x = self.conv_project(x)
372 |
373 | return x
374 |
375 |
376 | class ConvBlock(nn.Module):
377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None):
378 | super(ConvBlock, self).__init__()
379 |
380 | expansion = 1
381 | med_planes = outplanes // expansion
382 | embed_dim = 144
383 | num_heads = 8
384 | mlp_ratio = 1.0
385 |
386 | self.rb_search1 = SCConv(med_planes)
387 | self.rb_search2 = SCConv(med_planes)
388 | self.rb_search3 = SCConv(med_planes)
389 | self.rb_search4 = SCConv(med_planes)
390 |
391 | self.trans_block = Block(
392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
393 | drop=0., attn_drop=0., drop_path=0.)
394 |
395 | self.trans_block1 = Block(
396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
397 | drop=0., attn_drop=0., drop_path=0.)
398 |
399 | self.trans_block2 = Block(
400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
401 | drop=0., attn_drop=0., drop_path=0.)
402 |
403 | self.trans_block3 = Block(
404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
405 | drop=0., attn_drop=0., drop_path=0.)
406 |
407 | self.trans_block4 = Block(
408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
409 | drop=0., attn_drop=0., drop_path=0.)
410 |
411 | self.trans_block5 = Block(
412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
413 | drop=0., attn_drop=0., drop_path=0.)
414 |
415 | self.trans_block6 = Block(
416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
417 | drop=0., attn_drop=0., drop_path=0.)
418 |
419 | self.trans_block7 = Block(
420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
421 | drop=0., attn_drop=0., drop_path=0.)
422 |
423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
431 |
432 | self.res_scale = Scale(1)
433 | self.x_scale = Scale(1)
434 | self.num_rbs = 1
435 |
436 | self.res_conv = res_conv
437 | self.drop_block = drop_block
438 | self.drop_path = drop_path
439 |
440 | def zero_init_last_bn(self):
441 | nn.init.zeros_(self.bn3.weight)
442 |
443 | def forward(self, x):
444 | residual = x
445 |
446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x
447 |
448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x
449 |
450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x
451 |
452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x
453 |
454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x
455 |
456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x
457 |
458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x
459 |
460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x
461 |
462 | x = self.x_scale(x) + self.res_scale(residual)
463 |
464 | return x
465 |
466 |
467 | class ConvTransBlock(nn.Module):
468 |
469 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio,
470 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
471 | last_fusion=False, num_med_block=0, groups=1):
472 | super(ConvTransBlock, self).__init__()
473 | expansion = 1
474 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
475 |
476 | self.dw_stride = dw_stride
477 | self.embed_dim = embed_dim
478 | self.num_med_block = num_med_block
479 | self.last_fusion = last_fusion
480 | self.res_scale = Scale(1)
481 | self.x_scale = Scale(1)
482 |
483 | def forward(self, x):
484 | x = self.cnn_block(x)
485 |
486 | return x
487 |
488 |
489 | class MODEL(nn.Module):
490 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1.,
491 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0.,
492 | patch_norm=True):
493 | super(MODEL, self).__init__()
494 | scale = args.scale
495 | n_feats = 48
496 | n_colors = 3
497 | embed_dim = 64
498 |
499 | self.patch_norm = patch_norm
500 | self.num_features = embed_dim
501 | rgb_mean = (0.4488, 0.4371, 0.4040)
502 | rgb_std = (1.0, 1.0, 1.0)
503 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std)
504 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1)
505 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1)
506 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
507 |
508 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule
509 |
510 | # 2~final Stage
511 | init_stage = 2
512 | fin_stage = 3
513 | stage_1_channel = n_feats
514 | trans_dw_stride = patch_size
515 | for i in range(init_stage, fin_stage):
516 | if i%2==0:
517 | m = i
518 | else:
519 | m = i-1
520 | self.add_module('conv_trans_' + str(m),
521 | ConvTransBlock(
522 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride,
523 | embed_dim=embed_dim,
524 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
525 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
526 | drop_path_rate=self.trans_dpr[i - 1],
527 | num_med_block=num_med_block
528 | )
529 | )
530 |
531 | self.fin_stage = fin_stage
532 | self.dw_stride = trans_dw_stride
533 |
534 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
535 |
536 | m = []
537 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
538 | m.append(nn.PixelShuffle(scale[0]))
539 | self.UP1 = nn.Sequential(*m)
540 |
541 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
542 | up_body = []
543 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
544 | up_body.append(nn.PixelShuffle(scale[0]))
545 | self.UP2 = nn.Sequential(*up_body)
546 |
547 | self.apply(self._init_weights)
548 |
549 | def _init_weights(self, m):
550 | if isinstance(m, nn.Linear):
551 | trunc_normal_(m.weight, std=.02)
552 | if isinstance(m, nn.Linear) and m.bias is not None:
553 | nn.init.constant_(m.bias, 0)
554 | elif isinstance(m, nn.LayerNorm):
555 | nn.init.constant_(m.bias, 0)
556 | nn.init.constant_(m.weight, 1.0)
557 |
558 | def forward(self, x):
559 | (H, W) = (x.shape[2], x.shape[3])
560 | residual = x
561 | x = self.sub_mean(x)
562 | x = self.conv_first_cnn(x)
563 |
564 | for i in range(2, self.fin_stage):
565 | if i%2==0:
566 | m = i
567 | else:
568 | m = i-1
569 | x = eval('self.conv_trans_' + str(m))(x)
570 |
571 | x = self.conv_after_body(x)
572 | y1 = self.UP1(x)
573 | y2 = self.UP2(self.conv_stright(residual))
574 | output = self.add_mean(y1 + y2)
575 |
576 | return output
--------------------------------------------------------------------------------
/model/CFINx3.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from model import common
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import torch.nn.functional as F
7 | from pdb import set_trace as stx
8 | import numbers
9 | from einops import rearrange
10 | from torch.nn.parameter import Parameter
11 | from torch.autograd import Variable
12 | #from IPython import embed
13 |
14 | def make_model(args, parent=False):
15 | return MODEL(args)
16 |
17 | class Mlp(nn.Module):
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0)
23 | self.act = act_layer()
24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | class PatchEmbed(nn.Module):
36 | def __init__(self, img_size, patch_size=4, in_chans=32, embed_dim=32, norm_layer=None):
37 | super().__init__()
38 | img_size = to_2tuple(img_size)
39 | patch_size = to_2tuple(patch_size)
40 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
41 | self.img_size = img_size
42 | self.patch_size = patch_size
43 | self.patches_resolution = patches_resolution
44 | self.num_patches = patches_resolution[0] * patches_resolution[1]
45 |
46 | self.embed_dim = embed_dim
47 | self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
48 |
49 | if norm_layer is not None: ## norm_layer=None
50 | self.norm = norm_layer(embed_dim)
51 | else:
52 | self.norm = None
53 |
54 | def forward(self, x):
55 | B, C, H, W = x.shape
56 | x = self.proj(x)
57 | return x
58 |
59 | class Conv2d_CG(nn.Conv2d):
60 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1,
61 | bias=True):
62 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
63 |
64 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True)
65 | self.bias_conv = Parameter(torch.Tensor(out_channels))
66 | nn.init.kaiming_normal_(self.weight_conv)
67 |
68 | self.stride = stride
69 | self.padding = padding
70 | self.dilation = dilation
71 | self.groups = groups
72 |
73 | if kernel_size == 0:
74 | self.ind = True
75 | else:
76 | self.ind = False
77 | self.oc = out_channels
78 | self.ks = kernel_size
79 |
80 | # target spatial size of the pooling layer
81 | ws = kernel_size
82 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws))
83 |
84 | # the dimension of latent representation
85 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
86 |
87 | # the context encoding module
88 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
89 |
90 | self.act = nn.ReLU()
91 |
92 | # the number of groups in the channel interaction module
93 | if in_channels // 8:
94 | self.g = 8
95 | else:
96 | self.g = in_channels
97 |
98 | # the channel interacting module
99 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
100 |
101 | # the gate decoding module (spatial interaction)
102 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
103 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
104 |
105 | # used to prepare the input feature map to patches
106 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
107 |
108 | # sigmoid function
109 | self.sig = nn.Sigmoid()
110 |
111 | def forward(self, x):
112 | if self.ind:
113 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups)
114 | else:
115 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch
116 | weight = self.weight_conv
117 |
118 | # allocate global information
119 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9
120 |
121 | # context-encoding module
122 | out = self.ce(gl) # out: batch x n_feat x 5
123 |
124 | # use different bn for following two branches
125 | ce2 = out # ce2: batch x n_feat x 5
126 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization)
127 |
128 | # gate decoding branch 1 (spatial interaction)
129 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3)
130 |
131 | # channel interacting module
132 | if self.g > 3:
133 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous()
134 | else:
135 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous()
136 | oc = oc.view(b, self.oc, -1)
137 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer)
138 |
139 | # gate decoding branch 2 (spatial interaction)
140 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3)
141 |
142 | # produce gate (equation (4) in the CRAN paper)
143 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
144 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight)
145 |
146 | # unfolding input feature map to patches
147 | x_un = self.unfold(x)
148 | b, _, l = x_un.size()
149 | out = (out * weight.unsqueeze(0))#.to(device)
150 | out = out.view(b, self.oc, -1)
151 |
152 | # currently only handle square input and output
153 | return torch.matmul(out, x_un).view(b, self.oc, h, w)
154 |
155 | class ConvAttention(nn.Module):
156 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
157 | super().__init__()
158 | self.num_heads = num_heads
159 | head_dim = dim // num_heads
160 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True)
161 | self.to_qkv = Conv2d_CG(dim, dim*3)
162 |
163 | def forward(self, x, k1=None, v1=None, return_x=False):
164 | weight = self.weight
165 | b,c,h,w = x.shape
166 |
167 | qkv = self.to_qkv(x)
168 | q, k, v = qkv.chunk(3, dim=1)
169 |
170 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
171 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
172 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
173 |
174 | if k1 is None:
175 | k = k
176 | v = v
177 | else:
178 | k = k1 + k
179 | v = v1 + v
180 | q = torch.nn.functional.normalize(q, dim=-1)
181 | k = torch.nn.functional.normalize(k, dim=-1)
182 |
183 | attn = (q @ k.transpose(-2, -1)) * weight
184 | attn = attn.softmax(dim=-1)
185 | x = (attn @ v)
186 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
187 |
188 | if return_x:
189 | return x
190 | else:
191 | return x, k, v
192 |
193 |
194 | class WithBias_LayerNorm(nn.Module):
195 | def __init__(self, normalized_shape):
196 | super(WithBias_LayerNorm, self).__init__()
197 | if isinstance(normalized_shape, numbers.Integral):
198 | normalized_shape = (normalized_shape,)
199 | normalized_shape = torch.Size(normalized_shape)
200 |
201 | assert len(normalized_shape) == 1
202 |
203 | self.weight = nn.Parameter(torch.ones(normalized_shape))
204 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
205 | self.normalized_shape = normalized_shape
206 |
207 | def forward(self, x):
208 | mu = x.mean(-1, keepdim=True)
209 | sigma = x.var(-1, keepdim=True, unbiased=False)
210 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
211 |
212 |
213 | def to_3d(x):
214 | return rearrange(x, 'b c h w -> b (h w) c')
215 |
216 | def to_4d(x,h,w):
217 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
218 |
219 | class LayerNorm(nn.Module):
220 | def __init__(self, dim):
221 | super(LayerNorm, self).__init__()
222 | self.body = WithBias_LayerNorm(dim)
223 |
224 | def forward(self, x):
225 | h, w = x.shape[-2:]
226 | return to_4d(self.body(to_3d(x)), h, w)
227 |
228 |
229 | class Block(nn.Module):
230 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
231 | drop_path=0., act_layer=nn.GELU):
232 | super().__init__()
233 | self.norm1 = LayerNorm(dim)
234 | kernel_size1 = 1
235 | padding1 = 0
236 | kernel_size2 = 3
237 | padding2 = 1
238 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1)
239 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2)
240 |
241 | self.norm2 = LayerNorm(dim)
242 | self.norm3 = LayerNorm(dim)
243 | mlp_hidden_dim = int(dim*1)
244 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim)
245 |
246 | def forward(self, x):
247 | res = x
248 | x, k1, v1 = self.attn(x)
249 | x = res + self.norm1(x)
250 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True))
251 | x = x + self.norm3(self.mlp(x))
252 | return x
253 |
254 |
255 | class Scale(nn.Module):
256 | def __init__(self, init_value=1e-3):
257 | super().__init__()
258 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
259 |
260 | def forward(self, input):
261 | return input * self.scale
262 |
263 |
264 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
265 | act_type = act_type.lower()
266 | if act_type == 'relu':
267 | layer = nn.ReLU()
268 | elif act_type == 'lrelu':
269 | layer = nn.LeakyReLU(neg_slope)
270 | elif act_type == 'prelu':
271 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
272 | else:
273 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
274 | return layer
275 |
276 |
277 | class eca_layer(nn.Module):
278 | def __init__(self, channel, k_size):
279 | super(eca_layer, self).__init__()
280 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
281 | self.k_size = k_size
282 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
283 | self.sigmoid = nn.Sigmoid()
284 |
285 |
286 | def forward(self, x):
287 | b, c, _, _ = x.size()
288 | y = self.avg_pool(x)
289 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
290 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
291 | y = self.sigmoid(y)
292 | x = x * y.expand_as(x)
293 | return x
294 |
295 |
296 | class MaskPredictor(nn.Module):
297 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)):
298 | super(MaskPredictor,self).__init__()
299 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False)
300 |
301 | def forward(self,x):
302 | spa_mask=self.spatial_mask(x)
303 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1)
304 | return spa_mask
305 |
306 |
307 | class RB(nn.Module):
308 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
309 | super(RB, self).__init__()
310 | self.CA = eca_layer(n_feats, k_size=3)
311 | self.MaskPredictor = MaskPredictor(n_feats*8//8)
312 |
313 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
314 | nn.LeakyReLU(0.05),
315 | )
316 |
317 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
318 | nn.LeakyReLU(0.05),
319 | )
320 |
321 | self.res_scale = Scale(1)
322 | self.x_scale = Scale(1)
323 |
324 | def forward(self, x):
325 | res = x
326 | x = self.k(x)
327 |
328 | MaskPredictor = self.MaskPredictor(x)
329 | mask = (MaskPredictor[:,1,...]).unsqueeze(1)
330 | x = x * (mask.expand_as(x))
331 |
332 | x1 = self.k1(x)
333 | x2 = self.CA(x1)
334 | out = self.x_scale(x2) + self.res_scale(res)
335 |
336 | return out
337 |
338 |
339 | class SCConv(nn.Module):
340 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
341 | super(SCConv, self).__init__()
342 | pooling_r = 2
343 | med_feats = n_feats // 1
344 |
345 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True),
346 | nn.LeakyReLU(0.05),
347 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1),
348 | )
349 |
350 | self.sig = nn.Sigmoid()
351 |
352 | self.k3 = RB(n_feats)
353 |
354 | self.k4 = RB(n_feats)
355 |
356 | self.k5 = RB(n_feats)
357 |
358 | self.res_scale = Scale(1)
359 | self.x_scale = Scale(1)
360 |
361 | def forward(self, x):
362 | identity = x
363 | _, _, H, W = identity.shape
364 | x1_1 = self.k3(x)
365 | x1 = self.k4(x1_1)
366 |
367 |
368 | x1_s = self.sig(self.k1(x) + x)
369 | x1 = self.k5(x1_s * x1)
370 |
371 | out = self.res_scale(x1) + self.x_scale(identity)
372 |
373 | return out
374 |
375 |
376 | class FCUUp(nn.Module):
377 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
378 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)):
379 | super(FCUUp, self).__init__()
380 | self.up_stride = up_stride
381 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
382 | self.act = act_layer()
383 |
384 | def forward(self, x_t):
385 | x_r = self.act(self.conv_project(x_t))
386 |
387 | return x_r
388 |
389 | class FCUDown(nn.Module):
390 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
391 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)):
392 | super(FCUDown, self).__init__()
393 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
394 |
395 | def forward(self, x):
396 | x = self.conv_project(x)
397 |
398 | return x
399 |
400 |
401 | class ConvBlock(nn.Module):
402 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None):
403 | super(ConvBlock, self).__init__()
404 |
405 | expansion = 1
406 | med_planes = outplanes // expansion
407 | embed_dim = 144
408 | num_heads = 8
409 | mlp_ratio = 1.0
410 |
411 | self.rb_search1 = SCConv(med_planes)
412 | self.rb_search2 = SCConv(med_planes)
413 | self.rb_search3 = SCConv(med_planes)
414 | self.rb_search4 = SCConv(med_planes)
415 |
416 | self.trans_block = Block(
417 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
418 | drop=0., attn_drop=0., drop_path=0.)
419 |
420 | self.trans_block1 = Block(
421 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
422 | drop=0., attn_drop=0., drop_path=0.)
423 |
424 | self.trans_block2 = Block(
425 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
426 | drop=0., attn_drop=0., drop_path=0.)
427 |
428 | self.trans_block3 = Block(
429 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
430 | drop=0., attn_drop=0., drop_path=0.)
431 |
432 | self.trans_block4 = Block(
433 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
434 | drop=0., attn_drop=0., drop_path=0.)
435 |
436 | self.trans_block5 = Block(
437 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
438 | drop=0., attn_drop=0., drop_path=0.)
439 |
440 | self.trans_block6 = Block(
441 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
442 | drop=0., attn_drop=0., drop_path=0.)
443 |
444 | self.trans_block7 = Block(
445 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
446 | drop=0., attn_drop=0., drop_path=0.)
447 |
448 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
449 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
450 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
451 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
452 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
453 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
454 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
455 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
456 |
457 | self.res_scale = Scale(1)
458 | self.x_scale = Scale(1)
459 | self.num_rbs = 1
460 |
461 | self.res_conv = res_conv
462 | self.drop_block = drop_block
463 | self.drop_path = drop_path
464 |
465 | def zero_init_last_bn(self):
466 | nn.init.zeros_(self.bn3.weight)
467 |
468 | def forward(self, x):
469 | residual = x
470 |
471 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x
472 |
473 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x
474 |
475 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x
476 |
477 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x
478 |
479 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x
480 |
481 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x
482 |
483 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x
484 |
485 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x
486 |
487 | x = self.x_scale(x) + self.res_scale(residual)
488 |
489 | return x
490 |
491 |
492 | class ConvTransBlock(nn.Module):
493 | """
494 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
495 | """
496 |
497 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio,
498 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
499 | last_fusion=False, num_med_block=0, groups=1):
500 | super(ConvTransBlock, self).__init__()
501 | expansion = 1
502 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
503 |
504 | self.dw_stride = dw_stride
505 | self.embed_dim = embed_dim
506 | self.num_med_block = num_med_block
507 | self.last_fusion = last_fusion
508 | self.res_scale = Scale(1)
509 | self.x_scale = Scale(1)
510 |
511 | def forward(self, x):
512 | x = self.cnn_block(x)
513 |
514 | return x
515 |
516 |
517 | class MODEL(nn.Module):
518 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1.,
519 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0.,
520 | patch_norm=True):
521 | super(MODEL, self).__init__()
522 | scale = args.scale
523 | n_feats = 48
524 | n_colors = 3
525 | embed_dim = 64
526 |
527 | self.patch_norm = patch_norm
528 | self.num_features = embed_dim
529 | rgb_mean = (0.4488, 0.4371, 0.4040)
530 | rgb_std = (1.0, 1.0, 1.0)
531 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std)
532 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1)
533 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
534 |
535 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule
536 |
537 | init_stage = 2
538 | fin_stage = 3
539 | stage_1_channel = n_feats
540 | trans_dw_stride = patch_size
541 | for i in range(init_stage, fin_stage):
542 | if i%2==0:
543 | m = i
544 | else:
545 | m = i-1
546 | self.add_module('conv_trans_' + str(m),
547 | ConvTransBlock(
548 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride,
549 | embed_dim=embed_dim,
550 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
551 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
552 | drop_path_rate=self.trans_dpr[i - 1],
553 | num_med_block=num_med_block
554 | )
555 | )
556 |
557 | self.fin_stage = fin_stage
558 | self.dw_stride = trans_dw_stride
559 |
560 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
561 |
562 | m = []
563 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
564 | m.append(nn.PixelShuffle(scale[0]))
565 | self.UP1 = nn.Sequential(*m)
566 |
567 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
568 | up_body = []
569 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
570 | up_body.append(nn.PixelShuffle(scale[0]))
571 | self.UP2 = nn.Sequential(*up_body)
572 |
573 | self.apply(self._init_weights)
574 |
575 | def _init_weights(self, m):
576 | if isinstance(m, nn.Linear):
577 | trunc_normal_(m.weight, std=.02)
578 | if isinstance(m, nn.Linear) and m.bias is not None:
579 | nn.init.constant_(m.bias, 0)
580 | elif isinstance(m, nn.LayerNorm):
581 | nn.init.constant_(m.bias, 0)
582 | nn.init.constant_(m.weight, 1.0)
583 |
584 | def forward(self, x):
585 | (H, W) = (x.shape[2], x.shape[3])
586 | residual = x
587 | x = self.sub_mean(x)
588 | x = self.conv_first_cnn(x)
589 |
590 | for i in range(2, self.fin_stage):
591 | if i%2==0:
592 | m = i
593 | else:
594 | m = i-1
595 | x = eval('self.conv_trans_' + str(m))(x)
596 |
597 | x = self.conv_after_body(x)
598 | y1 = self.UP1(x)
599 | y2 = self.UP2(self.conv_stright(residual))
600 | output = self.add_mean(y1 + y2)
601 |
602 | return output
--------------------------------------------------------------------------------
/model/CFINx4.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from model import common
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import torch.nn.functional as F
7 | from pdb import set_trace as stx
8 | import numbers
9 | from einops import rearrange
10 | from torch.nn.parameter import Parameter
11 | from torch.autograd import Variable
12 | #from IPython import embed
13 |
14 | def make_model(args, parent=False):
15 | return MODEL(args)
16 |
17 | class Mlp(nn.Module):
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0)
23 | self.act = act_layer()
24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | class Conv2d_CG(nn.Conv2d):
36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1,
37 | bias=True):
38 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
39 |
40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True)
41 | self.bias_conv = Parameter(torch.Tensor(out_channels))
42 | nn.init.kaiming_normal_(self.weight_conv)
43 |
44 | self.stride = stride
45 | self.padding = padding
46 | self.dilation = dilation
47 | self.groups = groups
48 |
49 | if kernel_size == 0:
50 | self.ind = True
51 | else:
52 | self.ind = False
53 | self.oc = out_channels
54 | self.ks = kernel_size
55 |
56 | # target spatial size of the pooling layer
57 | ws = kernel_size
58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws))
59 |
60 | # the dimension of latent representation
61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
62 |
63 | # the context encoding module
64 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
65 |
66 | self.act = nn.ReLU()
67 |
68 | # the number of groups in the channel interaction module
69 | if in_channels // 8:
70 | self.g = 8
71 | else:
72 | self.g = in_channels
73 |
74 | # the channel interacting module
75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
76 |
77 | # the gate decoding module (spatial interaction)
78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
80 |
81 | # used to prepare the input feature map to patches
82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
83 |
84 | # sigmoid function
85 | self.sig = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | if self.ind:
89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups)
90 | else:
91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch
92 | weight = self.weight_conv
93 |
94 | # allocate global information
95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9
96 |
97 | # context-encoding module
98 | out = self.ce(gl) # out: batch x n_feat x 5
99 |
100 | # use different bn for following two branches
101 | ce2 = out # ce2: batch x n_feat x 5
102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization)
103 |
104 | # gate decoding branch 1 (spatial interaction)
105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3)
106 |
107 | # channel interacting module
108 | if self.g > 3:
109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous()
110 | else:
111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous()
112 | oc = oc.view(b, self.oc, -1)
113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer)
114 |
115 | # gate decoding branch 2 (spatial interaction)
116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3)
117 |
118 | # produce gate (equation (4) in the CRAN paper)
119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight)
121 |
122 | # unfolding input feature map to patches
123 | x_un = self.unfold(x)
124 | b, _, l = x_un.size()
125 | out = (out * weight.unsqueeze(0))#.to(device)
126 | out = out.view(b, self.oc, -1)
127 |
128 | # currently only handle square input and output
129 | return torch.matmul(out, x_un).view(b, self.oc, h, w)
130 |
131 | class ConvAttention(nn.Module):
132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
133 | super().__init__()
134 | self.num_heads = num_heads
135 | head_dim = dim // num_heads
136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True)
137 | self.to_qkv = Conv2d_CG(dim, dim*3)
138 |
139 | def forward(self, x, k1=None, v1=None, return_x=False):
140 | weight = self.weight
141 | b,c,h,w = x.shape
142 |
143 | qkv = self.to_qkv(x)
144 | q, k, v = qkv.chunk(3, dim=1)
145 |
146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
149 |
150 | if k1 is None:
151 | k = k
152 | v = v
153 | else:
154 | k = k1 + k
155 | v = v1 + v
156 | q = torch.nn.functional.normalize(q, dim=-1)
157 | k = torch.nn.functional.normalize(k, dim=-1)
158 |
159 | attn = (q @ k.transpose(-2, -1)) * weight
160 | attn = attn.softmax(dim=-1)
161 | x = (attn @ v)
162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
163 |
164 | if return_x:
165 | return x
166 | else:
167 | return x, k, v
168 |
169 |
170 | class WithBias_LayerNorm(nn.Module):
171 | def __init__(self, normalized_shape):
172 | super(WithBias_LayerNorm, self).__init__()
173 | if isinstance(normalized_shape, numbers.Integral):
174 | normalized_shape = (normalized_shape,)
175 | normalized_shape = torch.Size(normalized_shape)
176 |
177 | assert len(normalized_shape) == 1
178 |
179 | self.weight = nn.Parameter(torch.ones(normalized_shape))
180 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
181 | self.normalized_shape = normalized_shape
182 |
183 | def forward(self, x):
184 | mu = x.mean(-1, keepdim=True)
185 | sigma = x.var(-1, keepdim=True, unbiased=False)
186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
187 |
188 |
189 | def to_3d(x):
190 | return rearrange(x, 'b c h w -> b (h w) c')
191 |
192 | def to_4d(x,h,w):
193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
194 |
195 | class LayerNorm(nn.Module):
196 | def __init__(self, dim):
197 | super(LayerNorm, self).__init__()
198 | self.body = WithBias_LayerNorm(dim)
199 |
200 | def forward(self, x):
201 | h, w = x.shape[-2:]
202 | return to_4d(self.body(to_3d(x)), h, w)
203 |
204 |
205 | class Block(nn.Module):
206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
207 | drop_path=0., act_layer=nn.GELU):
208 | super().__init__()
209 | self.norm1 = LayerNorm(dim)
210 | kernel_size1 = 1
211 | padding1 = 0
212 | kernel_size2 = 3
213 | padding2 = 1
214 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1)
215 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2)
216 |
217 | self.norm2 = LayerNorm(dim)
218 | self.norm3 = LayerNorm(dim)
219 | mlp_hidden_dim = int(dim*1)
220 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim)
221 |
222 | def forward(self, x):
223 | res = x
224 | x, k1, v1 = self.attn(x)
225 | x = res + self.norm1(x)
226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True))
227 | x = x + self.norm3(self.mlp(x))
228 | return x
229 |
230 |
231 | class Scale(nn.Module):
232 | def __init__(self, init_value=1e-3):
233 | super().__init__()
234 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
235 |
236 | def forward(self, input):
237 | return input * self.scale
238 |
239 |
240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
241 | act_type = act_type.lower()
242 | if act_type == 'relu':
243 | layer = nn.ReLU()
244 | elif act_type == 'lrelu':
245 | layer = nn.LeakyReLU(neg_slope)
246 | elif act_type == 'prelu':
247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
248 | else:
249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
250 | return layer
251 |
252 |
253 | class eca_layer(nn.Module):
254 | def __init__(self, channel, k_size):
255 | super(eca_layer, self).__init__()
256 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
257 | self.k_size = k_size
258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
259 | self.sigmoid = nn.Sigmoid()
260 |
261 |
262 | def forward(self, x):
263 | b, c, _, _ = x.size()
264 | y = self.avg_pool(x)
265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
267 | y = self.sigmoid(y)
268 | x = x * y.expand_as(x)
269 | return x
270 |
271 |
272 | class MaskPredictor(nn.Module):
273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)):
274 | super(MaskPredictor,self).__init__()
275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False)
276 |
277 | def forward(self,x):
278 | spa_mask=self.spatial_mask(x)
279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1)
280 | return spa_mask
281 |
282 |
283 | class RB(nn.Module):
284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
285 | super(RB, self).__init__()
286 | self.CA = eca_layer(n_feats, k_size=3)
287 | self.MaskPredictor = MaskPredictor(n_feats*8//8)
288 |
289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
290 | nn.LeakyReLU(0.05),
291 | )
292 |
293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
294 | nn.LeakyReLU(0.05),
295 | )
296 |
297 | self.res_scale = Scale(1)
298 | self.x_scale = Scale(1)
299 |
300 | def forward(self, x):
301 | res = x
302 | x = self.k(x)
303 |
304 | MaskPredictor = self.MaskPredictor(x)
305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1)
306 | x = x * (mask.expand_as(x))
307 |
308 | x1 = self.k1(x)
309 | x2 = self.CA(x1)
310 | out = self.x_scale(x2) + self.res_scale(res)
311 |
312 | return out
313 |
314 |
315 | class SCConv(nn.Module):
316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
317 | super(SCConv, self).__init__()
318 | pooling_r = 2
319 | med_feats = n_feats // 1
320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True),
321 | nn.LeakyReLU(0.05),
322 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1),
323 | )
324 |
325 | self.sig = nn.Sigmoid()
326 |
327 | self.k3 = RB(n_feats)
328 |
329 | self.k4 = RB(n_feats)
330 |
331 | self.k5 = RB(n_feats)
332 |
333 | self.res_scale = Scale(1)
334 | self.x_scale = Scale(1)
335 |
336 | def forward(self, x):
337 | identity = x
338 | _, _, H, W = identity.shape
339 | x1_1 = self.k3(x)
340 | x1 = self.k4(x1_1)
341 |
342 |
343 | x1_s = self.sig(self.k1(x) + x)
344 | x1 = self.k5(x1_s * x1)
345 |
346 | out = self.res_scale(x1) + self.x_scale(identity)
347 |
348 | return out
349 |
350 |
351 | class FCUUp(nn.Module):
352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)):
354 | super(FCUUp, self).__init__()
355 | self.up_stride = up_stride
356 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
357 | self.act = act_layer()
358 |
359 | def forward(self, x_t):
360 | x_r = self.act(self.conv_project(x_t))
361 |
362 | return x_r
363 |
364 | class FCUDown(nn.Module):
365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)):
367 | super(FCUDown, self).__init__()
368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
369 |
370 | def forward(self, x):
371 | x = self.conv_project(x)
372 |
373 | return x
374 |
375 |
376 | class ConvBlock(nn.Module):
377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None):
378 | super(ConvBlock, self).__init__()
379 |
380 | expansion = 1
381 | med_planes = outplanes // expansion
382 | embed_dim = 144
383 | num_heads = 8
384 | mlp_ratio = 1.0
385 |
386 | self.rb_search1 = SCConv(med_planes)
387 | self.rb_search2 = SCConv(med_planes)
388 | self.rb_search3 = SCConv(med_planes)
389 | self.rb_search4 = SCConv(med_planes)
390 |
391 | self.trans_block = Block(
392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
393 | drop=0., attn_drop=0., drop_path=0.)
394 |
395 | self.trans_block1 = Block(
396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
397 | drop=0., attn_drop=0., drop_path=0.)
398 |
399 | self.trans_block2 = Block(
400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
401 | drop=0., attn_drop=0., drop_path=0.)
402 |
403 | self.trans_block3 = Block(
404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
405 | drop=0., attn_drop=0., drop_path=0.)
406 |
407 | self.trans_block4 = Block(
408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
409 | drop=0., attn_drop=0., drop_path=0.)
410 |
411 | self.trans_block5 = Block(
412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
413 | drop=0., attn_drop=0., drop_path=0.)
414 |
415 | self.trans_block6 = Block(
416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
417 | drop=0., attn_drop=0., drop_path=0.)
418 |
419 | self.trans_block7 = Block(
420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
421 | drop=0., attn_drop=0., drop_path=0.)
422 |
423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
431 |
432 | self.res_scale = Scale(1)
433 | self.x_scale = Scale(1)
434 | self.num_rbs = 1
435 |
436 | self.res_conv = res_conv
437 | self.drop_block = drop_block
438 | self.drop_path = drop_path
439 |
440 | def zero_init_last_bn(self):
441 | nn.init.zeros_(self.bn3.weight)
442 |
443 | def forward(self, x):
444 | residual = x
445 |
446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x
447 |
448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x
449 |
450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x
451 |
452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x
453 |
454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x
455 |
456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x
457 |
458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x
459 |
460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x
461 |
462 | x = self.x_scale(x) + self.res_scale(residual)
463 |
464 | return x
465 |
466 |
467 | class ConvTransBlock(nn.Module):
468 | """
469 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
470 | """
471 |
472 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio,
473 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
474 | last_fusion=False, num_med_block=0, groups=1):
475 | super(ConvTransBlock, self).__init__()
476 | expansion = 1
477 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
478 |
479 | self.dw_stride = dw_stride
480 | self.embed_dim = embed_dim
481 | self.num_med_block = num_med_block
482 | self.last_fusion = last_fusion
483 | self.res_scale = Scale(1)
484 | self.x_scale = Scale(1)
485 |
486 | def forward(self, x):
487 | x = self.cnn_block(x)
488 |
489 | return x
490 |
491 |
492 | class MODEL(nn.Module):
493 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1.,
494 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0.,
495 | patch_norm=True):
496 | super(MODEL, self).__init__()
497 | scale = args.scale
498 | n_feats = 48
499 | n_colors = 3
500 | embed_dim = 64
501 |
502 | self.patch_norm = patch_norm
503 | self.num_features = embed_dim
504 | rgb_mean = (0.4488, 0.4371, 0.4040)
505 | rgb_std = (1.0, 1.0, 1.0)
506 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std)
507 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1)
508 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1)
509 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
510 |
511 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule
512 |
513 | # 2~final Stage
514 | init_stage = 2
515 | fin_stage = 3
516 | stage_1_channel = n_feats
517 | trans_dw_stride = patch_size
518 | for i in range(init_stage, fin_stage):
519 | if i%2==0:
520 | m = i
521 | else:
522 | m = i-1
523 | self.add_module('conv_trans_' + str(m),
524 | ConvTransBlock(
525 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride,
526 | embed_dim=embed_dim,
527 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
528 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
529 | drop_path_rate=self.trans_dpr[i - 1],
530 | num_med_block=num_med_block
531 | )
532 | )
533 |
534 | self.fin_stage = fin_stage
535 | self.dw_stride = trans_dw_stride
536 |
537 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
538 |
539 | m = []
540 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
541 | m.append(nn.PixelShuffle(scale[0]))
542 | self.UP1 = nn.Sequential(*m)
543 |
544 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
545 | up_body = []
546 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
547 | up_body.append(nn.PixelShuffle(scale[0]))
548 | self.UP2 = nn.Sequential(*up_body)
549 |
550 | self.apply(self._init_weights)
551 |
552 | def _init_weights(self, m):
553 | if isinstance(m, nn.Linear):
554 | trunc_normal_(m.weight, std=.02)
555 | if isinstance(m, nn.Linear) and m.bias is not None:
556 | nn.init.constant_(m.bias, 0)
557 | elif isinstance(m, nn.LayerNorm):
558 | nn.init.constant_(m.bias, 0)
559 | nn.init.constant_(m.weight, 1.0)
560 |
561 | def forward(self, x):
562 | (H, W) = (x.shape[2], x.shape[3])
563 | residual = x
564 | x = self.sub_mean(x)
565 | x = self.conv_first_cnn(x)
566 |
567 | for i in range(2, self.fin_stage):
568 | if i%2==0:
569 | m = i
570 | else:
571 | m = i-1
572 | x = eval('self.conv_trans_' + str(m))(x)
573 |
574 | x = self.conv_after_body(x)
575 | y1 = self.UP1(x)
576 | y2 = self.UP2(self.conv_stright(residual))
577 | output = self.add_mean(y1 + y2)
578 |
579 | return output
--------------------------------------------------------------------------------
/model/MultiAdd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from model import common
5 | from timm.models.layers import trunc_normal_
6 | import torch.nn.functional as F
7 | from pdb import set_trace as stx
8 | import numbers
9 | from einops import rearrange
10 | from torch.nn.parameter import Parameter
11 | from torch.autograd import Variable
12 | from IPython import embed
13 | import cv2
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 |
17 |
18 | class Mlp(nn.Module):
19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
20 | super().__init__()
21 | out_features = out_features or in_features
22 | hidden_features = hidden_features or in_features
23 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0)
24 | self.act = act_layer()
25 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1)
26 | self.drop = nn.Dropout(drop)
27 |
28 | def forward(self, x):
29 | x = self.fc1(x)
30 | x = self.act(x)
31 | x = self.drop(x)
32 | x = self.fc2(x)
33 | x = self.drop(x)
34 | return x
35 |
36 | class PatchEmbed(nn.Module):
37 | def __init__(self, img_size, patch_size=4, in_chans=32, embed_dim=32, norm_layer=None):
38 | super().__init__()
39 | img_size = to_2tuple(img_size)
40 | patch_size = to_2tuple(patch_size)
41 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
42 | self.img_size = img_size
43 | self.patch_size = patch_size
44 | self.patches_resolution = patches_resolution
45 | self.num_patches = patches_resolution[0] * patches_resolution[1]
46 |
47 | self.embed_dim = embed_dim
48 | self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
49 |
50 | if norm_layer is not None: ## norm_layer=None
51 | self.norm = norm_layer(embed_dim)
52 | else:
53 | self.norm = None
54 |
55 | def forward(self, x):
56 | B, C, H, W = x.shape
57 | x = self.proj(x)
58 | return x
59 |
60 | class Conv2d_CG(nn.Conv2d):
61 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1,
62 | bias=True):
63 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
64 |
65 | # weight & bias for content-gated-convolution
66 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True)
67 | #Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size),
68 | # requires_grad=True).cuda()
69 | self.bias_conv = Parameter(torch.Tensor(out_channels))#Parameter(torch.zeros(out_channels), requires_grad=True)
70 | nn.init.kaiming_normal_(self.weight_conv)
71 |
72 | self.stride = stride
73 | self.padding = padding
74 | self.dilation = dilation
75 | self.groups = groups
76 |
77 | # for convolutional layers with a kernel size of 1, just use traditional convolution
78 | if kernel_size == 0:
79 | self.ind = True
80 | else:
81 | self.ind = False
82 | self.oc = out_channels
83 | self.ks = kernel_size
84 |
85 | # target spatial size of the pooling layer
86 | ws = kernel_size
87 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws))
88 |
89 | # the dimension of latent representation
90 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1)
91 |
92 | # the context encoding module
93 | self.ce = nn.Linear(ws * ws, self.num_lat, False)
94 | #self.ce_bn = nn.BatchNorm1d(in_channels)
95 | #self.ci_bn2 = nn.BatchNorm1d(in_channels)
96 |
97 | self.act = nn.ReLU()
98 |
99 | # the number of groups in the channel interaction module
100 | if in_channels // 8:
101 | self.g = 8
102 | else:
103 | self.g = in_channels
104 |
105 | # the channel interacting module
106 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False)
107 | #self.ci_bn = nn.BatchNorm1d(out_channels)
108 |
109 | # the gate decoding module (spatial interaction)
110 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
111 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False)
112 |
113 | # used to prepare the input feature map to patches
114 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride)
115 |
116 | # sigmoid function
117 | self.sig = nn.Sigmoid()
118 |
119 | def forward(self, x):
120 | # for convolutional layers with a kernel size of 1, just use the traditional convolution
121 | if self.ind:
122 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups)
123 | else:
124 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch
125 | weight = self.weight_conv
126 |
127 | # allocate global information
128 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9
129 |
130 | # context-encoding module
131 | out = self.ce(gl) # out: batch x n_feat x 5
132 |
133 | # use different bn for following two branches
134 | ce2 = out # ce2: batch x n_feat x 5
135 | #out = self.ce_bn(out)
136 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization)
137 |
138 | # gate decoding branch 1 (spatial interaction)
139 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3)
140 |
141 | # channel interacting module
142 | if self.g > 3:
143 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous()
144 | else:
145 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous()
146 | oc = oc.view(b, self.oc, -1)
147 | #oc = self.ci_bn(oc)
148 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer)
149 |
150 | # gate decoding branch 2 (spatial interaction)
151 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3)
152 |
153 | # produce gate (equation (4) in the CRAN paper)
154 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks))
155 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight)
156 |
157 | # unfolding input feature map to patches
158 | x_un = self.unfold(x)
159 | b, _, l = x_un.size()
160 | #embed()
161 | # gating
162 | out = (out * weight.unsqueeze(0))#.to(device)
163 | out = out.view(b, self.oc, -1)
164 |
165 | # currently only handle square input and output
166 | return torch.matmul(out, x_un).view(b, self.oc, h, w)
167 |
168 | class ConvAttention(nn.Module):
169 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
170 | super().__init__()
171 | self.num_heads = num_heads
172 | head_dim = dim // num_heads
173 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True)
174 | self.to_qkv = Conv2d_CG(dim, dim*3)
175 |
176 | def forward(self, x, k1=None, v1=None, return_x=False):
177 | weight = self.weight
178 | b,c,h,w = x.shape
179 |
180 | qkv = self.to_qkv(x)
181 | q, k, v = qkv.chunk(3, dim=1)
182 |
183 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
184 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
185 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
186 |
187 | if k1 is None:
188 | k = k
189 | v = v
190 | else:
191 | k = k1 + k
192 | v = v1 + v
193 | q = torch.nn.functional.normalize(q, dim=-1)
194 | k = torch.nn.functional.normalize(k, dim=-1)
195 |
196 | attn = (q @ k.transpose(-2, -1)) * weight
197 | attn = attn.softmax(dim=-1)
198 | x = (attn @ v)
199 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
200 |
201 | if return_x:
202 | return x
203 | else:
204 | return x, k, v
205 |
206 |
207 | class WithBias_LayerNorm(nn.Module):
208 | def __init__(self, normalized_shape):
209 | super(WithBias_LayerNorm, self).__init__()
210 | if isinstance(normalized_shape, numbers.Integral):
211 | normalized_shape = (normalized_shape,)
212 | normalized_shape = torch.Size(normalized_shape)
213 |
214 | assert len(normalized_shape) == 1
215 |
216 | self.weight = nn.Parameter(torch.ones(normalized_shape))
217 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
218 | self.normalized_shape = normalized_shape
219 |
220 | def forward(self, x):
221 | mu = x.mean(-1, keepdim=True)
222 | sigma = x.var(-1, keepdim=True, unbiased=False)
223 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
224 |
225 |
226 | def to_3d(x):
227 | return rearrange(x, 'b c h w -> b (h w) c')
228 |
229 | def to_4d(x,h,w):
230 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
231 |
232 | class LayerNorm(nn.Module):
233 | def __init__(self, dim):
234 | super(LayerNorm, self).__init__()
235 | self.body = WithBias_LayerNorm(dim)
236 |
237 | def forward(self, x):
238 | h, w = x.shape[-2:]
239 | return to_4d(self.body(to_3d(x)), h, w)
240 |
241 |
242 | class Block(nn.Module):
243 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
244 | drop_path=0., act_layer=nn.GELU):
245 | super().__init__()
246 | self.norm1 = LayerNorm(dim)
247 | kernel_size1 = 1
248 | padding1 = 0
249 | kernel_size2 = 3
250 | padding2 = 1
251 | #self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
252 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1)
253 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2)
254 |
255 | self.norm2 = LayerNorm(dim)
256 | self.norm3 = LayerNorm(dim)
257 | mlp_hidden_dim = int(dim*1)
258 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim)
259 |
260 | def forward(self, x):
261 | res = x
262 | x, k1, v1 = self.attn(x)
263 | x = res + self.norm1(x)
264 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True))
265 | x = x + self.norm3(self.mlp(x))
266 | return x
267 |
268 |
269 | class Scale(nn.Module):
270 | def __init__(self, init_value=1e-3):
271 | super().__init__()
272 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
273 |
274 | def forward(self, input):
275 | return input * self.scale
276 |
277 |
278 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1):
279 | act_type = act_type.lower()
280 | if act_type == 'relu':
281 | layer = nn.ReLU()
282 | elif act_type == 'lrelu':
283 | layer = nn.LeakyReLU(neg_slope)
284 | elif act_type == 'prelu':
285 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
286 | else:
287 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
288 | return layer
289 |
290 |
291 | class eca_layer(nn.Module):
292 | def __init__(self, channel, k_size):
293 | super(eca_layer, self).__init__()
294 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
295 | self.k_size = k_size
296 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel)
297 | self.sigmoid = nn.Sigmoid()
298 |
299 |
300 | def forward(self, x):
301 | b, c, _, _ = x.size()
302 | y = self.avg_pool(x)
303 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2))
304 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1)
305 | y = self.sigmoid(y)
306 | x = x * y.expand_as(x)
307 | return x
308 |
309 |
310 | class MaskPredictor(nn.Module):
311 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)):
312 | super(MaskPredictor,self).__init__()
313 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False)
314 |
315 | def forward(self,x):
316 | spa_mask=self.spatial_mask(x)
317 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1)
318 | return spa_mask
319 |
320 |
321 | class RB(nn.Module):
322 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
323 | super(RB, self).__init__()
324 | self.CA = eca_layer(n_feats, k_size=3)
325 | self.MaskPredictor = MaskPredictor(n_feats*8//8)
326 |
327 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
328 | nn.LeakyReLU(0.05),
329 | )
330 |
331 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)),
332 | nn.LeakyReLU(0.05),
333 | )
334 |
335 | self.res_scale = Scale(1)
336 | self.x_scale = Scale(1)
337 |
338 | def forward(self, x):
339 | res = x
340 | x = self.k(x)
341 |
342 | MaskPredictor = self.MaskPredictor(x)
343 | mask = (MaskPredictor[:,1,...]).unsqueeze(1)
344 | x = x * (mask.expand_as(x))
345 |
346 | x1 = self.k1(x)
347 | x2 = self.CA(x1)
348 | out = self.x_scale(x2) + self.res_scale(res)
349 |
350 | return out
351 |
352 |
353 | class SCConv(nn.Module):
354 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)):
355 | super(SCConv, self).__init__()
356 | pooling_r = 2
357 | med_feats = n_feats // 1
358 | #self.k = nn.Conv2d(n_feats, n_feats, kernel_size=1, stride=2, padding=0, groups=1)
359 |
360 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True),
361 | nn.LeakyReLU(0.05),
362 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1),
363 | )
364 |
365 | self.sig = nn.Sigmoid()
366 |
367 | self.k3 = RB(n_feats)
368 |
369 | self.k4 = RB(n_feats)
370 |
371 | self.k5 = RB(n_feats)
372 |
373 | self.res_scale = Scale(1)
374 | self.x_scale = Scale(1)
375 |
376 | def forward(self, x):
377 | identity = x
378 | _, _, H, W = identity.shape
379 | x1_1 = self.k3(x)
380 | x1 = self.k4(x1_1)
381 |
382 |
383 | x1_s = self.sig(self.k1(x) + x)
384 | x1 = self.k5(x1_s * x1)
385 |
386 | out = self.res_scale(x1) + self.x_scale(identity)
387 |
388 | return out
389 |
390 |
391 | class FCUUp(nn.Module):
392 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
393 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)):
394 | super(FCUUp, self).__init__()
395 | self.up_stride = up_stride
396 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
397 | self.act = act_layer()
398 |
399 | def forward(self, x_t):
400 | x_r = self.act(self.conv_project(x_t))
401 |
402 | return x_r
403 |
404 | class FCUDown(nn.Module):
405 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
406 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)):
407 | super(FCUDown, self).__init__()
408 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0))
409 |
410 | def forward(self, x):
411 | x = self.conv_project(x)
412 |
413 | return x
414 |
415 |
416 | class ConvBlock(nn.Module):
417 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None):
418 | super(ConvBlock, self).__init__()
419 |
420 | expansion = 1
421 | med_planes = outplanes // expansion
422 | embed_dim = 144
423 | num_heads = 8
424 | mlp_ratio = 1.0
425 |
426 | self.rb_search1 = SCConv(med_planes)
427 | self.rb_search2 = SCConv(med_planes)
428 | self.rb_search3 = SCConv(med_planes)
429 | self.rb_search4 = SCConv(med_planes)
430 |
431 | self.trans_block = Block(
432 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
433 | drop=0., attn_drop=0., drop_path=0.)
434 |
435 | self.trans_block1 = Block(
436 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
437 | drop=0., attn_drop=0., drop_path=0.)
438 |
439 | self.trans_block2 = Block(
440 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
441 | drop=0., attn_drop=0., drop_path=0.)
442 |
443 | self.trans_block3 = Block(
444 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
445 | drop=0., attn_drop=0., drop_path=0.)
446 |
447 | self.trans_block4 = Block(
448 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
449 | drop=0., attn_drop=0., drop_path=0.)
450 |
451 | self.trans_block5 = Block(
452 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
453 | drop=0., attn_drop=0., drop_path=0.)
454 |
455 | self.trans_block6 = Block(
456 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
457 | drop=0., attn_drop=0., drop_path=0.)
458 |
459 | self.trans_block7 = Block(
460 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None,
461 | drop=0., attn_drop=0., drop_path=0.)
462 |
463 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
464 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
465 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
466 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
467 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
468 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
469 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1)
470 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1)
471 |
472 | self.res_scale = Scale(1)
473 | self.x_scale = Scale(1)
474 | self.num_rbs = 1
475 |
476 | self.res_conv = res_conv
477 | self.drop_block = drop_block
478 | self.drop_path = drop_path
479 |
480 | def zero_init_last_bn(self):
481 | nn.init.zeros_(self.bn3.weight)
482 |
483 | def forward(self, x):
484 | residual = x
485 |
486 | x = self.rb_search1(x)
487 | x = self.squeeze_block(self.trans_block(self.expand_block(x))) + x
488 |
489 | x = self.rb_search1(x)
490 | x = self.squeeze_block(self.trans_block1(self.expand_block(x))) + x
491 |
492 | x = self.rb_search2(x)
493 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(x))) + x
494 |
495 | x = self.rb_search2(x)
496 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(x))) + x
497 |
498 | x = self.rb_search3(x)
499 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(x))) + x
500 |
501 | x = self.rb_search3(x)
502 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(x))) + x
503 |
504 | x = self.rb_search4(x)
505 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(x))) + x
506 |
507 | x = self.rb_search4(x)
508 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(x))) + x
509 |
510 | x = self.x_scale(x) + self.res_scale(residual)
511 |
512 | return x
513 |
514 |
515 | class ConvTransBlock(nn.Module):
516 | """
517 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
518 | """
519 |
520 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio,
521 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
522 | last_fusion=False, num_med_block=0, groups=1):
523 | super(ConvTransBlock, self).__init__()
524 | expansion = 1
525 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
526 |
527 | #self.fusion_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups)
528 |
529 | self.dw_stride = dw_stride
530 | self.embed_dim = embed_dim
531 | self.num_med_block = num_med_block
532 | self.last_fusion = last_fusion
533 | self.res_scale = Scale(1)
534 | self.x_scale = Scale(1)
535 |
536 | def forward(self, x):
537 | x = self.cnn_block(x)
538 |
539 | #x = self.fusion_block(x)
540 |
541 | return x
542 |
543 |
544 | class MODEL(nn.Module):
545 | def __init__(self, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1.,
546 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0.,
547 | patch_norm=True):
548 | super(MODEL, self).__init__()
549 | scale = [4]
550 | n_feats = 48
551 | n_colors = 3
552 | embed_dim = 64
553 | #height = (1024 // scale // window_size + 1) * window_size
554 | #width = (720 // scale // window_size + 1) * window_size
555 | #img_size = (height, width)
556 |
557 | self.patch_norm = patch_norm
558 | self.num_features = embed_dim
559 | rgb_mean = (0.4488, 0.4371, 0.4040)
560 | rgb_std = (1.0, 1.0, 1.0)
561 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std)
562 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1)
563 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1)
564 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
565 |
566 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule
567 |
568 | # 2~final Stage
569 | init_stage = 2
570 | fin_stage = 3
571 | stage_1_channel = n_feats
572 | trans_dw_stride = patch_size
573 | for i in range(init_stage, fin_stage):
574 | if i%2==0:
575 | m = i
576 | else:
577 | m = i-1
578 | self.add_module('conv_trans_' + str(m),
579 | ConvTransBlock(
580 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride,
581 | embed_dim=embed_dim,
582 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
583 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
584 | drop_path_rate=self.trans_dpr[i - 1],
585 | num_med_block=num_med_block
586 | )
587 | )
588 |
589 | self.fin_stage = fin_stage
590 | self.dw_stride = trans_dw_stride
591 |
592 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
593 |
594 | m = []
595 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
596 | m.append(nn.PixelShuffle(scale[0]))
597 | self.UP1 = nn.Sequential(*m)
598 |
599 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
600 | up_body = []
601 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1))
602 | up_body.append(nn.PixelShuffle(scale[0]))
603 | self.UP2 = nn.Sequential(*up_body)
604 |
605 | self.apply(self._init_weights)
606 |
607 | def _init_weights(self, m):
608 | if isinstance(m, nn.Linear):
609 | trunc_normal_(m.weight, std=.02)
610 | if isinstance(m, nn.Linear) and m.bias is not None:
611 | nn.init.constant_(m.bias, 0)
612 | elif isinstance(m, nn.LayerNorm):
613 | nn.init.constant_(m.bias, 0)
614 | nn.init.constant_(m.weight, 1.0)
615 |
616 | def forward(self, x):
617 | (H, W) = (x.shape[2], x.shape[3])
618 | residual = x
619 | x = self.sub_mean(x)
620 | x = self.conv_first_cnn(x)
621 | #x_t = x
622 |
623 | for i in range(2, self.fin_stage):
624 | if i%2==0:
625 | m = i
626 | else:
627 | m = i-1
628 | x = eval('self.conv_trans_' + str(m))(x)
629 |
630 | x = self.conv_after_body(x)
631 | y1 = self.UP1(x)
632 | y2 = self.UP2(self.conv_stright(residual))
633 | output = self.add_mean(y1 + y2)
634 |
635 | return output
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import model.CFINx2
6 | import model.CFINx3
7 | import model.CFINx4
8 |
9 | def dataparallel(model, gpu_list):
10 | ngpus = len(gpu_list)
11 | assert ngpus != 0, "only support gpu mode"
12 | assert torch.cuda.device_count() >= ngpus, "Invalid Number of GPUs"
13 | assert isinstance(model, list), "Invalid Type of Dual model"
14 | for i in range(len(model)):
15 | if ngpus >= 2:
16 | model[i] = nn.DataParallel(model[i], gpu_list).cuda()
17 | else:
18 | model[i] = model[i].cuda()
19 | return model
20 |
21 |
22 | class Model(nn.Module):
23 | def __init__(self, opt, ckp):
24 | super(Model, self).__init__()
25 | print('Making model...')
26 | self.opt = opt
27 | self.scale = opt.scale
28 | self.idx_scale = 0
29 | self.self_ensemble = opt.self_ensemble
30 | self.cpu = opt.cpu
31 | self.device = torch.device('cpu' if opt.cpu else 'cuda')
32 | self.n_GPUs = opt.n_GPUs
33 |
34 | self.chop = opt.chop
35 | self.precision = opt.precision
36 |
37 | self.model = CFINx4.make_model(opt).to(self.device)
38 |
39 | if not opt.cpu and opt.n_GPUs > 1:
40 | self.model = nn.DataParallel(self.model, range(opt.n_GPUs))
41 |
42 | self.load(opt.pre_train,cpu=opt.cpu)
43 |
44 | # compute parameter
45 | # num_parameter = self.count_parameters(self.model)
46 | # ckp.write_log(f"The number of parameters is {num_parameter / 1000 ** 2:.2f}M")
47 | # ckp.write_log(f"The number of parameters is {num_parameter:.2f}")
48 |
49 | def forward(self, x, idx_scale=0):
50 | self.idx_scale = idx_scale
51 | target = self.get_model()
52 | if hasattr(target, 'set_scale'):
53 | target.set_scale(idx_scale)
54 |
55 | if self.self_ensemble and not self.training:
56 | if self.chop:
57 | forward_function = self.forward_chop
58 | else:
59 | forward_function = self.model.forward
60 |
61 | return self.forward_x8(x, forward_function)
62 | elif self.chop and not self.training:
63 | return self.forward_chop(x)
64 | else:
65 | return self.model(x)
66 |
67 | def get_model(self):
68 | if self.n_GPUs == 1:
69 | return self.model
70 | else:
71 | return self.model.module
72 |
73 |
74 | def state_dict(self, **kwargs):
75 | target = self.get_model()
76 | return target.state_dict(**kwargs)
77 |
78 | def count_parameters(self, model):
79 | if self.opt.n_GPUs > 1:
80 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
81 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
82 |
83 | def save(self, path, is_best=False):
84 | target = self.get_model()
85 | torch.save(
86 | target.state_dict(),
87 | os.path.join(path, 'model', 'model_latest.pt')
88 | )
89 | if is_best:
90 | torch.save(
91 | target.state_dict(),
92 | os.path.join(path, 'model', 'model_best.pt')
93 | )
94 |
95 | def load(self, pre_train='.',cpu=False):
96 | if cpu:
97 | kwargs = {'map_location': lambda storage, loc: storage}
98 | else:
99 | kwargs = {}
100 | #### load primal model ####
101 | if pre_train != '.':
102 | print('Loading model from {}'.format(pre_train))
103 | self.get_model().load_state_dict(
104 | torch.load(pre_train, **kwargs),
105 | strict=False
106 | )
107 |
108 | def forward_chop(self, x, shave=10, min_size=160000):
109 | scale = self.scale[self.idx_scale]
110 | n_GPUs = min(self.n_GPUs, 4)
111 | b, c, h, w = x.size()
112 | h_half, w_half = h // 2, w // 2
113 | h_size, w_size = h_half + shave, w_half + shave
114 | lr_list = [
115 | x[:, :, 0:h_size, 0:w_size],
116 | x[:, :, 0:h_size, (w - w_size):w],
117 | x[:, :, (h - h_size):h, 0:w_size],
118 | x[:, :, (h - h_size):h, (w - w_size):w]]
119 |
120 | if w_size * h_size < min_size:
121 | sr_list = []
122 | for i in range(0, 4, n_GPUs):
123 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
124 | sr_batch = self.model(lr_batch)
125 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
126 | else:
127 | sr_list = [
128 | self.forward_chop(patch, shave=shave, min_size=min_size) \
129 | for patch in lr_list
130 | ]
131 |
132 | h, w = scale * h, scale * w
133 | h_half, w_half = scale * h_half, scale * w_half
134 | h_size, w_size = scale * h_size, scale * w_size
135 | shave *= scale
136 |
137 | output = x.new(b, c, h, w)
138 | output[:, :, 0:h_half, 0:w_half] \
139 | = sr_list[0][:, :, 0:h_half, 0:w_half]
140 | output[:, :, 0:h_half, w_half:w] \
141 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
142 | output[:, :, h_half:h, 0:w_half] \
143 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
144 | output[:, :, h_half:h, w_half:w] \
145 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
146 |
147 | return output
148 |
149 | def forward_x8(self, x, forward_function):
150 | def _transform(v, op):
151 | if self.precision != 'single': v = v.float()
152 |
153 | v2np = v.data.cpu().numpy()
154 | if op == 'v':
155 | tfnp = v2np[:, :, :, ::-1].copy()
156 | elif op == 'h':
157 | tfnp = v2np[:, :, ::-1, :].copy()
158 | elif op == 't':
159 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
160 |
161 | ret = torch.Tensor(tfnp).to(self.device)
162 | if self.precision == 'half': ret = ret.half()
163 |
164 | return ret
165 |
166 | lr_list = [x]
167 | for tf in 'v', 'h', 't':
168 | lr_list.extend([_transform(t, tf) for t in lr_list])
169 |
170 | sr_list = [forward_function(aug) for aug in lr_list]
171 | for i in range(len(sr_list)):
172 | if i > 3:
173 | sr_list[i] = _transform(sr_list[i], 't')
174 | if i % 4 > 1:
175 | sr_list[i] = _transform(sr_list[i], 'h')
176 | if (i % 4) % 2 == 1:
177 | sr_list[i] = _transform(sr_list[i], 'v')
178 |
179 | output_cat = torch.cat(sr_list, dim=0)
180 | output = output_cat.mean(dim=0, keepdim=True)
181 |
182 | return output
--------------------------------------------------------------------------------
/model/__pycache__/CFIN.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFIN.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/CFINx2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx2.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/CFINx3.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx3.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/CFINx4.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx4.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/MSDNN_LW1.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MSDNN_LW1.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/MsDNN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MsDNN.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/MultiAdd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MultiAdd.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/MultiAdd.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MultiAdd.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/msfin3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/msfin3.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/no_transmy.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/no_transmy.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/rebuttal_updown1.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/rebuttal_updown1.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy3.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy5.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy5.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy5_0.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy5_0.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy_j.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_j.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy_j.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_j.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transmy_jx4.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_jx4.cpython-39.pyc
--------------------------------------------------------------------------------
/model/common.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/common.py
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import utility
3 | import numpy as np
4 |
5 | parser = argparse.ArgumentParser(description='DRN')
6 |
7 | parser.add_argument('--n_threads', type=int, default=8,
8 | help='number of threads for data loading')
9 | parser.add_argument('--cpu', action='store_true',
10 | help='use cpu only')
11 | parser.add_argument('--n_GPUs', type=int, default=1,
12 | help='number of GPUs')
13 | parser.add_argument('--seed', type=int, default=1,
14 | help='random seed')
15 | parser.add_argument('--data_dir', type=str, default='/home2/wenjieli/datasets',
16 | help='dataset directory')
17 |
18 | parser.add_argument('--data_train', type=str, default='DIV2K',
19 | help='train dataset name')
20 | parser.add_argument('--data_test', type=str, default='Set5',
21 | help='test dataset name')
22 | parser.add_argument('--data_range', type=str, default='1-800/896-900',
23 | help='train/test data range')
24 | parser.add_argument('--scale', type=int, default=2,
25 | help='super resolution scale')
26 | parser.add_argument('--patch_size', type=int, default=192, #192,144,96
27 | help='output patch size')
28 | parser.add_argument('--rgb_range', type=int, default=255,
29 | help='maximum value of RGB')
30 | parser.add_argument('--n_colors', type=int, default=3,
31 | help='number of color channels to use')
32 | parser.add_argument('--no_augment', action='store_true',
33 | help='do not use data augmentation')
34 | parser.add_argument('--model', help='model name: DRN-S | DRN-L', required=True)
35 | parser.add_argument('--pre_train', type=str, default='.',
36 | help='pre-trained model directory')
37 | parser.add_argument('--pre_train_dual', type=str, default='.',
38 | help='pre-trained dual model directory')
39 | parser.add_argument('--n_blocks', type=int, default=30,
40 | help='number of residual blocks, 16|30|40|80')
41 | parser.add_argument('--n_feats', type=int, default=20,
42 | help='number of feature maps')
43 |
44 | parser.add_argument('--chop', action='store_true',
45 | help='enable memory-efficient forward')
46 | parser.add_argument('--precision', type=str, default='single',
47 | choices=('single', 'half'),
48 | help='FP precision for test (single | half)')
49 |
50 | parser.add_argument('--num_steps', type=int, default=10,
51 | help='number of RCAB')
52 | parser.add_argument('--negval', type=float, default=0.2,
53 | help='Negative value parameter for Leaky ReLU')
54 | parser.add_argument('--test_every', type=int, default=1000,
55 | help='do test per every N batches')
56 | parser.add_argument('--epochs', type=int, default=1000,
57 | help='number of epochs to train')
58 | parser.add_argument('--batch_size', type=int, default=16,
59 | help='input batch size for training')
60 | parser.add_argument('--self_ensemble', action='store_true',
61 | help='use self-ensemble method for test')
62 | parser.add_argument('--test_only', action='store_true',
63 | help='set this option to test the model')
64 | parser.add_argument('--lr', type=float, default=5e-4,
65 | help='learning rate')
66 | parser.add_argument('--eta_min', type=float, default=1e-7,
67 | help='eta_min lr')
68 |
69 | parser.add_argument('--beta1', type=float, default=0.9,
70 | help='ADAM beta1')
71 | parser.add_argument('--beta2', type=float, default=0.999,
72 | help='ADAM beta2')
73 | parser.add_argument('--epsilon', type=float, default=1e-8,
74 | help='ADAM epsilon for numerical stability')
75 | parser.add_argument('--weight_decay', type=float, default=0,
76 | help='weight decay')
77 | parser.add_argument('--loss', type=str, default='1*L1',
78 | help='loss function configuration, L1|MSE')
79 | parser.add_argument('--skip_threshold', type=float, default='1e6',
80 | help='skipping batch that has large error')
81 | parser.add_argument('--dual_weight', type=float, default=0.1,
82 | help='the weight of dual loss')
83 | parser.add_argument('--save', type=str, default='./experiment/test/',
84 | help='file name to save')
85 | parser.add_argument('--print_every', type=int, default=100,
86 | help='how many batches to wait before logging training status')
87 | parser.add_argument('--save_results', action='store_true',
88 | help='save output results')
89 |
90 | parser.add_argument('--load', type=str, default='.',
91 | help='file name to load')
92 | parser.add_argument('--resume', type=int, default=0,
93 | help='resume from specific checkpoint')
94 | parser.add_argument("--ext", type=str, default='.npy')
95 | parser.add_argument("--n_train", type=int, default=800,
96 | help="number of training set")
97 | parser.add_argument("--cuda", action="store_true", default=True,
98 | help="use cuda")
99 |
100 |
101 | args = parser.parse_args()
102 |
103 | utility.init_model(args)
104 |
105 | args.scale = [args.scale]
106 |
107 | for arg in vars(args):
108 | if vars(args)[arg] == 'True':
109 | vars(args)[arg] = True
110 | elif vars(args)[arg] == 'False':
111 | vars(args)[arg] = False
112 |
113 |
--------------------------------------------------------------------------------
/test_summary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchsummaryX import summary
3 | import model.MultiAdd as MultiAdd
4 | model = MultiAdd.MODEL()
5 |
6 | summary(model, torch.zeros((1, 3, 320, 180)))#HR:1280 x 720
7 |
8 | # input LR x2, HR size is 720p
9 | # summary(model, torch.zeros((1, 3, 640, 360)))
10 |
11 | # input LR x3, HR size is 720p
12 | # summary(model, torch.zeros((1, 3, 426, 240)))
13 |
14 | # input LR x4, HR size is 720p
15 | # summary(model, torch.zeros((1, 3, 320, 180)))
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utility
3 | from decimal import Decimal
4 | from tqdm import tqdm
5 |
6 | class Trainer():
7 | def __init__(self, opt, loader, my_model, my_loss, ckp):
8 | self.opt = opt
9 | self.scale = opt.scale
10 | self.ckp = ckp
11 | self.loader_train = loader.loader_train
12 | self.loader_test = loader.loader_test
13 | self.model = my_model
14 | self.loss = my_loss
15 | self.optimizer = utility.make_optimizer(opt, self.model)
16 | self.scheduler = utility.make_scheduler(opt, self.optimizer)
17 | self.error_last = 1e8
18 |
19 | def train(self):
20 | epoch = self.scheduler.last_epoch + 1
21 | lr = self.scheduler.get_lr()[0]
22 |
23 | self.ckp.write_log(
24 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
25 | )
26 | self.loss.start_log()
27 | self.model.train()
28 | timer_data, timer_model = utility.timer(), utility.timer()
29 | for batch, (lr, hr, _) in enumerate(self.loader_train):
30 | lr, hr = self.prepare(lr, hr)#GPU
31 | timer_data.hold()
32 | timer_model.tic()
33 |
34 | self.optimizer.zero_grad()
35 |
36 | sr = self.model(lr[0])
37 | loss = self.loss(sr, hr)
38 |
39 | if loss.item() < self.opt.skip_threshold * self.error_last:
40 | loss.backward()
41 | self.optimizer.step()
42 | else:
43 | print('Skip this batch {}! (Loss: {})'.format(
44 | batch + 1, loss.item()
45 | ))
46 |
47 | timer_model.hold()
48 |
49 | if (batch + 1) % self.opt.print_every == 0:
50 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
51 | (batch + 1) * self.opt.batch_size,
52 | len(self.loader_train.dataset),
53 | self.loss.display_loss(batch),
54 | timer_model.release(),
55 | timer_data.release()))
56 |
57 | timer_data.tic()
58 |
59 | self.loss.end_log(len(self.loader_train))
60 | self.error_last = self.loss.log[-1, -1]
61 | self.step()
62 |
63 | def test(self):
64 | epoch = self.scheduler.last_epoch
65 | self.ckp.write_log('\nEvaluation:')
66 | self.ckp.add_log(torch.zeros(1, 1))
67 | self.model.eval()
68 |
69 | timer_test = utility.timer()
70 | with torch.no_grad():
71 | scale = max(self.scale)
72 | for si, s in enumerate([scale]):
73 | eval_psnr = 0
74 | tqdm_test = tqdm(self.loader_test, ncols=80)
75 | for _, (lr, hr, filename) in enumerate(tqdm_test):
76 | filename = filename[0]
77 | no_eval = (hr.nelement() == 1)
78 | if not no_eval:
79 | lr, hr = self.prepare(lr, hr)
80 | else:
81 | lr, = self.prepare(lr)
82 |
83 | sr = self.model(lr[0])
84 | if isinstance(sr, list): sr = sr[-1]
85 |
86 | sr = utility.quantize(sr, self.opt.rgb_range)
87 |
88 | if not no_eval:
89 | eval_psnr += utility.calc_psnr(
90 | sr, hr, s, self.opt.rgb_range,
91 | benchmark=self.loader_test.dataset.benchmark
92 | )
93 |
94 | # save test results
95 | if self.opt.save_results:
96 | self.ckp.save_results_nopostfix(filename, sr, s)
97 |
98 | self.ckp.log[-1, si] = eval_psnr / len(self.loader_test)
99 | best = self.ckp.log.max(0)
100 | self.ckp.write_log(
101 | '[{} x{}]\tPSNR: {:.2f} (Best: {:.2f} @epoch {})'.format(
102 | self.opt.data_test, s,
103 | self.ckp.log[-1, si],
104 | best[0][si],
105 | best[1][si] + 1
106 | )
107 | )
108 |
109 | self.ckp.write_log(
110 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
111 | )
112 | if not self.opt.test_only:
113 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
114 |
115 | def step(self):
116 | self.scheduler.step()
117 |
118 | def prepare(self, *args):
119 | device = torch.device('cpu' if self.opt.cpu else 'cuda')
120 |
121 | if len(args)>1:
122 | return [a.to(device) for a in args[0]], args[-1].to(device)
123 | return [a.to(device) for a in args[0]],
124 |
125 | def terminate(self):
126 | if self.opt.test_only:
127 | self.test()
128 | return True
129 | else:
130 | epoch = self.scheduler.last_epoch
131 | return epoch >= self.opt.epochs
132 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | import torch.optim.lr_scheduler as lrs
8 |
9 |
10 | def set_seed(seed):
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | if torch.cuda.device_count() == 1:
15 | torch.cuda.manual_seed(seed)
16 | else:
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | class timer():
21 | def __init__(self):
22 | self.acc = 0
23 | self.tic()
24 |
25 | def tic(self):
26 | self.t0 = time.time()
27 |
28 | def toc(self):
29 | return time.time() - self.t0
30 |
31 | def hold(self):
32 | self.acc += self.toc()
33 |
34 | def release(self):
35 | ret = self.acc
36 | self.acc = 0
37 |
38 | return ret
39 |
40 | def reset(self):
41 | self.acc = 0
42 |
43 |
44 | def quantize(img, rgb_range):
45 | pixel_range = 255 / rgb_range
46 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
47 |
48 |
49 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
50 | if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
51 | print("the dimention of sr image is not equal to hr's! ")
52 | sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
53 | diff = (sr - hr).data.div(rgb_range)
54 |
55 | if benchmark:
56 | shave = scale
57 | if diff.size(1) > 1:
58 | convert = diff.new(1, 3, 1, 1)
59 | convert[0, 0, 0, 0] = 65.738
60 | convert[0, 1, 0, 0] = 129.057
61 | convert[0, 2, 0, 0] = 25.064
62 | diff.mul_(convert).div_(256)
63 | diff = diff.sum(dim=1, keepdim=True)
64 | else:
65 | shave = scale + 6
66 |
67 | valid = diff[:, :, shave:-shave, shave:-shave]
68 | mse = valid.pow(2).mean()
69 |
70 | return -10 * math.log10(mse)
71 |
72 |
73 | def make_optimizer(opt, my_model):
74 | trainable = filter(lambda x: x.requires_grad, my_model.parameters())
75 | optimizer_function = optim.Adam
76 | kwargs = {
77 | 'betas': (opt.beta1, opt.beta2),
78 | 'eps': opt.epsilon
79 | }
80 | kwargs['lr'] = opt.lr
81 | kwargs['weight_decay'] = opt.weight_decay
82 |
83 | return optimizer_function(trainable, **kwargs)
84 |
85 |
86 | def make_scheduler(opt, my_optimizer):
87 | scheduler = lrs.CosineAnnealingLR(
88 | my_optimizer,
89 | float(opt.epochs),
90 | eta_min=opt.eta_min
91 | )
92 |
93 | return scheduler
94 |
95 |
96 |
97 | def init_model(args):
98 | '''
99 | if args.model.find('MSFIN3') >= 0:
100 | if args.scale == 4:
101 | args.num_steps = 1
102 | args.n_feats = 24
103 | args.patch_size = 192
104 | elif args.scale == 8:
105 | args.n_blocks = 30
106 | args.n_feats = 8
107 | else:
108 | print('Use defaults n_blocks and n_feats.')
109 | # args.dual = True
110 | '''
111 | if args.model.find('TRANSMY5') >= 0:
112 | if args.scale == 4:
113 | args.num_steps = 1
114 | args.n_feats = 32
115 | args.patch_size = 192
116 | elif args.scale == 8:
117 | args.n_blocks = 30
118 | args.n_feats = 8
119 | else:
120 | print('Use defaults n_blocks and n_feats.')
121 |
122 |
--------------------------------------------------------------------------------