├── README.md ├── __pycache__ ├── checkpoint.cpython-36.pyc ├── checkpoint.cpython-39.pyc ├── loss.cpython-36.pyc ├── loss.cpython-39.pyc ├── option.cpython-36.pyc ├── option.cpython-39.pyc ├── trainer.cpython-36.pyc ├── trainer.cpython-39.pyc ├── utility.cpython-35.pyc ├── utility.cpython-36.pyc └── utility.cpython-39.pyc ├── checkpoint.py ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── benchmark.cpython-36.pyc │ ├── benchmark.cpython-39.pyc │ ├── common.cpython-36.pyc │ ├── common.cpython-39.pyc │ ├── df2k.cpython-36.pyc │ ├── div2k.cpython-36.pyc │ ├── div2k.cpython-39.pyc │ ├── srdata.cpython-36.pyc │ └── srdata.cpython-39.pyc ├── benchmark.py ├── common.py ├── div2k.py └── srdata.py ├── experiments └── CFIN │ └── model │ ├── model_best_x2.pt │ ├── model_best_x3.pt │ └── model_best_x4.pt ├── img └── compare.png ├── loss.py ├── main.py ├── model ├── CFIN.py ├── CFINx2.py ├── CFINx3.py ├── CFINx4.py ├── MultiAdd.py ├── __init__.py ├── __pycache__ │ ├── CFIN.cpython-39.pyc │ ├── CFINx2.cpython-39.pyc │ ├── CFINx3.cpython-39.pyc │ ├── CFINx4.cpython-39.pyc │ ├── MSDNN_LW1.cpython-36.pyc │ ├── MsDNN.cpython-36.pyc │ ├── MultiAdd.cpython-36.pyc │ ├── MultiAdd.cpython-39.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── common.cpython-36.pyc │ ├── common.cpython-39.pyc │ ├── msfin3.cpython-36.pyc │ ├── no_transmy.cpython-39.pyc │ ├── rebuttal_updown1.cpython-39.pyc │ ├── transmy.cpython-36.pyc │ ├── transmy3.cpython-36.pyc │ ├── transmy5.cpython-36.pyc │ ├── transmy5_0.cpython-36.pyc │ ├── transmy_j.cpython-36.pyc │ ├── transmy_j.cpython-39.pyc │ └── transmy_jx4.cpython-39.pyc └── common.py ├── option.py ├── test_summary.py ├── trainer.py └── utility.py /README.md: -------------------------------------------------------------------------------- 1 | ## CFIN-Pytorch 2 | This repository is an official PyTorch implementation of our paper "Cross-receptive Focused Inference Network for Lightweight Image Super-Resolution". 3 | 4 | --- 5 | 6 | ### CFIN: Cross-receptive Focused Inference Network for Lightweight Image Super-Resolution. (IEEE TRANSACTIONS ON MULTIMEDIA, 2023) 7 | 8 | > [[Paper(arxiv)](https://arxiv.org/abs/2207.02796)]   [[Paper(IEEE)](https://ieeexplore.ieee.org/document/10114600)]   [[Code](https://github.com/24wenjie-li/CFIN)]   9 | 10 |

11 | 12 |

13 | 14 | --- 15 | 16 | ## Prerequisites: 17 | ``` 18 | 1. Python >= 3.6 19 | 2. PyTorch >= 1.2 20 | 3. numpy 21 | 4. skimage 22 | 5. imageio 23 | 6. tqdm 24 | 7. timm 25 | 8. einops 26 | ``` 27 | 28 | ## Dataset 29 | We used only DIV2K dataset to train our model. To speed up the data reading during training, we converted the format of the images within the dataset from png to npy. 30 | Please download the DIV2K_decoded with npy format from here.[Quark Netdisk] 31 | 32 | The test set contains five datasets, Set5, Set14, B100, Urban100, Manga109. The benchmark can be downloaded from here.[Baidu Netdisk][Password:8888] 33 | 34 | Extract the file and place it in the same location as args.data_dir in option.py. 35 | 36 | The code and datasets need satisfy the following structures: 37 | ``` 38 | ├── CFIN # Train / Test Code 39 | ├── dataset # all datasets for this code 40 | | └── DIV2K_decoded # train datasets with npy format 41 | | | └── DIV2K_train_HR 42 | | | └── DIV2K_train_LR_bicubic 43 | | └── benchmark # test datasets with png format 44 | | | └── Set5 45 | | | └── Set14 46 | | | └── B100 47 | | | └── Urban100 48 | | | └── Manga109 49 | ───────────────── 50 | ``` 51 | 52 | 53 | ## Results 54 | All our SR Results can be downloaded from here.[Baidu Netdisk][Password:8888] 55 | 56 | All pretrained model can be found in experiments/CFIN/model/. 57 | 58 | ## Training 59 | Note:You need to manually import the name of the model to be trained/tested in line 37 of model/init.py. 60 | ``` 61 | # CFIN x2 62 | python main.py --scale 2 --model CFINx2 --patch_size 96 --save experiments/CFINx2 63 | 64 | # CFIN x3 65 | python main.py --scale 3 --model CFINx3 --patch_size 144 --save experiments/CFINx3 66 | 67 | # CFIN x4 68 | python main.py --scale 4 --model CFINx4 --patch_size 192 --save experiments/CFINx4 69 | ``` 70 | **Somethings you need to know:** 71 | Since the training/testing code we provide is the initial code, the module names within the code are not consistent with those within the paper. For this reason, we provide model/CFIN.py to help readers understand the paper better. 72 | 73 | ## Testing 74 | Note:Since the PSNR/SSIM values in our paper are obtained from the Matlab program, the data obtained using the python code may have a floating error of 0.01 dB in the PSNR. The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to here. (You need to modify the test path!) 75 | ``` 76 | # CFIN x2 77 | python main.py --scale 2 --model CFINx2 --save test_results/CFINx2 --pre_train experiments/CFIN/model/model_best_x2.pt --test_only --save_results --data_test Set5 78 | 79 | # CFIN x3 80 | python main.py --scale 3 --model CFINx3 --save test_results/CFINx3 --pre_train experiments/CFIN/model/model_best_x3.pt --test_only --save_results --data_test Set5 81 | 82 | # CFIN x4 83 | python main.py --scale 4 --model CFINx4 --save test_results/CFINx4 --pre_train experiments/CFIN/model/model_best_x4.pt --test_only --save_results --data_test Set5 84 | 85 | # CFIN+ x2 with self-ensemble strategy 86 | python main.py --scale 2 --model CFINx2 --save test_results/CFINx2 --pre_train experiments/CFIN/model/model_best_x2.pt --test_only --save_results --chop --self_ensemble --data_test Set5 87 | 88 | # CFIN+ x3 with self-ensemble strategy 89 | python main.py --scale 3 --model CFINx3 --save test_results/CFINx3 --pre_train experiments/CFIN/model/model_best_x3.pt --test_only --save_results --chop --self_ensemble --data_test Set5 90 | 91 | # CFIN+ x4 with self-ensemble strategy 92 | python main.py --scale 4 --model CFINx4 --save test_results/CFINx4 --pre_train experiments/CFIN/model/model_best_x4.pt --test_only --save_results --chop --self_ensemble --data_test Set5 93 | ``` 94 | 95 | ## Test Parmas and Muti-adds 96 | Note:You need to install torchsummaryX! 97 | ``` 98 | # Default CFINx4 99 | python test_summary.py 100 | ``` 101 | 102 | ## Performance 103 | Our CFIN is trained on RGB, but as in previous work, we only reported PSNR/SSIM on the Y channel. 104 | 105 | Model|Scale|Params|Multi-adds|Set5|Set14|B100|Urban100|Manga109 106 | --|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--: 107 | CFIN |x2|675K|116.9G|38.14/0.9610|33.80/0.9199|32.26/0.9006|32.48/0.9311|38.97/0.9777 108 | CFIN |x3|681K|53.5G|34.65/0.9289|30.45/0.8443|29.18/0.8071|28.49/0.8583|33.89/0.9464 109 | CFIN |x4|699K|31.2G|32.49/0.8985|28.74/0.7849|27.68/0.7396|26.39/0.7946|30.73/0.9124 110 | 111 | ## Some extra questions 112 | You can download the supplementary materials on issues requested by reviewers from here.[Baidu Netdisk][Password:8888] 113 | 114 | ## Acknowledgements 115 | This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [DRN](https://github.com/guoyongcs/DRN). We thank the authors for sharing their codes. 116 | 117 | 118 | ## :clipboard: Citation 119 | 120 | ``` 121 | @article{li2023cross, 122 | title={Cross-receptive focused inference network for lightweight image super-resolution}, 123 | author={Li, Wenjie and Li, Juncheng and Gao, Guangwei and Deng, Weihong and Zhou, Jiantao and Yang, Jian and Qi, Guo-Jun}, 124 | journal={IEEE Transactions on Multimedia}, 125 | year={2023}, 126 | publisher={IEEE} 127 | } 128 | ``` 129 | 130 | ## :e-mail: Contact 131 | 132 | If you have any question, please email `lewj2408@gmail.com` 133 | -------------------------------------------------------------------------------- /__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/option.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/option.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utility.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utility.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/__pycache__/utility.cpython-39.pyc -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import numpy as np 5 | import imageio 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Checkpoint(): 12 | def __init__(self, opt): 13 | self.opt = opt 14 | self.ok = True 15 | self.log = torch.Tensor() 16 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 17 | 18 | if opt.save == '.': opt.save = '../experiment/EXP/' + now 19 | self.dir = opt.save 20 | 21 | def _make_dir(path): 22 | if not os.path.exists(path): os.makedirs(path) 23 | 24 | _make_dir(self.dir) 25 | _make_dir(self.dir + '/model') 26 | _make_dir(self.dir + '/results') 27 | 28 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 29 | self.log_file = open(self.dir + '/log.txt', open_type) 30 | with open(self.dir + '/config.txt', open_type) as f: 31 | f.write(now + '\n\n') 32 | for arg in vars(opt): 33 | f.write('{}: {}\n'.format(arg, getattr(opt, arg))) 34 | f.write('\n') 35 | 36 | def save(self, trainer, epoch, is_best=False): 37 | trainer.model.save(self.dir, is_best=is_best) 38 | trainer.loss.save(self.dir) 39 | trainer.loss.plot_loss(self.dir, epoch) 40 | 41 | self.plot_psnr(epoch) 42 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 43 | torch.save( 44 | trainer.optimizer.state_dict(), 45 | os.path.join(self.dir, 'optimizer.pt') 46 | ) 47 | # dual_optimizers = {} 48 | # for i in range(len(trainer.dual_optimizers)): 49 | # dual_optimizers[i] = trainer.dual_optimizers[i] 50 | # torch.save( 51 | # dual_optimizers, 52 | # os.path.join(self.dir, 'dual_optimizers.pt') 53 | # ) 54 | 55 | def add_log(self, log): 56 | self.log = torch.cat([self.log, log]) 57 | 58 | def write_log(self, log, refresh=False): 59 | print(log) 60 | self.log_file.write(log + '\n') 61 | if refresh: 62 | self.log_file.close() 63 | self.log_file = open(self.dir + '/log.txt', 'a') 64 | 65 | def done(self): 66 | self.log_file.close() 67 | 68 | def plot_psnr(self, epoch): 69 | axis = np.linspace(1, epoch, epoch) 70 | label = 'SR on {}'.format(self.opt.data_test) 71 | fig = plt.figure() 72 | plt.title(label) 73 | for idx_scale, scale in enumerate([self.opt.scale[0]]): 74 | plt.plot( 75 | axis, 76 | self.log[:, idx_scale].numpy(), 77 | label='Scale {}'.format(scale) 78 | ) 79 | plt.legend() 80 | plt.xlabel('Epochs') 81 | plt.ylabel('PSNR') 82 | plt.grid(True) 83 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.opt.data_test)) 84 | plt.close(fig) 85 | 86 | def save_results_nopostfix(self, filename, sr, scale): 87 | apath = '{}/results/{}/x{}'.format(self.dir, self.opt.data_test, scale) 88 | if not os.path.exists(apath): 89 | os.makedirs(apath) 90 | filename = os.path.join(apath, filename) 91 | 92 | normalized = sr[0].data.mul(255 / self.opt.rgb_range) 93 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 94 | imageio.imwrite('{}.png'.format(filename), ndarr) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | class Data: 6 | def __init__(self, args): 7 | self.loader_train = None 8 | if not args.test_only: 9 | module_train = import_module('data.' + args.data_train.lower()) 10 | trainset = getattr(module_train, args.data_train)(args) 11 | self.loader_train = DataLoader( 12 | trainset, 13 | batch_size=args.batch_size, 14 | num_workers=args.n_threads, 15 | shuffle=True, 16 | pin_memory=not args.cpu 17 | ) 18 | 19 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'Look']: 20 | module_test = import_module('data.benchmark') 21 | testset = getattr(module_test, 'Benchmark')(args, name=args.data_test, train=False) 22 | else: 23 | module_test = import_module('data.' + args.data_test.lower()) 24 | testset = getattr(module_test, args.data_test)(args, train=False) 25 | 26 | self.loader_test = DataLoader( 27 | testset, 28 | batch_size=1, 29 | num_workers=1, 30 | shuffle=False, 31 | pin_memory=not args.cpu 32 | ) 33 | 34 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/benchmark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/benchmark.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/benchmark.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/benchmark.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/df2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/df2k.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/div2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/div2k.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/div2k.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/div2k.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/srdata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/srdata.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/srdata.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/data/__pycache__/srdata.cpython-39.pyc -------------------------------------------------------------------------------- /data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | 5 | class Benchmark(srdata.SRData1): 6 | def __init__(self, args, name='', train=True, benchmark=True): 7 | super(Benchmark, self).__init__( 8 | args, name=name, train=train, benchmark=True 9 | ) 10 | 11 | def _set_filesystem(self, data_dir): 12 | 13 | self.apath = os.path.join(data_dir, 'benchmark', self.name) 14 | self.dir_hr = os.path.join(self.apath, 'HR') 15 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 16 | self.ext = ('', '.png') 17 | 18 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import skimage.color as sc 4 | import torch 5 | 6 | 7 | def get_patch(*args, patch_size=96, scale=[2], multi_scale=False): 8 | th, tw = args[-1].shape[:2] # target images size 9 | 10 | tp = patch_size # patch size of target hr image 11 | ip = [patch_size // s for s in scale] # patch size of lr images 12 | 13 | # tx and ty are the top and left coordinate of the patch 14 | tx = random.randrange(0, tw - tp + 1) 15 | ty = random.randrange(0, th - tp + 1) 16 | tx, ty = tx- tx % scale[0], ty - ty % scale[0] 17 | ix, iy = [ tx // s for s in scale], [ty // s for s in scale] 18 | 19 | lr = [args[0][i][iy[i]:iy[i] + ip[i], ix[i]:ix[i] + ip[i], :] for i in range(len(scale))] 20 | hr = args[-1][ty:ty + tp, tx:tx + tp, :] 21 | 22 | return [lr, hr] 23 | 24 | def set_channel(*args, n_channels=3): 25 | def _set_channel(img): 26 | if img.ndim == 2: 27 | img = np.expand_dims(img, axis=2) 28 | 29 | c = img.shape[2] 30 | if n_channels == 1 and c == 3: 31 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 32 | elif n_channels == 3 and c == 1: 33 | img = np.concatenate([img] * n_channels, 2) 34 | 35 | return img 36 | 37 | return [_set_channel(a) for a in args[0]], _set_channel(args[-1]) 38 | 39 | 40 | def np2Tensor(*args, rgb_range=255): 41 | def _np2Tensor(img): 42 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 43 | tensor = torch.from_numpy(np_transpose).float() 44 | tensor.mul_(rgb_range / 255) 45 | 46 | return tensor 47 | 48 | return [_np2Tensor(a) for a in args[0]], _np2Tensor(args[1]) 49 | 50 | 51 | def augment(*args, hflip=True, rot=True): 52 | hflip = hflip and random.random() < 0.5 53 | vflip = rot and random.random() < 0.5 54 | rot90 = rot and random.random() < 0.5 55 | 56 | def _augment(img): 57 | if hflip: img = img[:, ::-1, :] 58 | if vflip: img = img[::-1, :, :] 59 | if rot90: img = img.transpose(1, 0, 2) 60 | 61 | return img 62 | 63 | return [_augment(a) for a in args[0]], _augment(args[-1]) 64 | 65 | -------------------------------------------------------------------------------- /data/div2k.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | from data import srdata 5 | 6 | 7 | class DIV2K(srdata.SRData): 8 | def __init__(self, args, name='DIV2K_decoded', train=True, benchmark=False): 9 | super(DIV2K, self).__init__( 10 | args, name=name, train=train, benchmark=benchmark 11 | ) 12 | 13 | def _set_filesystem(self, data_dir): 14 | super(DIV2K, self)._set_filesystem(data_dir) 15 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 16 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 17 | 18 | 19 | -------------------------------------------------------------------------------- /data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from data import common 4 | import numpy as np 5 | import imageio 6 | import torch.utils.data as data 7 | import cv2 8 | 9 | def default_loader(path): 10 | return cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, [2, 1, 0]] 11 | 12 | def npy_loader(path): 13 | return np.load(path, allow_pickle=True) 14 | 15 | class SRData(data.Dataset): 16 | def __init__(self, args, name='', train=True, benchmark=False): 17 | self.args = args 18 | self.name = name 19 | self.train = train 20 | self.split = 'train' if train else 'test' 21 | self.do_eval = True 22 | self.benchmark = benchmark 23 | self.scale = args.scale.copy() 24 | self.scale.reverse() 25 | self.ext1 = args.ext 26 | 27 | self._set_filesystem(args.data_dir) 28 | self._get_imgs_path(args) 29 | self._set_dataset_length() 30 | 31 | def __getitem__(self, idx): 32 | lr, hr, filename = self._load_file(idx) 33 | 34 | lr, hr = self.get_patch(lr, hr) 35 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors) 36 | 37 | lr_tensor, hr_tensor = common.np2Tensor( 38 | lr, hr, rgb_range=self.args.rgb_range 39 | ) 40 | 41 | return lr_tensor, hr_tensor, filename 42 | 43 | def __len__(self): 44 | return self.dataset_length 45 | 46 | def _get_imgs_path(self, args): 47 | list_hr, list_lr = self._scan() 48 | self.images_hr, self.images_lr = list_hr, list_lr 49 | 50 | def _set_dataset_length(self): 51 | if self.train: 52 | self.dataset_length = self.args.test_every * self.args.batch_size 53 | repeat = self.dataset_length // len(self.images_hr) 54 | self.random_border = len(self.images_hr) * repeat 55 | else: 56 | self.dataset_length = len(self.images_hr) 57 | 58 | def _scan(self): 59 | names_hr = sorted( 60 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[1])) 61 | ) 62 | names_lr = [[] for _ in self.scale] 63 | for f in names_hr: 64 | filename, _ = os.path.splitext(os.path.basename(f)) 65 | for si, s in enumerate(self.scale): 66 | names_lr[si].append(os.path.join( 67 | self.dir_lr, 'X{}/{}x{}{}'.format( 68 | s, filename, s, self.ext[1] 69 | ) 70 | )) 71 | 72 | return names_hr, names_lr 73 | 74 | def _set_filesystem(self, data_dir): 75 | self.apath = os.path.join(data_dir, self.name) 76 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 77 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 78 | self.ext = ('.png', '.npy') 79 | 80 | def _get_index(self, idx): 81 | if self.train: 82 | if idx < self.random_border: 83 | return idx % len(self.images_hr) 84 | else: 85 | return np.random.randint(len(self.images_hr)) 86 | else: 87 | return idx 88 | 89 | def _load_file(self, idx): 90 | idx = self._get_index(idx) 91 | f_hr = self.images_hr[idx] 92 | f_lr = [self.images_lr[idx_scale][idx] for idx_scale in range(len(self.scale))] 93 | 94 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 95 | if self.ext1 == '.npy': 96 | lr = [npy_loader(f_lr[idx_scale]) for idx_scale in range(len(self.scale))] 97 | hr = npy_loader(f_hr) 98 | if not self.train: 99 | hr = default_loader(f_hr) 100 | lr = [default_loader(f_lr[idx_scale]) for idx_scale in range(len(self.scale))] 101 | return lr, hr, filename 102 | 103 | def get_patch(self, lr, hr): 104 | scale = self.scale 105 | multi_scale = len(self.scale) > 1 106 | if self.train: 107 | lr, hr = common.get_patch( 108 | lr, 109 | hr, 110 | patch_size=self.args.patch_size, 111 | scale=scale, 112 | multi_scale=multi_scale 113 | ) 114 | if not self.args.no_augment: 115 | lr, hr = common.augment(lr, hr) 116 | else: 117 | if isinstance(lr, list): 118 | ih, iw = lr[0].shape[:2] 119 | else: 120 | ih, iw = lr.shape[:2] 121 | hr = hr[0:ih * scale[0], 0:iw * scale[0]] 122 | 123 | return lr, hr 124 | 125 | ############################################################################## 126 | class SRData1(data.Dataset): 127 | def __init__(self, args, name='', train=True, benchmark=False): 128 | self.args = args 129 | self.name = name 130 | self.train = train 131 | self.split = 'train' if train else 'test' 132 | self.do_eval = True 133 | self.benchmark = benchmark 134 | self.scale = args.scale.copy() 135 | self.scale.reverse() 136 | self.ext1 = args.ext 137 | 138 | self._set_filesystem(args.data_dir) 139 | self._get_imgs_path(args) 140 | self._set_dataset_length() 141 | 142 | def __getitem__(self, idx): 143 | lr, hr, filename = self._load_file(idx) 144 | 145 | lr, hr = self.get_patch(lr, hr) 146 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors) 147 | 148 | lr_tensor, hr_tensor = common.np2Tensor( 149 | lr, hr, rgb_range=self.args.rgb_range 150 | ) 151 | 152 | return lr_tensor, hr_tensor, filename 153 | 154 | def __len__(self): 155 | return self.dataset_length 156 | 157 | def _get_imgs_path(self, args): 158 | list_hr, list_lr = self._scan() 159 | self.images_hr, self.images_lr = list_hr, list_lr 160 | 161 | def _set_dataset_length(self): 162 | if self.train: 163 | self.dataset_length = self.args.test_every * self.args.batch_size 164 | repeat = self.dataset_length // len(self.images_hr) 165 | self.random_border = len(self.images_hr) * repeat 166 | else: 167 | self.dataset_length = len(self.images_hr) 168 | 169 | def _scan(self): 170 | names_hr = sorted( 171 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[1])) 172 | ) 173 | 174 | names_lr = [[] for _ in self.scale] 175 | for f in names_hr: 176 | filename, _ = os.path.splitext(os.path.basename(f)) 177 | for si, s in enumerate(self.scale): 178 | names_lr[si].append(os.path.join( 179 | self.dir_lr, 'X{}/{}x{}{}'.format( 180 | s, filename, s, self.ext[1] 181 | ) 182 | )) 183 | 184 | return names_hr, names_lr 185 | 186 | def _set_filesystem(self, data_dir): 187 | self.apath = os.path.join(data_dir, 'DIV2K') 188 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 189 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 190 | self.ext = ('.png', '.png') 191 | 192 | def _get_index(self, idx): 193 | if self.train: 194 | if idx < self.random_border: 195 | return idx % len(self.images_hr) 196 | else: 197 | return np.random.randint(len(self.images_hr)) 198 | else: 199 | return idx 200 | 201 | def _load_file(self, idx): 202 | idx = self._get_index(idx) 203 | f_hr = self.images_hr[idx] 204 | f_lr = [self.images_lr[idx_scale][idx] for idx_scale in range(len(self.scale))] 205 | 206 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 207 | hr = imageio.imread(f_hr) 208 | lr = [imageio.imread(f_lr[idx_scale]) for idx_scale in range(len(self.scale))] 209 | return lr, hr, filename 210 | 211 | def get_patch(self, lr, hr): 212 | scale = self.scale 213 | multi_scale = len(self.scale) > 1 214 | if self.train: 215 | lr, hr = common.get_patch( 216 | lr, 217 | hr, 218 | patch_size=self.args.patch_size, 219 | scale=scale, 220 | multi_scale=multi_scale 221 | ) 222 | if not self.args.no_augment: 223 | lr, hr = common.augment(lr, hr) 224 | else: 225 | if isinstance(lr, list): 226 | ih, iw = lr[0].shape[:2] 227 | else: 228 | ih, iw = lr.shape[:2] 229 | hr = hr[0:ih * scale[0], 0:iw * scale[0]] 230 | 231 | return lr, hr 232 | 233 | -------------------------------------------------------------------------------- /experiments/CFIN/model/model_best_x2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x2.pt -------------------------------------------------------------------------------- /experiments/CFIN/model/model_best_x3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x3.pt -------------------------------------------------------------------------------- /experiments/CFIN/model/model_best_x4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/experiments/CFIN/model/model_best_x4.pt -------------------------------------------------------------------------------- /img/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/img/compare.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Loss(nn.modules.loss._Loss): 11 | def __init__(self, args, ckp): 12 | super(Loss, self).__init__() 13 | 14 | self.loss = [] 15 | self.loss_module = nn.ModuleList() 16 | for loss in args.loss.split('+'): 17 | weight, loss_type = loss.split('*') 18 | 19 | if loss_type == 'MSE': 20 | loss_function = nn.MSELoss() 21 | elif loss_type == 'L1': 22 | loss_function = nn.L1Loss(reduction='mean') 23 | else: 24 | assert False, f"Unsupported loss type: {loss_type:s}" 25 | 26 | self.loss.append({ 27 | 'type': loss_type, 28 | 'weight': float(weight), 29 | 'function': loss_function} 30 | ) 31 | 32 | if len(self.loss) > 1: 33 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 34 | 35 | for l in self.loss: 36 | if l['function'] is not None: 37 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 38 | self.loss_module.append(l['function']) 39 | 40 | self.log = torch.Tensor() 41 | 42 | def forward(self, sr, hr): 43 | losses = [] 44 | for i, l in enumerate(self.loss): 45 | if l['function'] is not None: 46 | loss = l['function'](sr, hr) 47 | effective_loss = l['weight'] * loss 48 | losses.append(effective_loss) 49 | self.log[-1, i] += effective_loss.item() 50 | 51 | loss_sum = sum(losses) 52 | if len(self.loss) > 1: 53 | self.log[-1, -1] += loss_sum.item() 54 | 55 | return loss_sum 56 | 57 | def start_log(self): 58 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 59 | 60 | def end_log(self, n_batches): 61 | self.log[-1].div_(n_batches) 62 | 63 | def display_loss(self, batch): 64 | n_samples = batch + 1 65 | log = [] 66 | for l, c in zip(self.loss, self.log[-1]): 67 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 68 | 69 | return ''.join(log) 70 | 71 | def plot_loss(self, apath, epoch): 72 | axis = np.linspace(1, epoch, epoch) 73 | for i, l in enumerate(self.loss): 74 | label = '{} Loss'.format(l['type']) 75 | fig = plt.figure() 76 | plt.title(label) 77 | plt.plot(axis, self.log[:, i].numpy(), label=label) 78 | plt.legend() 79 | plt.xlabel('Epochs') 80 | plt.ylabel('Loss') 81 | plt.grid(True) 82 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) 83 | plt.close(fig) 84 | 85 | def save(self, apath): 86 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 87 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import utility 2 | import data 3 | import model 4 | import loss 5 | from option import args 6 | from checkpoint import Checkpoint 7 | from trainer import Trainer 8 | 9 | utility.set_seed(args.seed) 10 | checkpoint = Checkpoint(args) 11 | 12 | if checkpoint.ok: 13 | loader = data.Data(args) 14 | model = model.Model(args, checkpoint) 15 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 16 | t = Trainer(args, loader, model, loss, checkpoint) 17 | while not t.terminate(): 18 | t.train() 19 | t.test() 20 | checkpoint.done() 21 | 22 | -------------------------------------------------------------------------------- /model/CFIN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from model import common 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import torch.nn.functional as F 7 | from pdb import set_trace as stx 8 | import numbers 9 | from einops import rearrange 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | #from IPython import embed 13 | 14 | def make_model(args, parent=False): 15 | return MODEL(args) 16 | 17 | class IGP(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = CGM(in_features, hidden_features, kernel_size=1, padding=0) 23 | self.act = act_layer() 24 | self.fc2 = CGM(hidden_features, out_features, kernel_size=3, padding=1) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class CGM(nn.Conv2d): 36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1, 37 | bias=True): 38 | super(CGM, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 39 | 40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True) 41 | self.bias_conv = Parameter(torch.Tensor(out_channels)) 42 | nn.init.kaiming_normal_(self.weight_conv) 43 | 44 | self.stride = stride 45 | self.padding = padding 46 | self.dilation = dilation 47 | self.groups = groups 48 | 49 | if kernel_size == 0: 50 | self.ind = True 51 | else: 52 | self.ind = False 53 | self.oc = out_channels 54 | self.ks = kernel_size 55 | 56 | # target spatial size of the pooling layer 57 | ws = kernel_size 58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws)) 59 | 60 | # the dimension of latent representation 61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 62 | 63 | # the context encoding module 64 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 65 | 66 | self.act = nn.ReLU() 67 | 68 | # the number of groups in the channel interaction module 69 | if in_channels // 8: 70 | self.g = 8 71 | else: 72 | self.g = in_channels 73 | 74 | # the channel interacting module 75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 76 | 77 | # the gate decoding module (spatial interaction) 78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 80 | 81 | # used to prepare the input feature map to patches 82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 83 | 84 | # sigmoid function 85 | self.sig = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | if self.ind: 89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups) 90 | else: 91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch 92 | weight = self.weight_conv 93 | 94 | # allocate global information 95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9 96 | 97 | # context-encoding module 98 | out = self.ce(gl) # out: batch x n_feat x 5 99 | 100 | # use different bn for following two branches 101 | ce2 = out # ce2: batch x n_feat x 5 102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization) 103 | 104 | # gate decoding branch 1 (spatial interaction) 105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3) 106 | 107 | # channel interacting module 108 | if self.g > 3: 109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous() 110 | else: 111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous() 112 | oc = oc.view(b, self.oc, -1) 113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer) 114 | 115 | # gate decoding branch 2 (spatial interaction) 116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3) 117 | 118 | # produce gate (equation (4) in the CRAN paper) 119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight) 121 | 122 | # unfolding input feature map to patches 123 | x_un = self.unfold(x) 124 | b, _, l = x_un.size() 125 | out = (out * weight.unsqueeze(0))#.to(device) 126 | out = out.view(b, self.oc, -1) 127 | 128 | # currently only handle square input and output 129 | return torch.matmul(out, x_un).view(b, self.oc, h, w) 130 | 131 | class CGA(nn.Module): 132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 133 | super().__init__() 134 | self.num_heads = num_heads 135 | head_dim = dim // num_heads 136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True) 137 | self.to_qkv = CGM(dim, dim*3) 138 | 139 | def forward(self, x, k1=None, v1=None, return_x=False): 140 | weight = self.weight 141 | b,c,h,w = x.shape 142 | 143 | qkv = self.to_qkv(x) 144 | q, k, v = qkv.chunk(3, dim=1) 145 | 146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 149 | 150 | if k1 is None: 151 | k = k 152 | v = v 153 | else: 154 | k = k1 + k 155 | v = v1 + v 156 | q = torch.nn.functional.normalize(q, dim=-1) 157 | k = torch.nn.functional.normalize(k, dim=-1) 158 | 159 | attn = (q @ k.transpose(-2, -1)) * weight 160 | attn = attn.softmax(dim=-1) 161 | x = (attn @ v) 162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 163 | 164 | if return_x: 165 | return x 166 | else: 167 | return x, k, v 168 | 169 | 170 | class WithBias_LayerNorm(nn.Module): 171 | def __init__(self, normalized_shape): 172 | super(WithBias_LayerNorm, self).__init__() 173 | if isinstance(normalized_shape, numbers.Integral): 174 | normalized_shape = (normalized_shape,) 175 | normalized_shape = torch.Size(normalized_shape) 176 | 177 | assert len(normalized_shape) == 1 178 | 179 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 180 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 181 | self.normalized_shape = normalized_shape 182 | 183 | def forward(self, x): 184 | mu = x.mean(-1, keepdim=True) 185 | sigma = x.var(-1, keepdim=True, unbiased=False) 186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 187 | 188 | 189 | def to_3d(x): 190 | return rearrange(x, 'b c h w -> b (h w) c') 191 | 192 | def to_4d(x,h,w): 193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 194 | 195 | class LayerNorm(nn.Module): 196 | def __init__(self, dim): 197 | super(LayerNorm, self).__init__() 198 | self.body = WithBias_LayerNorm(dim) 199 | 200 | def forward(self, x): 201 | h, w = x.shape[-2:] 202 | return to_4d(self.body(to_3d(x)), h, w) 203 | 204 | 205 | class CFGT(nn.Module): 206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 207 | drop_path=0., act_layer=nn.GELU): 208 | super().__init__() 209 | self.norm1 = LayerNorm(dim) 210 | kernel_size1 = 1 211 | padding1 = 0 212 | kernel_size2 = 3 213 | padding2 = 1 214 | self.attn = CGA(dim, num_heads, kernel_size1, padding1) 215 | self.attn1 = CGA(dim, num_heads, kernel_size2, padding2) 216 | 217 | self.norm2 = LayerNorm(dim) 218 | self.norm3 = LayerNorm(dim) 219 | mlp_hidden_dim = int(dim*1) 220 | self.mlp = IGP(in_features=dim, hidden_features=mlp_hidden_dim) 221 | 222 | def forward(self, x): 223 | res = x 224 | x, k1, v1 = self.attn(x) 225 | x = res + self.norm1(x) 226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True)) 227 | x = x + self.norm3(self.mlp(x)) 228 | return x 229 | 230 | 231 | class Scale(nn.Module): 232 | def __init__(self, init_value=1e-3): 233 | super().__init__() 234 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 235 | 236 | def forward(self, input): 237 | return input * self.scale 238 | 239 | 240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1): 241 | act_type = act_type.lower() 242 | if act_type == 'relu': 243 | layer = nn.ReLU() 244 | elif act_type == 'lrelu': 245 | layer = nn.LeakyReLU(neg_slope) 246 | elif act_type == 'prelu': 247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 248 | else: 249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 250 | return layer 251 | 252 | 253 | class eca_layer(nn.Module): 254 | def __init__(self, channel, k_size): 255 | super(eca_layer, self).__init__() 256 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 257 | self.k_size = k_size 258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel) 259 | self.sigmoid = nn.Sigmoid() 260 | 261 | 262 | def forward(self, x): 263 | b, c, _, _ = x.size() 264 | y = self.avg_pool(x) 265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2)) 266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1) 267 | y = self.sigmoid(y) 268 | x = x * y.expand_as(x) 269 | return x 270 | 271 | 272 | class MaskPredictor(nn.Module): 273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)): 274 | super(MaskPredictor,self).__init__() 275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False) 276 | 277 | def forward(self,x): 278 | spa_mask=self.spatial_mask(x) 279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1) 280 | return spa_mask 281 | 282 | 283 | class RIFU(nn.Module): 284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 285 | super(RIFU, self).__init__() 286 | self.CA = eca_layer(n_feats, k_size=3) 287 | self.MaskPredictor = MaskPredictor(n_feats*8//8) 288 | 289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 290 | nn.LeakyReLU(0.05), 291 | ) 292 | 293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 294 | nn.LeakyReLU(0.05), 295 | ) 296 | 297 | self.res_scale = Scale(1) 298 | self.x_scale = Scale(1) 299 | 300 | def forward(self, x): 301 | res = x 302 | x = self.k(x) 303 | 304 | MaskPredictor = self.MaskPredictor(x) 305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1) 306 | x = x * (mask.expand_as(x)) 307 | 308 | x1 = self.k1(x) 309 | x2 = self.CA(x1) 310 | out = self.x_scale(x2) + self.res_scale(res) 311 | 312 | return out 313 | 314 | 315 | class CIAM(nn.Module): 316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 317 | super(CIAM, self).__init__() 318 | pooling_r = 2 319 | med_feats = n_feats // 1 320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True), 321 | nn.LeakyReLU(0.05), 322 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1), 323 | ) 324 | 325 | self.sig = nn.Sigmoid() 326 | 327 | self.k3 = RIFU(n_feats) 328 | 329 | self.k4 = RIFU(n_feats) 330 | 331 | self.k5 = RIFU(n_feats) 332 | 333 | self.res_scale = Scale(1) 334 | self.x_scale = Scale(1) 335 | 336 | def forward(self, x): 337 | identity = x 338 | _, _, H, W = identity.shape 339 | x1_1 = self.k3(x) 340 | x1 = self.k4(x1_1) 341 | 342 | 343 | x1_s = self.sig(self.k1(x) + x) 344 | x1 = self.k5(x1_s * x1) 345 | 346 | out = self.res_scale(x1) + self.x_scale(identity) 347 | 348 | return out 349 | 350 | 351 | class FCUUp(nn.Module): 352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, 353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)): 354 | super(FCUUp, self).__init__() 355 | self.up_stride = up_stride 356 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) 357 | self.act = act_layer() 358 | 359 | def forward(self, x_t): 360 | x_r = self.act(self.conv_project(x_t)) 361 | 362 | return x_r 363 | 364 | class FCUDown(nn.Module): 365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, 366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)): 367 | super(FCUDown, self).__init__() 368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 369 | 370 | def forward(self, x): 371 | x = self.conv_project(x) 372 | 373 | return x 374 | 375 | 376 | class CTGroup(nn.Module): 377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None): 378 | super(CTGroup, self).__init__() 379 | 380 | expansion = 1 381 | med_planes = outplanes // expansion 382 | embed_dim = 144 383 | num_heads = 8 384 | mlp_ratio = 1.0 385 | 386 | self.rb_search1 = CIAM(med_planes) 387 | self.rb_search2 = CIAM(med_planes) 388 | self.rb_search3 = CIAM(med_planes) 389 | self.rb_search4 = CIAM(med_planes) 390 | 391 | self.trans_block = CFGT( 392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 393 | drop=0., attn_drop=0., drop_path=0.) 394 | 395 | self.trans_block1 = CFGT( 396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 397 | drop=0., attn_drop=0., drop_path=0.) 398 | 399 | self.trans_block2 = CFGT( 400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 401 | drop=0., attn_drop=0., drop_path=0.) 402 | 403 | self.trans_block3 = CFGT( 404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 405 | drop=0., attn_drop=0., drop_path=0.) 406 | 407 | self.trans_block4 = CFGT( 408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 409 | drop=0., attn_drop=0., drop_path=0.) 410 | 411 | self.trans_block5 = CFGT( 412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 413 | drop=0., attn_drop=0., drop_path=0.) 414 | 415 | self.trans_block6 = CFGT( 416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 417 | drop=0., attn_drop=0., drop_path=0.) 418 | 419 | self.trans_block7 = CFGT( 420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 421 | drop=0., attn_drop=0., drop_path=0.) 422 | 423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 431 | 432 | self.res_scale = Scale(1) 433 | self.x_scale = Scale(1) 434 | self.num_rbs = 1 435 | 436 | self.res_conv = res_conv 437 | self.drop_block = drop_block 438 | self.drop_path = drop_path 439 | 440 | def zero_init_last_bn(self): 441 | nn.init.zeros_(self.bn3.weight) 442 | 443 | def forward(self, x): 444 | residual = x 445 | 446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x # 1 CT block 447 | 448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x # 2 CT block 449 | 450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x # 3 CT block 451 | 452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x # 4 CT block 453 | 454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x # 5 CT block 455 | 456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x # 6 CT block 457 | 458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x # 7 CT block 459 | 460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x # 8 CT block 461 | 462 | x = self.x_scale(x) + self.res_scale(residual) 463 | 464 | return x 465 | 466 | 467 | class ConvTransBlock(nn.Module): 468 | """ 469 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block 470 | """ 471 | 472 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio, 473 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 474 | last_fusion=False, num_med_block=0, groups=1): 475 | super(ConvTransBlock, self).__init__() 476 | expansion = 1 477 | self.cnn_block = CTGroup(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 478 | 479 | self.dw_stride = dw_stride 480 | self.embed_dim = embed_dim 481 | self.num_med_block = num_med_block 482 | self.last_fusion = last_fusion 483 | self.res_scale = Scale(1) 484 | self.x_scale = Scale(1) 485 | 486 | def forward(self, x): 487 | x = self.cnn_block(x) 488 | 489 | return x 490 | 491 | 492 | class MODEL(nn.Module): 493 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1., 494 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0., 495 | patch_norm=True): 496 | super(MODEL, self).__init__() 497 | scale = args.scale 498 | n_feats = 48 499 | n_colors = 3 500 | embed_dim = 64 501 | 502 | self.patch_norm = patch_norm 503 | self.num_features = embed_dim 504 | rgb_mean = (0.4488, 0.4371, 0.4040) 505 | rgb_std = (1.0, 1.0, 1.0) 506 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 507 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 508 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1) 509 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 510 | 511 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule 512 | 513 | # 2~final Stage 514 | init_stage = 2 515 | fin_stage = 3 516 | stage_1_channel = n_feats 517 | trans_dw_stride = patch_size 518 | for i in range(init_stage, fin_stage): 519 | if i%2==0: 520 | m = i 521 | else: 522 | m = i-1 523 | self.add_module('conv_trans_' + str(m), 524 | ConvTransBlock( 525 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride, 526 | embed_dim=embed_dim, 527 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 528 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 529 | drop_path_rate=self.trans_dpr[i - 1], 530 | num_med_block=num_med_block 531 | ) 532 | ) 533 | 534 | self.fin_stage = fin_stage 535 | self.dw_stride = trans_dw_stride 536 | 537 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 538 | 539 | m = [] 540 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 541 | m.append(nn.PixelShuffle(scale[0])) 542 | self.UP1 = nn.Sequential(*m) 543 | 544 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 545 | up_body = [] 546 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 547 | up_body.append(nn.PixelShuffle(scale[0])) 548 | self.UP2 = nn.Sequential(*up_body) 549 | 550 | self.apply(self._init_weights) 551 | 552 | def _init_weights(self, m): 553 | if isinstance(m, nn.Linear): 554 | trunc_normal_(m.weight, std=.02) 555 | if isinstance(m, nn.Linear) and m.bias is not None: 556 | nn.init.constant_(m.bias, 0) 557 | elif isinstance(m, nn.LayerNorm): 558 | nn.init.constant_(m.bias, 0) 559 | nn.init.constant_(m.weight, 1.0) 560 | 561 | def forward(self, x): 562 | (H, W) = (x.shape[2], x.shape[3]) 563 | residual = x 564 | x = self.sub_mean(x) 565 | x = self.conv_first_cnn(x) 566 | 567 | for i in range(2, self.fin_stage): 568 | if i%2==0: 569 | m = i 570 | else: 571 | m = i-1 572 | x = eval('self.conv_trans_' + str(m))(x) 573 | 574 | x = self.conv_after_body(x) 575 | y1 = self.UP1(x) 576 | y2 = self.UP2(self.conv_stright(residual)) 577 | output = self.add_mean(y1 + y2) 578 | 579 | return output -------------------------------------------------------------------------------- /model/CFINx2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from model import common 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import torch.nn.functional as F 7 | from pdb import set_trace as stx 8 | import numbers 9 | from einops import rearrange 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | #from IPython import embed 13 | 14 | def make_model(args, parent=False): 15 | return MODEL(args) 16 | 17 | class Mlp(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0) 23 | self.act = act_layer() 24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class Conv2d_CG(nn.Conv2d): 36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1, 37 | bias=True): 38 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 39 | 40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True) 41 | self.bias_conv = Parameter(torch.Tensor(out_channels)) 42 | nn.init.kaiming_normal_(self.weight_conv) 43 | 44 | self.stride = stride 45 | self.padding = padding 46 | self.dilation = dilation 47 | self.groups = groups 48 | 49 | if kernel_size == 0: 50 | self.ind = True 51 | else: 52 | self.ind = False 53 | self.oc = out_channels 54 | self.ks = kernel_size 55 | 56 | # target spatial size of the pooling layer 57 | ws = kernel_size 58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws)) 59 | 60 | # the dimension of latent representation 61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 62 | 63 | # the context encoding module 64 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 65 | 66 | self.act = nn.ReLU() 67 | 68 | # the number of groups in the channel interaction module 69 | if in_channels // 8: 70 | self.g = 8 71 | else: 72 | self.g = in_channels 73 | 74 | # the channel interacting module 75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 76 | 77 | # the gate decoding module (spatial interaction) 78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 80 | 81 | # used to prepare the input feature map to patches 82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 83 | 84 | # sigmoid function 85 | self.sig = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | if self.ind: 89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups) 90 | else: 91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch 92 | weight = self.weight_conv 93 | 94 | # allocate global information 95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9 96 | 97 | # context-encoding module 98 | out = self.ce(gl) # out: batch x n_feat x 5 99 | 100 | # use different bn for following two branches 101 | ce2 = out # ce2: batch x n_feat x 5 102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization) 103 | 104 | # gate decoding branch 1 (spatial interaction) 105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3) 106 | 107 | # channel interacting module 108 | if self.g > 3: 109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous() 110 | else: 111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous() 112 | oc = oc.view(b, self.oc, -1) 113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer) 114 | 115 | # gate decoding branch 2 (spatial interaction) 116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3) 117 | 118 | # produce gate (equation (4) in the CRAN paper) 119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight) 121 | 122 | # unfolding input feature map to patches 123 | x_un = self.unfold(x) 124 | b, _, l = x_un.size() 125 | out = (out * weight.unsqueeze(0))#.to(device) 126 | out = out.view(b, self.oc, -1) 127 | 128 | # currently only handle square input and output 129 | return torch.matmul(out, x_un).view(b, self.oc, h, w) 130 | 131 | class ConvAttention(nn.Module): 132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 133 | super().__init__() 134 | self.num_heads = num_heads 135 | head_dim = dim // num_heads 136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True) 137 | self.to_qkv = Conv2d_CG(dim, dim*3) 138 | 139 | def forward(self, x, k1=None, v1=None, return_x=False): 140 | weight = self.weight 141 | b,c,h,w = x.shape 142 | 143 | qkv = self.to_qkv(x) 144 | q, k, v = qkv.chunk(3, dim=1) 145 | 146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 149 | 150 | if k1 is None: 151 | k = k 152 | v = v 153 | else: 154 | k = k1 + k 155 | v = v1 + v 156 | q = torch.nn.functional.normalize(q, dim=-1) 157 | k = torch.nn.functional.normalize(k, dim=-1) 158 | 159 | attn = (q @ k.transpose(-2, -1)) * weight 160 | attn = attn.softmax(dim=-1) 161 | x = (attn @ v) 162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 163 | 164 | if return_x: 165 | return x 166 | else: 167 | return x, k, v 168 | 169 | 170 | class WithBias_LayerNorm(nn.Module): 171 | def __init__(self, normalized_shape): 172 | super(WithBias_LayerNorm, self).__init__() 173 | if isinstance(normalized_shape, numbers.Integral): 174 | normalized_shape = (normalized_shape,) 175 | normalized_shape = torch.Size(normalized_shape) 176 | 177 | assert len(normalized_shape) == 1 178 | 179 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 180 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 181 | self.normalized_shape = normalized_shape 182 | 183 | def forward(self, x): 184 | mu = x.mean(-1, keepdim=True) 185 | sigma = x.var(-1, keepdim=True, unbiased=False) 186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 187 | 188 | 189 | def to_3d(x): 190 | return rearrange(x, 'b c h w -> b (h w) c') 191 | 192 | def to_4d(x,h,w): 193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 194 | 195 | class LayerNorm(nn.Module): 196 | def __init__(self, dim): 197 | super(LayerNorm, self).__init__() 198 | self.body = WithBias_LayerNorm(dim) 199 | 200 | def forward(self, x): 201 | h, w = x.shape[-2:] 202 | return to_4d(self.body(to_3d(x)), h, w) 203 | 204 | 205 | class Block(nn.Module): 206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 207 | drop_path=0., act_layer=nn.GELU): 208 | super().__init__() 209 | self.norm1 = LayerNorm(dim) 210 | kernel_size1 = 1 211 | padding1 = 0 212 | kernel_size2 = 3 213 | padding2 = 1 214 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1) 215 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2) 216 | 217 | self.norm2 = LayerNorm(dim) 218 | self.norm3 = LayerNorm(dim) 219 | mlp_hidden_dim = int(dim*1) 220 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) 221 | 222 | def forward(self, x): 223 | res = x 224 | x, k1, v1 = self.attn(x) 225 | x = res + self.norm1(x) 226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True)) 227 | x = x + self.norm3(self.mlp(x)) 228 | return x 229 | 230 | 231 | class Scale(nn.Module): 232 | def __init__(self, init_value=1e-3): 233 | super().__init__() 234 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 235 | 236 | def forward(self, input): 237 | return input * self.scale 238 | 239 | 240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1): 241 | act_type = act_type.lower() 242 | if act_type == 'relu': 243 | layer = nn.ReLU() 244 | elif act_type == 'lrelu': 245 | layer = nn.LeakyReLU(neg_slope) 246 | elif act_type == 'prelu': 247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 248 | else: 249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 250 | return layer 251 | 252 | 253 | class eca_layer(nn.Module): 254 | def __init__(self, channel, k_size): 255 | super(eca_layer, self).__init__() 256 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 257 | self.k_size = k_size 258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel) 259 | self.sigmoid = nn.Sigmoid() 260 | 261 | 262 | def forward(self, x): 263 | b, c, _, _ = x.size() 264 | y = self.avg_pool(x) 265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2)) 266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1) 267 | y = self.sigmoid(y) 268 | x = x * y.expand_as(x) 269 | return x 270 | 271 | 272 | class MaskPredictor(nn.Module): 273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)): 274 | super(MaskPredictor,self).__init__() 275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False) 276 | 277 | def forward(self,x): 278 | spa_mask=self.spatial_mask(x) 279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1) 280 | return spa_mask 281 | 282 | 283 | class RB(nn.Module): 284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 285 | super(RB, self).__init__() 286 | self.CA = eca_layer(n_feats, k_size=3) 287 | self.MaskPredictor = MaskPredictor(n_feats*8//8) 288 | 289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 290 | nn.LeakyReLU(0.05), 291 | ) 292 | 293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 294 | nn.LeakyReLU(0.05), 295 | ) 296 | 297 | self.res_scale = Scale(1) 298 | self.x_scale = Scale(1) 299 | 300 | def forward(self, x): 301 | res = x 302 | x = self.k(x) 303 | 304 | MaskPredictor = self.MaskPredictor(x) 305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1) 306 | x = x * (mask.expand_as(x)) 307 | 308 | x1 = self.k1(x) 309 | x2 = self.CA(x1) 310 | out = self.x_scale(x2) + self.res_scale(res) 311 | 312 | return out 313 | 314 | 315 | class SCConv(nn.Module): 316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 317 | super(SCConv, self).__init__() 318 | pooling_r = 2 319 | med_feats = n_feats // 1 320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*3//2, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True), 321 | nn.LeakyReLU(0.05), 322 | nn.Conv2d(n_feats*3//2, n_feats, kernel_size=1, stride=2, padding=0, groups=1), 323 | ) 324 | 325 | self.sig = nn.Sigmoid() 326 | 327 | self.k3 = RB(n_feats) 328 | 329 | self.k4 = RB(n_feats) 330 | 331 | self.k5 = RB(n_feats) 332 | 333 | self.res_scale = Scale(1) 334 | self.x_scale = Scale(1) 335 | 336 | def forward(self, x): 337 | identity = x 338 | _, _, H, W = identity.shape 339 | x1_1 = self.k3(x) 340 | x1 = self.k4(x1_1) 341 | 342 | 343 | x1_s = self.sig(self.k1(x) + x) 344 | x1 = self.k5(x1_s * x1) 345 | 346 | out = self.res_scale(x1) + self.x_scale(identity) 347 | 348 | return out 349 | 350 | 351 | class FCUUp(nn.Module): 352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, 353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)): 354 | super(FCUUp, self).__init__() 355 | self.up_stride = up_stride 356 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 357 | self.act = act_layer() 358 | 359 | def forward(self, x_t): 360 | x_r = self.act(self.conv_project(x_t)) 361 | 362 | return x_r 363 | 364 | class FCUDown(nn.Module): 365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, 366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)): 367 | super(FCUDown, self).__init__() 368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 369 | 370 | def forward(self, x): 371 | x = self.conv_project(x) 372 | 373 | return x 374 | 375 | 376 | class ConvBlock(nn.Module): 377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None): 378 | super(ConvBlock, self).__init__() 379 | 380 | expansion = 1 381 | med_planes = outplanes // expansion 382 | embed_dim = 144 383 | num_heads = 8 384 | mlp_ratio = 1.0 385 | 386 | self.rb_search1 = SCConv(med_planes) 387 | self.rb_search2 = SCConv(med_planes) 388 | self.rb_search3 = SCConv(med_planes) 389 | self.rb_search4 = SCConv(med_planes) 390 | 391 | self.trans_block = Block( 392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 393 | drop=0., attn_drop=0., drop_path=0.) 394 | 395 | self.trans_block1 = Block( 396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 397 | drop=0., attn_drop=0., drop_path=0.) 398 | 399 | self.trans_block2 = Block( 400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 401 | drop=0., attn_drop=0., drop_path=0.) 402 | 403 | self.trans_block3 = Block( 404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 405 | drop=0., attn_drop=0., drop_path=0.) 406 | 407 | self.trans_block4 = Block( 408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 409 | drop=0., attn_drop=0., drop_path=0.) 410 | 411 | self.trans_block5 = Block( 412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 413 | drop=0., attn_drop=0., drop_path=0.) 414 | 415 | self.trans_block6 = Block( 416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 417 | drop=0., attn_drop=0., drop_path=0.) 418 | 419 | self.trans_block7 = Block( 420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 421 | drop=0., attn_drop=0., drop_path=0.) 422 | 423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 431 | 432 | self.res_scale = Scale(1) 433 | self.x_scale = Scale(1) 434 | self.num_rbs = 1 435 | 436 | self.res_conv = res_conv 437 | self.drop_block = drop_block 438 | self.drop_path = drop_path 439 | 440 | def zero_init_last_bn(self): 441 | nn.init.zeros_(self.bn3.weight) 442 | 443 | def forward(self, x): 444 | residual = x 445 | 446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x 447 | 448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x 449 | 450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x 451 | 452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x 453 | 454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x 455 | 456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x 457 | 458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x 459 | 460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x 461 | 462 | x = self.x_scale(x) + self.res_scale(residual) 463 | 464 | return x 465 | 466 | 467 | class ConvTransBlock(nn.Module): 468 | 469 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio, 470 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 471 | last_fusion=False, num_med_block=0, groups=1): 472 | super(ConvTransBlock, self).__init__() 473 | expansion = 1 474 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 475 | 476 | self.dw_stride = dw_stride 477 | self.embed_dim = embed_dim 478 | self.num_med_block = num_med_block 479 | self.last_fusion = last_fusion 480 | self.res_scale = Scale(1) 481 | self.x_scale = Scale(1) 482 | 483 | def forward(self, x): 484 | x = self.cnn_block(x) 485 | 486 | return x 487 | 488 | 489 | class MODEL(nn.Module): 490 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1., 491 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0., 492 | patch_norm=True): 493 | super(MODEL, self).__init__() 494 | scale = args.scale 495 | n_feats = 48 496 | n_colors = 3 497 | embed_dim = 64 498 | 499 | self.patch_norm = patch_norm 500 | self.num_features = embed_dim 501 | rgb_mean = (0.4488, 0.4371, 0.4040) 502 | rgb_std = (1.0, 1.0, 1.0) 503 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 504 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 505 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1) 506 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 507 | 508 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule 509 | 510 | # 2~final Stage 511 | init_stage = 2 512 | fin_stage = 3 513 | stage_1_channel = n_feats 514 | trans_dw_stride = patch_size 515 | for i in range(init_stage, fin_stage): 516 | if i%2==0: 517 | m = i 518 | else: 519 | m = i-1 520 | self.add_module('conv_trans_' + str(m), 521 | ConvTransBlock( 522 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride, 523 | embed_dim=embed_dim, 524 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 525 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 526 | drop_path_rate=self.trans_dpr[i - 1], 527 | num_med_block=num_med_block 528 | ) 529 | ) 530 | 531 | self.fin_stage = fin_stage 532 | self.dw_stride = trans_dw_stride 533 | 534 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 535 | 536 | m = [] 537 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 538 | m.append(nn.PixelShuffle(scale[0])) 539 | self.UP1 = nn.Sequential(*m) 540 | 541 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 542 | up_body = [] 543 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 544 | up_body.append(nn.PixelShuffle(scale[0])) 545 | self.UP2 = nn.Sequential(*up_body) 546 | 547 | self.apply(self._init_weights) 548 | 549 | def _init_weights(self, m): 550 | if isinstance(m, nn.Linear): 551 | trunc_normal_(m.weight, std=.02) 552 | if isinstance(m, nn.Linear) and m.bias is not None: 553 | nn.init.constant_(m.bias, 0) 554 | elif isinstance(m, nn.LayerNorm): 555 | nn.init.constant_(m.bias, 0) 556 | nn.init.constant_(m.weight, 1.0) 557 | 558 | def forward(self, x): 559 | (H, W) = (x.shape[2], x.shape[3]) 560 | residual = x 561 | x = self.sub_mean(x) 562 | x = self.conv_first_cnn(x) 563 | 564 | for i in range(2, self.fin_stage): 565 | if i%2==0: 566 | m = i 567 | else: 568 | m = i-1 569 | x = eval('self.conv_trans_' + str(m))(x) 570 | 571 | x = self.conv_after_body(x) 572 | y1 = self.UP1(x) 573 | y2 = self.UP2(self.conv_stright(residual)) 574 | output = self.add_mean(y1 + y2) 575 | 576 | return output -------------------------------------------------------------------------------- /model/CFINx3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from model import common 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import torch.nn.functional as F 7 | from pdb import set_trace as stx 8 | import numbers 9 | from einops import rearrange 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | #from IPython import embed 13 | 14 | def make_model(args, parent=False): 15 | return MODEL(args) 16 | 17 | class Mlp(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0) 23 | self.act = act_layer() 24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class PatchEmbed(nn.Module): 36 | def __init__(self, img_size, patch_size=4, in_chans=32, embed_dim=32, norm_layer=None): 37 | super().__init__() 38 | img_size = to_2tuple(img_size) 39 | patch_size = to_2tuple(patch_size) 40 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 41 | self.img_size = img_size 42 | self.patch_size = patch_size 43 | self.patches_resolution = patches_resolution 44 | self.num_patches = patches_resolution[0] * patches_resolution[1] 45 | 46 | self.embed_dim = embed_dim 47 | self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) 48 | 49 | if norm_layer is not None: ## norm_layer=None 50 | self.norm = norm_layer(embed_dim) 51 | else: 52 | self.norm = None 53 | 54 | def forward(self, x): 55 | B, C, H, W = x.shape 56 | x = self.proj(x) 57 | return x 58 | 59 | class Conv2d_CG(nn.Conv2d): 60 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1, 61 | bias=True): 62 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 63 | 64 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True) 65 | self.bias_conv = Parameter(torch.Tensor(out_channels)) 66 | nn.init.kaiming_normal_(self.weight_conv) 67 | 68 | self.stride = stride 69 | self.padding = padding 70 | self.dilation = dilation 71 | self.groups = groups 72 | 73 | if kernel_size == 0: 74 | self.ind = True 75 | else: 76 | self.ind = False 77 | self.oc = out_channels 78 | self.ks = kernel_size 79 | 80 | # target spatial size of the pooling layer 81 | ws = kernel_size 82 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws)) 83 | 84 | # the dimension of latent representation 85 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 86 | 87 | # the context encoding module 88 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 89 | 90 | self.act = nn.ReLU() 91 | 92 | # the number of groups in the channel interaction module 93 | if in_channels // 8: 94 | self.g = 8 95 | else: 96 | self.g = in_channels 97 | 98 | # the channel interacting module 99 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 100 | 101 | # the gate decoding module (spatial interaction) 102 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 103 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 104 | 105 | # used to prepare the input feature map to patches 106 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 107 | 108 | # sigmoid function 109 | self.sig = nn.Sigmoid() 110 | 111 | def forward(self, x): 112 | if self.ind: 113 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups) 114 | else: 115 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch 116 | weight = self.weight_conv 117 | 118 | # allocate global information 119 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9 120 | 121 | # context-encoding module 122 | out = self.ce(gl) # out: batch x n_feat x 5 123 | 124 | # use different bn for following two branches 125 | ce2 = out # ce2: batch x n_feat x 5 126 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization) 127 | 128 | # gate decoding branch 1 (spatial interaction) 129 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3) 130 | 131 | # channel interacting module 132 | if self.g > 3: 133 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous() 134 | else: 135 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous() 136 | oc = oc.view(b, self.oc, -1) 137 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer) 138 | 139 | # gate decoding branch 2 (spatial interaction) 140 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3) 141 | 142 | # produce gate (equation (4) in the CRAN paper) 143 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 144 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight) 145 | 146 | # unfolding input feature map to patches 147 | x_un = self.unfold(x) 148 | b, _, l = x_un.size() 149 | out = (out * weight.unsqueeze(0))#.to(device) 150 | out = out.view(b, self.oc, -1) 151 | 152 | # currently only handle square input and output 153 | return torch.matmul(out, x_un).view(b, self.oc, h, w) 154 | 155 | class ConvAttention(nn.Module): 156 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 157 | super().__init__() 158 | self.num_heads = num_heads 159 | head_dim = dim // num_heads 160 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True) 161 | self.to_qkv = Conv2d_CG(dim, dim*3) 162 | 163 | def forward(self, x, k1=None, v1=None, return_x=False): 164 | weight = self.weight 165 | b,c,h,w = x.shape 166 | 167 | qkv = self.to_qkv(x) 168 | q, k, v = qkv.chunk(3, dim=1) 169 | 170 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 171 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 172 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 173 | 174 | if k1 is None: 175 | k = k 176 | v = v 177 | else: 178 | k = k1 + k 179 | v = v1 + v 180 | q = torch.nn.functional.normalize(q, dim=-1) 181 | k = torch.nn.functional.normalize(k, dim=-1) 182 | 183 | attn = (q @ k.transpose(-2, -1)) * weight 184 | attn = attn.softmax(dim=-1) 185 | x = (attn @ v) 186 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 187 | 188 | if return_x: 189 | return x 190 | else: 191 | return x, k, v 192 | 193 | 194 | class WithBias_LayerNorm(nn.Module): 195 | def __init__(self, normalized_shape): 196 | super(WithBias_LayerNorm, self).__init__() 197 | if isinstance(normalized_shape, numbers.Integral): 198 | normalized_shape = (normalized_shape,) 199 | normalized_shape = torch.Size(normalized_shape) 200 | 201 | assert len(normalized_shape) == 1 202 | 203 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 204 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 205 | self.normalized_shape = normalized_shape 206 | 207 | def forward(self, x): 208 | mu = x.mean(-1, keepdim=True) 209 | sigma = x.var(-1, keepdim=True, unbiased=False) 210 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 211 | 212 | 213 | def to_3d(x): 214 | return rearrange(x, 'b c h w -> b (h w) c') 215 | 216 | def to_4d(x,h,w): 217 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 218 | 219 | class LayerNorm(nn.Module): 220 | def __init__(self, dim): 221 | super(LayerNorm, self).__init__() 222 | self.body = WithBias_LayerNorm(dim) 223 | 224 | def forward(self, x): 225 | h, w = x.shape[-2:] 226 | return to_4d(self.body(to_3d(x)), h, w) 227 | 228 | 229 | class Block(nn.Module): 230 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 231 | drop_path=0., act_layer=nn.GELU): 232 | super().__init__() 233 | self.norm1 = LayerNorm(dim) 234 | kernel_size1 = 1 235 | padding1 = 0 236 | kernel_size2 = 3 237 | padding2 = 1 238 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1) 239 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2) 240 | 241 | self.norm2 = LayerNorm(dim) 242 | self.norm3 = LayerNorm(dim) 243 | mlp_hidden_dim = int(dim*1) 244 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) 245 | 246 | def forward(self, x): 247 | res = x 248 | x, k1, v1 = self.attn(x) 249 | x = res + self.norm1(x) 250 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True)) 251 | x = x + self.norm3(self.mlp(x)) 252 | return x 253 | 254 | 255 | class Scale(nn.Module): 256 | def __init__(self, init_value=1e-3): 257 | super().__init__() 258 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 259 | 260 | def forward(self, input): 261 | return input * self.scale 262 | 263 | 264 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1): 265 | act_type = act_type.lower() 266 | if act_type == 'relu': 267 | layer = nn.ReLU() 268 | elif act_type == 'lrelu': 269 | layer = nn.LeakyReLU(neg_slope) 270 | elif act_type == 'prelu': 271 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 272 | else: 273 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 274 | return layer 275 | 276 | 277 | class eca_layer(nn.Module): 278 | def __init__(self, channel, k_size): 279 | super(eca_layer, self).__init__() 280 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 281 | self.k_size = k_size 282 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel) 283 | self.sigmoid = nn.Sigmoid() 284 | 285 | 286 | def forward(self, x): 287 | b, c, _, _ = x.size() 288 | y = self.avg_pool(x) 289 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2)) 290 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1) 291 | y = self.sigmoid(y) 292 | x = x * y.expand_as(x) 293 | return x 294 | 295 | 296 | class MaskPredictor(nn.Module): 297 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)): 298 | super(MaskPredictor,self).__init__() 299 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False) 300 | 301 | def forward(self,x): 302 | spa_mask=self.spatial_mask(x) 303 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1) 304 | return spa_mask 305 | 306 | 307 | class RB(nn.Module): 308 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 309 | super(RB, self).__init__() 310 | self.CA = eca_layer(n_feats, k_size=3) 311 | self.MaskPredictor = MaskPredictor(n_feats*8//8) 312 | 313 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 314 | nn.LeakyReLU(0.05), 315 | ) 316 | 317 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 318 | nn.LeakyReLU(0.05), 319 | ) 320 | 321 | self.res_scale = Scale(1) 322 | self.x_scale = Scale(1) 323 | 324 | def forward(self, x): 325 | res = x 326 | x = self.k(x) 327 | 328 | MaskPredictor = self.MaskPredictor(x) 329 | mask = (MaskPredictor[:,1,...]).unsqueeze(1) 330 | x = x * (mask.expand_as(x)) 331 | 332 | x1 = self.k1(x) 333 | x2 = self.CA(x1) 334 | out = self.x_scale(x2) + self.res_scale(res) 335 | 336 | return out 337 | 338 | 339 | class SCConv(nn.Module): 340 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 341 | super(SCConv, self).__init__() 342 | pooling_r = 2 343 | med_feats = n_feats // 1 344 | 345 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True), 346 | nn.LeakyReLU(0.05), 347 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1), 348 | ) 349 | 350 | self.sig = nn.Sigmoid() 351 | 352 | self.k3 = RB(n_feats) 353 | 354 | self.k4 = RB(n_feats) 355 | 356 | self.k5 = RB(n_feats) 357 | 358 | self.res_scale = Scale(1) 359 | self.x_scale = Scale(1) 360 | 361 | def forward(self, x): 362 | identity = x 363 | _, _, H, W = identity.shape 364 | x1_1 = self.k3(x) 365 | x1 = self.k4(x1_1) 366 | 367 | 368 | x1_s = self.sig(self.k1(x) + x) 369 | x1 = self.k5(x1_s * x1) 370 | 371 | out = self.res_scale(x1) + self.x_scale(identity) 372 | 373 | return out 374 | 375 | 376 | class FCUUp(nn.Module): 377 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, 378 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)): 379 | super(FCUUp, self).__init__() 380 | self.up_stride = up_stride 381 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) 382 | self.act = act_layer() 383 | 384 | def forward(self, x_t): 385 | x_r = self.act(self.conv_project(x_t)) 386 | 387 | return x_r 388 | 389 | class FCUDown(nn.Module): 390 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, 391 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)): 392 | super(FCUDown, self).__init__() 393 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 394 | 395 | def forward(self, x): 396 | x = self.conv_project(x) 397 | 398 | return x 399 | 400 | 401 | class ConvBlock(nn.Module): 402 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None): 403 | super(ConvBlock, self).__init__() 404 | 405 | expansion = 1 406 | med_planes = outplanes // expansion 407 | embed_dim = 144 408 | num_heads = 8 409 | mlp_ratio = 1.0 410 | 411 | self.rb_search1 = SCConv(med_planes) 412 | self.rb_search2 = SCConv(med_planes) 413 | self.rb_search3 = SCConv(med_planes) 414 | self.rb_search4 = SCConv(med_planes) 415 | 416 | self.trans_block = Block( 417 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 418 | drop=0., attn_drop=0., drop_path=0.) 419 | 420 | self.trans_block1 = Block( 421 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 422 | drop=0., attn_drop=0., drop_path=0.) 423 | 424 | self.trans_block2 = Block( 425 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 426 | drop=0., attn_drop=0., drop_path=0.) 427 | 428 | self.trans_block3 = Block( 429 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 430 | drop=0., attn_drop=0., drop_path=0.) 431 | 432 | self.trans_block4 = Block( 433 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 434 | drop=0., attn_drop=0., drop_path=0.) 435 | 436 | self.trans_block5 = Block( 437 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 438 | drop=0., attn_drop=0., drop_path=0.) 439 | 440 | self.trans_block6 = Block( 441 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 442 | drop=0., attn_drop=0., drop_path=0.) 443 | 444 | self.trans_block7 = Block( 445 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 446 | drop=0., attn_drop=0., drop_path=0.) 447 | 448 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 449 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 450 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 451 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 452 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 453 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 454 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 455 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 456 | 457 | self.res_scale = Scale(1) 458 | self.x_scale = Scale(1) 459 | self.num_rbs = 1 460 | 461 | self.res_conv = res_conv 462 | self.drop_block = drop_block 463 | self.drop_path = drop_path 464 | 465 | def zero_init_last_bn(self): 466 | nn.init.zeros_(self.bn3.weight) 467 | 468 | def forward(self, x): 469 | residual = x 470 | 471 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x 472 | 473 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x 474 | 475 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x 476 | 477 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x 478 | 479 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x 480 | 481 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x 482 | 483 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x 484 | 485 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x 486 | 487 | x = self.x_scale(x) + self.res_scale(residual) 488 | 489 | return x 490 | 491 | 492 | class ConvTransBlock(nn.Module): 493 | """ 494 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block 495 | """ 496 | 497 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio, 498 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 499 | last_fusion=False, num_med_block=0, groups=1): 500 | super(ConvTransBlock, self).__init__() 501 | expansion = 1 502 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 503 | 504 | self.dw_stride = dw_stride 505 | self.embed_dim = embed_dim 506 | self.num_med_block = num_med_block 507 | self.last_fusion = last_fusion 508 | self.res_scale = Scale(1) 509 | self.x_scale = Scale(1) 510 | 511 | def forward(self, x): 512 | x = self.cnn_block(x) 513 | 514 | return x 515 | 516 | 517 | class MODEL(nn.Module): 518 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1., 519 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0., 520 | patch_norm=True): 521 | super(MODEL, self).__init__() 522 | scale = args.scale 523 | n_feats = 48 524 | n_colors = 3 525 | embed_dim = 64 526 | 527 | self.patch_norm = patch_norm 528 | self.num_features = embed_dim 529 | rgb_mean = (0.4488, 0.4371, 0.4040) 530 | rgb_std = (1.0, 1.0, 1.0) 531 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 532 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 533 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 534 | 535 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule 536 | 537 | init_stage = 2 538 | fin_stage = 3 539 | stage_1_channel = n_feats 540 | trans_dw_stride = patch_size 541 | for i in range(init_stage, fin_stage): 542 | if i%2==0: 543 | m = i 544 | else: 545 | m = i-1 546 | self.add_module('conv_trans_' + str(m), 547 | ConvTransBlock( 548 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride, 549 | embed_dim=embed_dim, 550 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 551 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 552 | drop_path_rate=self.trans_dpr[i - 1], 553 | num_med_block=num_med_block 554 | ) 555 | ) 556 | 557 | self.fin_stage = fin_stage 558 | self.dw_stride = trans_dw_stride 559 | 560 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 561 | 562 | m = [] 563 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 564 | m.append(nn.PixelShuffle(scale[0])) 565 | self.UP1 = nn.Sequential(*m) 566 | 567 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 568 | up_body = [] 569 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 570 | up_body.append(nn.PixelShuffle(scale[0])) 571 | self.UP2 = nn.Sequential(*up_body) 572 | 573 | self.apply(self._init_weights) 574 | 575 | def _init_weights(self, m): 576 | if isinstance(m, nn.Linear): 577 | trunc_normal_(m.weight, std=.02) 578 | if isinstance(m, nn.Linear) and m.bias is not None: 579 | nn.init.constant_(m.bias, 0) 580 | elif isinstance(m, nn.LayerNorm): 581 | nn.init.constant_(m.bias, 0) 582 | nn.init.constant_(m.weight, 1.0) 583 | 584 | def forward(self, x): 585 | (H, W) = (x.shape[2], x.shape[3]) 586 | residual = x 587 | x = self.sub_mean(x) 588 | x = self.conv_first_cnn(x) 589 | 590 | for i in range(2, self.fin_stage): 591 | if i%2==0: 592 | m = i 593 | else: 594 | m = i-1 595 | x = eval('self.conv_trans_' + str(m))(x) 596 | 597 | x = self.conv_after_body(x) 598 | y1 = self.UP1(x) 599 | y2 = self.UP2(self.conv_stright(residual)) 600 | output = self.add_mean(y1 + y2) 601 | 602 | return output -------------------------------------------------------------------------------- /model/CFINx4.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from model import common 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import torch.nn.functional as F 7 | from pdb import set_trace as stx 8 | import numbers 9 | from einops import rearrange 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | #from IPython import embed 13 | 14 | def make_model(args, parent=False): 15 | return MODEL(args) 16 | 17 | class Mlp(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0) 23 | self.act = act_layer() 24 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class Conv2d_CG(nn.Conv2d): 36 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1, 37 | bias=True): 38 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 39 | 40 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True) 41 | self.bias_conv = Parameter(torch.Tensor(out_channels)) 42 | nn.init.kaiming_normal_(self.weight_conv) 43 | 44 | self.stride = stride 45 | self.padding = padding 46 | self.dilation = dilation 47 | self.groups = groups 48 | 49 | if kernel_size == 0: 50 | self.ind = True 51 | else: 52 | self.ind = False 53 | self.oc = out_channels 54 | self.ks = kernel_size 55 | 56 | # target spatial size of the pooling layer 57 | ws = kernel_size 58 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws)) 59 | 60 | # the dimension of latent representation 61 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 62 | 63 | # the context encoding module 64 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 65 | 66 | self.act = nn.ReLU() 67 | 68 | # the number of groups in the channel interaction module 69 | if in_channels // 8: 70 | self.g = 8 71 | else: 72 | self.g = in_channels 73 | 74 | # the channel interacting module 75 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 76 | 77 | # the gate decoding module (spatial interaction) 78 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 79 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 80 | 81 | # used to prepare the input feature map to patches 82 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 83 | 84 | # sigmoid function 85 | self.sig = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | if self.ind: 89 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups) 90 | else: 91 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch 92 | weight = self.weight_conv 93 | 94 | # allocate global information 95 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9 96 | 97 | # context-encoding module 98 | out = self.ce(gl) # out: batch x n_feat x 5 99 | 100 | # use different bn for following two branches 101 | ce2 = out # ce2: batch x n_feat x 5 102 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization) 103 | 104 | # gate decoding branch 1 (spatial interaction) 105 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3) 106 | 107 | # channel interacting module 108 | if self.g > 3: 109 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous() 110 | else: 111 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous() 112 | oc = oc.view(b, self.oc, -1) 113 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer) 114 | 115 | # gate decoding branch 2 (spatial interaction) 116 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3) 117 | 118 | # produce gate (equation (4) in the CRAN paper) 119 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 120 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight) 121 | 122 | # unfolding input feature map to patches 123 | x_un = self.unfold(x) 124 | b, _, l = x_un.size() 125 | out = (out * weight.unsqueeze(0))#.to(device) 126 | out = out.view(b, self.oc, -1) 127 | 128 | # currently only handle square input and output 129 | return torch.matmul(out, x_un).view(b, self.oc, h, w) 130 | 131 | class ConvAttention(nn.Module): 132 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 133 | super().__init__() 134 | self.num_heads = num_heads 135 | head_dim = dim // num_heads 136 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True) 137 | self.to_qkv = Conv2d_CG(dim, dim*3) 138 | 139 | def forward(self, x, k1=None, v1=None, return_x=False): 140 | weight = self.weight 141 | b,c,h,w = x.shape 142 | 143 | qkv = self.to_qkv(x) 144 | q, k, v = qkv.chunk(3, dim=1) 145 | 146 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 147 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 148 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 149 | 150 | if k1 is None: 151 | k = k 152 | v = v 153 | else: 154 | k = k1 + k 155 | v = v1 + v 156 | q = torch.nn.functional.normalize(q, dim=-1) 157 | k = torch.nn.functional.normalize(k, dim=-1) 158 | 159 | attn = (q @ k.transpose(-2, -1)) * weight 160 | attn = attn.softmax(dim=-1) 161 | x = (attn @ v) 162 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 163 | 164 | if return_x: 165 | return x 166 | else: 167 | return x, k, v 168 | 169 | 170 | class WithBias_LayerNorm(nn.Module): 171 | def __init__(self, normalized_shape): 172 | super(WithBias_LayerNorm, self).__init__() 173 | if isinstance(normalized_shape, numbers.Integral): 174 | normalized_shape = (normalized_shape,) 175 | normalized_shape = torch.Size(normalized_shape) 176 | 177 | assert len(normalized_shape) == 1 178 | 179 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 180 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 181 | self.normalized_shape = normalized_shape 182 | 183 | def forward(self, x): 184 | mu = x.mean(-1, keepdim=True) 185 | sigma = x.var(-1, keepdim=True, unbiased=False) 186 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 187 | 188 | 189 | def to_3d(x): 190 | return rearrange(x, 'b c h w -> b (h w) c') 191 | 192 | def to_4d(x,h,w): 193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 194 | 195 | class LayerNorm(nn.Module): 196 | def __init__(self, dim): 197 | super(LayerNorm, self).__init__() 198 | self.body = WithBias_LayerNorm(dim) 199 | 200 | def forward(self, x): 201 | h, w = x.shape[-2:] 202 | return to_4d(self.body(to_3d(x)), h, w) 203 | 204 | 205 | class Block(nn.Module): 206 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 207 | drop_path=0., act_layer=nn.GELU): 208 | super().__init__() 209 | self.norm1 = LayerNorm(dim) 210 | kernel_size1 = 1 211 | padding1 = 0 212 | kernel_size2 = 3 213 | padding2 = 1 214 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1) 215 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2) 216 | 217 | self.norm2 = LayerNorm(dim) 218 | self.norm3 = LayerNorm(dim) 219 | mlp_hidden_dim = int(dim*1) 220 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) 221 | 222 | def forward(self, x): 223 | res = x 224 | x, k1, v1 = self.attn(x) 225 | x = res + self.norm1(x) 226 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True)) 227 | x = x + self.norm3(self.mlp(x)) 228 | return x 229 | 230 | 231 | class Scale(nn.Module): 232 | def __init__(self, init_value=1e-3): 233 | super().__init__() 234 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 235 | 236 | def forward(self, input): 237 | return input * self.scale 238 | 239 | 240 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1): 241 | act_type = act_type.lower() 242 | if act_type == 'relu': 243 | layer = nn.ReLU() 244 | elif act_type == 'lrelu': 245 | layer = nn.LeakyReLU(neg_slope) 246 | elif act_type == 'prelu': 247 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 248 | else: 249 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 250 | return layer 251 | 252 | 253 | class eca_layer(nn.Module): 254 | def __init__(self, channel, k_size): 255 | super(eca_layer, self).__init__() 256 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 257 | self.k_size = k_size 258 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel) 259 | self.sigmoid = nn.Sigmoid() 260 | 261 | 262 | def forward(self, x): 263 | b, c, _, _ = x.size() 264 | y = self.avg_pool(x) 265 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2)) 266 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1) 267 | y = self.sigmoid(y) 268 | x = x * y.expand_as(x) 269 | return x 270 | 271 | 272 | class MaskPredictor(nn.Module): 273 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)): 274 | super(MaskPredictor,self).__init__() 275 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False) 276 | 277 | def forward(self,x): 278 | spa_mask=self.spatial_mask(x) 279 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1) 280 | return spa_mask 281 | 282 | 283 | class RB(nn.Module): 284 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 285 | super(RB, self).__init__() 286 | self.CA = eca_layer(n_feats, k_size=3) 287 | self.MaskPredictor = MaskPredictor(n_feats*8//8) 288 | 289 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 290 | nn.LeakyReLU(0.05), 291 | ) 292 | 293 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 294 | nn.LeakyReLU(0.05), 295 | ) 296 | 297 | self.res_scale = Scale(1) 298 | self.x_scale = Scale(1) 299 | 300 | def forward(self, x): 301 | res = x 302 | x = self.k(x) 303 | 304 | MaskPredictor = self.MaskPredictor(x) 305 | mask = (MaskPredictor[:,1,...]).unsqueeze(1) 306 | x = x * (mask.expand_as(x)) 307 | 308 | x1 = self.k1(x) 309 | x2 = self.CA(x1) 310 | out = self.x_scale(x2) + self.res_scale(res) 311 | 312 | return out 313 | 314 | 315 | class SCConv(nn.Module): 316 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 317 | super(SCConv, self).__init__() 318 | pooling_r = 2 319 | med_feats = n_feats // 1 320 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True), 321 | nn.LeakyReLU(0.05), 322 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1), 323 | ) 324 | 325 | self.sig = nn.Sigmoid() 326 | 327 | self.k3 = RB(n_feats) 328 | 329 | self.k4 = RB(n_feats) 330 | 331 | self.k5 = RB(n_feats) 332 | 333 | self.res_scale = Scale(1) 334 | self.x_scale = Scale(1) 335 | 336 | def forward(self, x): 337 | identity = x 338 | _, _, H, W = identity.shape 339 | x1_1 = self.k3(x) 340 | x1 = self.k4(x1_1) 341 | 342 | 343 | x1_s = self.sig(self.k1(x) + x) 344 | x1 = self.k5(x1_s * x1) 345 | 346 | out = self.res_scale(x1) + self.x_scale(identity) 347 | 348 | return out 349 | 350 | 351 | class FCUUp(nn.Module): 352 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, 353 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)): 354 | super(FCUUp, self).__init__() 355 | self.up_stride = up_stride 356 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) 357 | self.act = act_layer() 358 | 359 | def forward(self, x_t): 360 | x_r = self.act(self.conv_project(x_t)) 361 | 362 | return x_r 363 | 364 | class FCUDown(nn.Module): 365 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, 366 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)): 367 | super(FCUDown, self).__init__() 368 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 369 | 370 | def forward(self, x): 371 | x = self.conv_project(x) 372 | 373 | return x 374 | 375 | 376 | class ConvBlock(nn.Module): 377 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None): 378 | super(ConvBlock, self).__init__() 379 | 380 | expansion = 1 381 | med_planes = outplanes // expansion 382 | embed_dim = 144 383 | num_heads = 8 384 | mlp_ratio = 1.0 385 | 386 | self.rb_search1 = SCConv(med_planes) 387 | self.rb_search2 = SCConv(med_planes) 388 | self.rb_search3 = SCConv(med_planes) 389 | self.rb_search4 = SCConv(med_planes) 390 | 391 | self.trans_block = Block( 392 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 393 | drop=0., attn_drop=0., drop_path=0.) 394 | 395 | self.trans_block1 = Block( 396 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 397 | drop=0., attn_drop=0., drop_path=0.) 398 | 399 | self.trans_block2 = Block( 400 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 401 | drop=0., attn_drop=0., drop_path=0.) 402 | 403 | self.trans_block3 = Block( 404 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 405 | drop=0., attn_drop=0., drop_path=0.) 406 | 407 | self.trans_block4 = Block( 408 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 409 | drop=0., attn_drop=0., drop_path=0.) 410 | 411 | self.trans_block5 = Block( 412 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 413 | drop=0., attn_drop=0., drop_path=0.) 414 | 415 | self.trans_block6 = Block( 416 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 417 | drop=0., attn_drop=0., drop_path=0.) 418 | 419 | self.trans_block7 = Block( 420 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 421 | drop=0., attn_drop=0., drop_path=0.) 422 | 423 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 424 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 425 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 426 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 427 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 428 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 429 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 430 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 431 | 432 | self.res_scale = Scale(1) 433 | self.x_scale = Scale(1) 434 | self.num_rbs = 1 435 | 436 | self.res_conv = res_conv 437 | self.drop_block = drop_block 438 | self.drop_path = drop_path 439 | 440 | def zero_init_last_bn(self): 441 | nn.init.zeros_(self.bn3.weight) 442 | 443 | def forward(self, x): 444 | residual = x 445 | 446 | x = self.squeeze_block(self.trans_block(self.expand_block(self.rb_search1(x)))) + x 447 | 448 | x = self.squeeze_block(self.trans_block1(self.expand_block(self.rb_search1(x)))) + x 449 | 450 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(self.rb_search2(x)))) + x 451 | 452 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(self.rb_search2(x)))) + x 453 | 454 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(self.rb_search3(x)))) + x 455 | 456 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(self.rb_search3(x)))) + x 457 | 458 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(self.rb_search4(x)))) + x 459 | 460 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(self.rb_search4(x)))) + x 461 | 462 | x = self.x_scale(x) + self.res_scale(residual) 463 | 464 | return x 465 | 466 | 467 | class ConvTransBlock(nn.Module): 468 | """ 469 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block 470 | """ 471 | 472 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio, 473 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 474 | last_fusion=False, num_med_block=0, groups=1): 475 | super(ConvTransBlock, self).__init__() 476 | expansion = 1 477 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 478 | 479 | self.dw_stride = dw_stride 480 | self.embed_dim = embed_dim 481 | self.num_med_block = num_med_block 482 | self.last_fusion = last_fusion 483 | self.res_scale = Scale(1) 484 | self.x_scale = Scale(1) 485 | 486 | def forward(self, x): 487 | x = self.cnn_block(x) 488 | 489 | return x 490 | 491 | 492 | class MODEL(nn.Module): 493 | def __init__(self, args, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1., 494 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0., 495 | patch_norm=True): 496 | super(MODEL, self).__init__() 497 | scale = args.scale 498 | n_feats = 48 499 | n_colors = 3 500 | embed_dim = 64 501 | 502 | self.patch_norm = patch_norm 503 | self.num_features = embed_dim 504 | rgb_mean = (0.4488, 0.4371, 0.4040) 505 | rgb_std = (1.0, 1.0, 1.0) 506 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 507 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 508 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1) 509 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 510 | 511 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule 512 | 513 | # 2~final Stage 514 | init_stage = 2 515 | fin_stage = 3 516 | stage_1_channel = n_feats 517 | trans_dw_stride = patch_size 518 | for i in range(init_stage, fin_stage): 519 | if i%2==0: 520 | m = i 521 | else: 522 | m = i-1 523 | self.add_module('conv_trans_' + str(m), 524 | ConvTransBlock( 525 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride, 526 | embed_dim=embed_dim, 527 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 528 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 529 | drop_path_rate=self.trans_dpr[i - 1], 530 | num_med_block=num_med_block 531 | ) 532 | ) 533 | 534 | self.fin_stage = fin_stage 535 | self.dw_stride = trans_dw_stride 536 | 537 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 538 | 539 | m = [] 540 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 541 | m.append(nn.PixelShuffle(scale[0])) 542 | self.UP1 = nn.Sequential(*m) 543 | 544 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 545 | up_body = [] 546 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 547 | up_body.append(nn.PixelShuffle(scale[0])) 548 | self.UP2 = nn.Sequential(*up_body) 549 | 550 | self.apply(self._init_weights) 551 | 552 | def _init_weights(self, m): 553 | if isinstance(m, nn.Linear): 554 | trunc_normal_(m.weight, std=.02) 555 | if isinstance(m, nn.Linear) and m.bias is not None: 556 | nn.init.constant_(m.bias, 0) 557 | elif isinstance(m, nn.LayerNorm): 558 | nn.init.constant_(m.bias, 0) 559 | nn.init.constant_(m.weight, 1.0) 560 | 561 | def forward(self, x): 562 | (H, W) = (x.shape[2], x.shape[3]) 563 | residual = x 564 | x = self.sub_mean(x) 565 | x = self.conv_first_cnn(x) 566 | 567 | for i in range(2, self.fin_stage): 568 | if i%2==0: 569 | m = i 570 | else: 571 | m = i-1 572 | x = eval('self.conv_trans_' + str(m))(x) 573 | 574 | x = self.conv_after_body(x) 575 | y1 = self.UP1(x) 576 | y2 = self.UP2(self.conv_stright(residual)) 577 | output = self.add_mean(y1 + y2) 578 | 579 | return output -------------------------------------------------------------------------------- /model/MultiAdd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from model import common 5 | from timm.models.layers import trunc_normal_ 6 | import torch.nn.functional as F 7 | from pdb import set_trace as stx 8 | import numbers 9 | from einops import rearrange 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | from IPython import embed 13 | import cv2 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = Conv2d_CG(in_features, hidden_features, kernel_size=1, padding=0) 24 | self.act = act_layer() 25 | self.fc2 = Conv2d_CG(hidden_features, out_features, kernel_size=3, padding=1) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | class PatchEmbed(nn.Module): 37 | def __init__(self, img_size, patch_size=4, in_chans=32, embed_dim=32, norm_layer=None): 38 | super().__init__() 39 | img_size = to_2tuple(img_size) 40 | patch_size = to_2tuple(patch_size) 41 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 42 | self.img_size = img_size 43 | self.patch_size = patch_size 44 | self.patches_resolution = patches_resolution 45 | self.num_patches = patches_resolution[0] * patches_resolution[1] 46 | 47 | self.embed_dim = embed_dim 48 | self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) 49 | 50 | if norm_layer is not None: ## norm_layer=None 51 | self.norm = norm_layer(embed_dim) 52 | else: 53 | self.norm = None 54 | 55 | def forward(self, x): 56 | B, C, H, W = x.shape 57 | x = self.proj(x) 58 | return x 59 | 60 | class Conv2d_CG(nn.Conv2d): 61 | def __init__(self, in_channels=64, out_channels=64, kernel_size=1, padding=0, stride=1, dilation=1, groups=1, 62 | bias=True): 63 | super(Conv2d_CG, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 64 | 65 | # weight & bias for content-gated-convolution 66 | self.weight_conv = Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001, requires_grad=True) 67 | #Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size), 68 | # requires_grad=True).cuda() 69 | self.bias_conv = Parameter(torch.Tensor(out_channels))#Parameter(torch.zeros(out_channels), requires_grad=True) 70 | nn.init.kaiming_normal_(self.weight_conv) 71 | 72 | self.stride = stride 73 | self.padding = padding 74 | self.dilation = dilation 75 | self.groups = groups 76 | 77 | # for convolutional layers with a kernel size of 1, just use traditional convolution 78 | if kernel_size == 0: 79 | self.ind = True 80 | else: 81 | self.ind = False 82 | self.oc = out_channels 83 | self.ks = kernel_size 84 | 85 | # target spatial size of the pooling layer 86 | ws = kernel_size 87 | self.avg_pool = nn.AdaptiveMaxPool2d((ws, ws)) 88 | 89 | # the dimension of latent representation 90 | self.num_lat = int((kernel_size * kernel_size) / 2 + 1) 91 | 92 | # the context encoding module 93 | self.ce = nn.Linear(ws * ws, self.num_lat, False) 94 | #self.ce_bn = nn.BatchNorm1d(in_channels) 95 | #self.ci_bn2 = nn.BatchNorm1d(in_channels) 96 | 97 | self.act = nn.ReLU() 98 | 99 | # the number of groups in the channel interaction module 100 | if in_channels // 8: 101 | self.g = 8 102 | else: 103 | self.g = in_channels 104 | 105 | # the channel interacting module 106 | self.ci = nn.Linear(self.g, out_channels // (in_channels // self.g), bias=False) 107 | #self.ci_bn = nn.BatchNorm1d(out_channels) 108 | 109 | # the gate decoding module (spatial interaction) 110 | self.gd = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 111 | self.gd2 = nn.Linear(self.num_lat, kernel_size * kernel_size, False) 112 | 113 | # used to prepare the input feature map to patches 114 | self.unfold = nn.Unfold(kernel_size, dilation, padding, stride) 115 | 116 | # sigmoid function 117 | self.sig = nn.Sigmoid() 118 | 119 | def forward(self, x): 120 | # for convolutional layers with a kernel size of 1, just use the traditional convolution 121 | if self.ind: 122 | return F.conv2d(x, self.weight_conv, self.bias_conv, self.stride, self.padding, self.dilation, self.groups) 123 | else: 124 | b, c, h, w = x.size() # x: batch x n_feat(=64) x h_patch x w_patch 125 | weight = self.weight_conv 126 | 127 | # allocate global information 128 | gl = self.avg_pool(x).view(b, c, -1) # gl: batch x n_feat x 3 x 3 -> batch x n_feat x 9 129 | 130 | # context-encoding module 131 | out = self.ce(gl) # out: batch x n_feat x 5 132 | 133 | # use different bn for following two branches 134 | ce2 = out # ce2: batch x n_feat x 5 135 | #out = self.ce_bn(out) 136 | out = self.act(out) # out: batch x n_feat x 5 (just batch normalization) 137 | 138 | # gate decoding branch 1 (spatial interaction) 139 | out = self.gd(out) # out: batch x n_feat x 9 (5 --> 9 = 3x3) 140 | 141 | # channel interacting module 142 | if self.g > 3: 143 | oc = self.ci(self.act(ce2.view(b, c // self.g, self.g, -1).transpose(2, 3))).transpose(2,3).contiguous() 144 | else: 145 | oc = self.ci(self.act(ce2.transpose(2, 1))).transpose(2, 1).contiguous() 146 | oc = oc.view(b, self.oc, -1) 147 | #oc = self.ci_bn(oc) 148 | oc = self.act(oc) # oc: batch x n_feat x 5 (after grouped linear layer) 149 | 150 | # gate decoding branch 2 (spatial interaction) 151 | oc = self.gd2(oc) # oc: batch x n_feat x 9 (5 --> 9 = 3x3) 152 | 153 | # produce gate (equation (4) in the CRAN paper) 154 | out = self.sig(out.view(b, 1, c, self.ks, self.ks) + oc.view(b, self.oc, 1, self.ks, self.ks)) 155 | # out: batch x out_channel x in_channel x kernel_size x kernel_size (same dimension as conv2d weight) 156 | 157 | # unfolding input feature map to patches 158 | x_un = self.unfold(x) 159 | b, _, l = x_un.size() 160 | #embed() 161 | # gating 162 | out = (out * weight.unsqueeze(0))#.to(device) 163 | out = out.view(b, self.oc, -1) 164 | 165 | # currently only handle square input and output 166 | return torch.matmul(out, x_un).view(b, self.oc, h, w) 167 | 168 | class ConvAttention(nn.Module): 169 | def __init__(self, dim, num_heads=8, kernel_size=5, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 170 | super().__init__() 171 | self.num_heads = num_heads 172 | head_dim = dim // num_heads 173 | self.weight = nn.Parameter(torch.randn(num_heads, dim//num_heads, dim//num_heads) * 0.001, requires_grad=True) 174 | self.to_qkv = Conv2d_CG(dim, dim*3) 175 | 176 | def forward(self, x, k1=None, v1=None, return_x=False): 177 | weight = self.weight 178 | b,c,h,w = x.shape 179 | 180 | qkv = self.to_qkv(x) 181 | q, k, v = qkv.chunk(3, dim=1) 182 | 183 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 184 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 185 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 186 | 187 | if k1 is None: 188 | k = k 189 | v = v 190 | else: 191 | k = k1 + k 192 | v = v1 + v 193 | q = torch.nn.functional.normalize(q, dim=-1) 194 | k = torch.nn.functional.normalize(k, dim=-1) 195 | 196 | attn = (q @ k.transpose(-2, -1)) * weight 197 | attn = attn.softmax(dim=-1) 198 | x = (attn @ v) 199 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 200 | 201 | if return_x: 202 | return x 203 | else: 204 | return x, k, v 205 | 206 | 207 | class WithBias_LayerNorm(nn.Module): 208 | def __init__(self, normalized_shape): 209 | super(WithBias_LayerNorm, self).__init__() 210 | if isinstance(normalized_shape, numbers.Integral): 211 | normalized_shape = (normalized_shape,) 212 | normalized_shape = torch.Size(normalized_shape) 213 | 214 | assert len(normalized_shape) == 1 215 | 216 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 217 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 218 | self.normalized_shape = normalized_shape 219 | 220 | def forward(self, x): 221 | mu = x.mean(-1, keepdim=True) 222 | sigma = x.var(-1, keepdim=True, unbiased=False) 223 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 224 | 225 | 226 | def to_3d(x): 227 | return rearrange(x, 'b c h w -> b (h w) c') 228 | 229 | def to_4d(x,h,w): 230 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 231 | 232 | class LayerNorm(nn.Module): 233 | def __init__(self, dim): 234 | super(LayerNorm, self).__init__() 235 | self.body = WithBias_LayerNorm(dim) 236 | 237 | def forward(self, x): 238 | h, w = x.shape[-2:] 239 | return to_4d(self.body(to_3d(x)), h, w) 240 | 241 | 242 | class Block(nn.Module): 243 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 244 | drop_path=0., act_layer=nn.GELU): 245 | super().__init__() 246 | self.norm1 = LayerNorm(dim) 247 | kernel_size1 = 1 248 | padding1 = 0 249 | kernel_size2 = 3 250 | padding2 = 1 251 | #self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 252 | self.attn = ConvAttention(dim, num_heads, kernel_size1, padding1) 253 | self.attn1 = ConvAttention(dim, num_heads, kernel_size2, padding2) 254 | 255 | self.norm2 = LayerNorm(dim) 256 | self.norm3 = LayerNorm(dim) 257 | mlp_hidden_dim = int(dim*1) 258 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) 259 | 260 | def forward(self, x): 261 | res = x 262 | x, k1, v1 = self.attn(x) 263 | x = res + self.norm1(x) 264 | x = x + self.norm2(self.attn1(x, k1, v1, return_x=True)) 265 | x = x + self.norm3(self.mlp(x)) 266 | return x 267 | 268 | 269 | class Scale(nn.Module): 270 | def __init__(self, init_value=1e-3): 271 | super().__init__() 272 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 273 | 274 | def forward(self, input): 275 | return input * self.scale 276 | 277 | 278 | def activation(act_type, inplace=False, neg_slope=0.05, n_prelu=1): 279 | act_type = act_type.lower() 280 | if act_type == 'relu': 281 | layer = nn.ReLU() 282 | elif act_type == 'lrelu': 283 | layer = nn.LeakyReLU(neg_slope) 284 | elif act_type == 'prelu': 285 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 286 | else: 287 | raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) 288 | return layer 289 | 290 | 291 | class eca_layer(nn.Module): 292 | def __init__(self, channel, k_size): 293 | super(eca_layer, self).__init__() 294 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 295 | self.k_size = k_size 296 | self.conv = nn.Conv1d(channel, channel, kernel_size=k_size, bias=False, groups=channel) 297 | self.sigmoid = nn.Sigmoid() 298 | 299 | 300 | def forward(self, x): 301 | b, c, _, _ = x.size() 302 | y = self.avg_pool(x) 303 | y = nn.functional.unfold(y.transpose(-1, -3), kernel_size=(1, self.k_size), padding=(0, (self.k_size - 1) // 2)) 304 | y = self.conv(y.transpose(-1, -2)).unsqueeze(-1) 305 | y = self.sigmoid(y) 306 | x = x * y.expand_as(x) 307 | return x 308 | 309 | 310 | class MaskPredictor(nn.Module): 311 | def __init__(self,in_channels, wn=lambda x: torch.nn.utils.weight_norm(x)): 312 | super(MaskPredictor,self).__init__() 313 | self.spatial_mask=nn.Conv2d(in_channels=in_channels,out_channels=3,kernel_size=1,bias=False) 314 | 315 | def forward(self,x): 316 | spa_mask=self.spatial_mask(x) 317 | spa_mask=F.gumbel_softmax(spa_mask,tau=1,hard=True,dim=1) 318 | return spa_mask 319 | 320 | 321 | class RB(nn.Module): 322 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 323 | super(RB, self).__init__() 324 | self.CA = eca_layer(n_feats, k_size=3) 325 | self.MaskPredictor = MaskPredictor(n_feats*8//8) 326 | 327 | self.k = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 328 | nn.LeakyReLU(0.05), 329 | ) 330 | 331 | self.k1 = nn.Sequential(wn(nn.Conv2d(n_feats*8//8, n_feats*8//8, kernel_size=3, padding=1, stride=1, groups=1)), 332 | nn.LeakyReLU(0.05), 333 | ) 334 | 335 | self.res_scale = Scale(1) 336 | self.x_scale = Scale(1) 337 | 338 | def forward(self, x): 339 | res = x 340 | x = self.k(x) 341 | 342 | MaskPredictor = self.MaskPredictor(x) 343 | mask = (MaskPredictor[:,1,...]).unsqueeze(1) 344 | x = x * (mask.expand_as(x)) 345 | 346 | x1 = self.k1(x) 347 | x2 = self.CA(x1) 348 | out = self.x_scale(x2) + self.res_scale(res) 349 | 350 | return out 351 | 352 | 353 | class SCConv(nn.Module): 354 | def __init__(self, n_feats, wn=lambda x: torch.nn.utils.weight_norm(x)): 355 | super(SCConv, self).__init__() 356 | pooling_r = 2 357 | med_feats = n_feats // 1 358 | #self.k = nn.Conv2d(n_feats, n_feats, kernel_size=1, stride=2, padding=0, groups=1) 359 | 360 | self.k1 = nn.Sequential(nn.ConvTranspose2d(n_feats, n_feats*4//3, kernel_size=pooling_r, stride=pooling_r, padding=0, groups=1, bias=True), 361 | nn.LeakyReLU(0.05), 362 | nn.Conv2d(n_feats*4//3, n_feats, kernel_size=1, stride=2, padding=0, groups=1), 363 | ) 364 | 365 | self.sig = nn.Sigmoid() 366 | 367 | self.k3 = RB(n_feats) 368 | 369 | self.k4 = RB(n_feats) 370 | 371 | self.k5 = RB(n_feats) 372 | 373 | self.res_scale = Scale(1) 374 | self.x_scale = Scale(1) 375 | 376 | def forward(self, x): 377 | identity = x 378 | _, _, H, W = identity.shape 379 | x1_1 = self.k3(x) 380 | x1 = self.k4(x1_1) 381 | 382 | 383 | x1_s = self.sig(self.k1(x) + x) 384 | x1 = self.k5(x1_s * x1) 385 | 386 | out = self.res_scale(x1) + self.x_scale(identity) 387 | 388 | return out 389 | 390 | 391 | class FCUUp(nn.Module): 392 | def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, 393 | norm_layer=nn.BatchNorm2d, wn=lambda x: torch.nn.utils.weight_norm(x)): 394 | super(FCUUp, self).__init__() 395 | self.up_stride = up_stride 396 | self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0) 397 | self.act = act_layer() 398 | 399 | def forward(self, x_t): 400 | x_r = self.act(self.conv_project(x_t)) 401 | 402 | return x_r 403 | 404 | class FCUDown(nn.Module): 405 | def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, 406 | norm_layer=nn.LayerNorm, wn=lambda x: torch.nn.utils.weight_norm(x)): 407 | super(FCUDown, self).__init__() 408 | self.conv_project = wn(nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)) 409 | 410 | def forward(self, x): 411 | x = self.conv_project(x) 412 | 413 | return x 414 | 415 | 416 | class ConvBlock(nn.Module): 417 | def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=None): 418 | super(ConvBlock, self).__init__() 419 | 420 | expansion = 1 421 | med_planes = outplanes // expansion 422 | embed_dim = 144 423 | num_heads = 8 424 | mlp_ratio = 1.0 425 | 426 | self.rb_search1 = SCConv(med_planes) 427 | self.rb_search2 = SCConv(med_planes) 428 | self.rb_search3 = SCConv(med_planes) 429 | self.rb_search4 = SCConv(med_planes) 430 | 431 | self.trans_block = Block( 432 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 433 | drop=0., attn_drop=0., drop_path=0.) 434 | 435 | self.trans_block1 = Block( 436 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 437 | drop=0., attn_drop=0., drop_path=0.) 438 | 439 | self.trans_block2 = Block( 440 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 441 | drop=0., attn_drop=0., drop_path=0.) 442 | 443 | self.trans_block3 = Block( 444 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 445 | drop=0., attn_drop=0., drop_path=0.) 446 | 447 | self.trans_block4 = Block( 448 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 449 | drop=0., attn_drop=0., drop_path=0.) 450 | 451 | self.trans_block5 = Block( 452 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 453 | drop=0., attn_drop=0., drop_path=0.) 454 | 455 | self.trans_block6 = Block( 456 | dim=embed_dim, num_heads=num_heads*2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 457 | drop=0., attn_drop=0., drop_path=0.) 458 | 459 | self.trans_block7 = Block( 460 | dim=embed_dim, num_heads=num_heads*3//2, mlp_ratio=mlp_ratio, qkv_bias=False, qk_scale=None, 461 | drop=0., attn_drop=0., drop_path=0.) 462 | 463 | self.expand_block = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 464 | self.squeeze_block = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 465 | self.expand_block1 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 466 | self.squeeze_block1 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 467 | self.expand_block2 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 468 | self.squeeze_block2 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 469 | self.expand_block3 = FCUUp(inplanes=med_planes, outplanes=embed_dim, up_stride=1) 470 | self.squeeze_block3 = FCUDown(inplanes=embed_dim, outplanes=med_planes, dw_stride=1) 471 | 472 | self.res_scale = Scale(1) 473 | self.x_scale = Scale(1) 474 | self.num_rbs = 1 475 | 476 | self.res_conv = res_conv 477 | self.drop_block = drop_block 478 | self.drop_path = drop_path 479 | 480 | def zero_init_last_bn(self): 481 | nn.init.zeros_(self.bn3.weight) 482 | 483 | def forward(self, x): 484 | residual = x 485 | 486 | x = self.rb_search1(x) 487 | x = self.squeeze_block(self.trans_block(self.expand_block(x))) + x 488 | 489 | x = self.rb_search1(x) 490 | x = self.squeeze_block(self.trans_block1(self.expand_block(x))) + x 491 | 492 | x = self.rb_search2(x) 493 | x = self.squeeze_block1(self.trans_block2(self.expand_block1(x))) + x 494 | 495 | x = self.rb_search2(x) 496 | x = self.squeeze_block1(self.trans_block3(self.expand_block1(x))) + x 497 | 498 | x = self.rb_search3(x) 499 | x = self.squeeze_block2(self.trans_block4(self.expand_block2(x))) + x 500 | 501 | x = self.rb_search3(x) 502 | x = self.squeeze_block2(self.trans_block5(self.expand_block2(x))) + x 503 | 504 | x = self.rb_search4(x) 505 | x = self.squeeze_block3(self.trans_block6(self.expand_block3(x))) + x 506 | 507 | x = self.rb_search4(x) 508 | x = self.squeeze_block3(self.trans_block7(self.expand_block3(x))) + x 509 | 510 | x = self.x_scale(x) + self.res_scale(residual) 511 | 512 | return x 513 | 514 | 515 | class ConvTransBlock(nn.Module): 516 | """ 517 | Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block 518 | """ 519 | 520 | def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads, mlp_ratio, 521 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 522 | last_fusion=False, num_med_block=0, groups=1): 523 | super(ConvTransBlock, self).__init__() 524 | expansion = 1 525 | self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 526 | 527 | #self.fusion_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=1, groups=groups) 528 | 529 | self.dw_stride = dw_stride 530 | self.embed_dim = embed_dim 531 | self.num_med_block = num_med_block 532 | self.last_fusion = last_fusion 533 | self.res_scale = Scale(1) 534 | self.x_scale = Scale(1) 535 | 536 | def forward(self, x): 537 | x = self.cnn_block(x) 538 | 539 | #x = self.fusion_block(x) 540 | 541 | return x 542 | 543 | 544 | class MODEL(nn.Module): 545 | def __init__(self, norm_layer=nn.LayerNorm, patch_size=1, window_size=8, num_heads=8, mlp_ratio=1., 546 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_med_block=0, drop_path_rate=0., 547 | patch_norm=True): 548 | super(MODEL, self).__init__() 549 | scale = [4] 550 | n_feats = 48 551 | n_colors = 3 552 | embed_dim = 64 553 | #height = (1024 // scale // window_size + 1) * window_size 554 | #width = (720 // scale // window_size + 1) * window_size 555 | #img_size = (height, width) 556 | 557 | self.patch_norm = patch_norm 558 | self.num_features = embed_dim 559 | rgb_mean = (0.4488, 0.4371, 0.4040) 560 | rgb_std = (1.0, 1.0, 1.0) 561 | self.sub_mean = common.MeanShift(255, rgb_mean, rgb_std) 562 | self.add_mean = common.MeanShift(255, rgb_mean, rgb_std, 1) 563 | #self.conv_first_trans = nn.Conv2d(n_colors, embed_dim, 3, 1, 1) 564 | self.conv_first_cnn = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 565 | 566 | self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 8)] # stochastic depth decay rule 567 | 568 | # 2~final Stage 569 | init_stage = 2 570 | fin_stage = 3 571 | stage_1_channel = n_feats 572 | trans_dw_stride = patch_size 573 | for i in range(init_stage, fin_stage): 574 | if i%2==0: 575 | m = i 576 | else: 577 | m = i-1 578 | self.add_module('conv_trans_' + str(m), 579 | ConvTransBlock( 580 | stage_1_channel, stage_1_channel, res_conv=True, stride=1, dw_stride=trans_dw_stride, 581 | embed_dim=embed_dim, 582 | num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 583 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 584 | drop_path_rate=self.trans_dpr[i - 1], 585 | num_med_block=num_med_block 586 | ) 587 | ) 588 | 589 | self.fin_stage = fin_stage 590 | self.dw_stride = trans_dw_stride 591 | 592 | self.conv_after_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 593 | 594 | m = [] 595 | m.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 596 | m.append(nn.PixelShuffle(scale[0])) 597 | self.UP1 = nn.Sequential(*m) 598 | 599 | self.conv_stright = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 600 | up_body = [] 601 | up_body.append(nn.Conv2d(n_feats, (scale[0] ** 2) * n_colors, 3, 1, 1)) 602 | up_body.append(nn.PixelShuffle(scale[0])) 603 | self.UP2 = nn.Sequential(*up_body) 604 | 605 | self.apply(self._init_weights) 606 | 607 | def _init_weights(self, m): 608 | if isinstance(m, nn.Linear): 609 | trunc_normal_(m.weight, std=.02) 610 | if isinstance(m, nn.Linear) and m.bias is not None: 611 | nn.init.constant_(m.bias, 0) 612 | elif isinstance(m, nn.LayerNorm): 613 | nn.init.constant_(m.bias, 0) 614 | nn.init.constant_(m.weight, 1.0) 615 | 616 | def forward(self, x): 617 | (H, W) = (x.shape[2], x.shape[3]) 618 | residual = x 619 | x = self.sub_mean(x) 620 | x = self.conv_first_cnn(x) 621 | #x_t = x 622 | 623 | for i in range(2, self.fin_stage): 624 | if i%2==0: 625 | m = i 626 | else: 627 | m = i-1 628 | x = eval('self.conv_trans_' + str(m))(x) 629 | 630 | x = self.conv_after_body(x) 631 | y1 = self.UP1(x) 632 | y2 = self.UP2(self.conv_stright(residual)) 633 | output = self.add_mean(y1 + y2) 634 | 635 | return output -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import model.CFINx2 6 | import model.CFINx3 7 | import model.CFINx4 8 | 9 | def dataparallel(model, gpu_list): 10 | ngpus = len(gpu_list) 11 | assert ngpus != 0, "only support gpu mode" 12 | assert torch.cuda.device_count() >= ngpus, "Invalid Number of GPUs" 13 | assert isinstance(model, list), "Invalid Type of Dual model" 14 | for i in range(len(model)): 15 | if ngpus >= 2: 16 | model[i] = nn.DataParallel(model[i], gpu_list).cuda() 17 | else: 18 | model[i] = model[i].cuda() 19 | return model 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, opt, ckp): 24 | super(Model, self).__init__() 25 | print('Making model...') 26 | self.opt = opt 27 | self.scale = opt.scale 28 | self.idx_scale = 0 29 | self.self_ensemble = opt.self_ensemble 30 | self.cpu = opt.cpu 31 | self.device = torch.device('cpu' if opt.cpu else 'cuda') 32 | self.n_GPUs = opt.n_GPUs 33 | 34 | self.chop = opt.chop 35 | self.precision = opt.precision 36 | 37 | self.model = CFINx4.make_model(opt).to(self.device) 38 | 39 | if not opt.cpu and opt.n_GPUs > 1: 40 | self.model = nn.DataParallel(self.model, range(opt.n_GPUs)) 41 | 42 | self.load(opt.pre_train,cpu=opt.cpu) 43 | 44 | # compute parameter 45 | # num_parameter = self.count_parameters(self.model) 46 | # ckp.write_log(f"The number of parameters is {num_parameter / 1000 ** 2:.2f}M") 47 | # ckp.write_log(f"The number of parameters is {num_parameter:.2f}") 48 | 49 | def forward(self, x, idx_scale=0): 50 | self.idx_scale = idx_scale 51 | target = self.get_model() 52 | if hasattr(target, 'set_scale'): 53 | target.set_scale(idx_scale) 54 | 55 | if self.self_ensemble and not self.training: 56 | if self.chop: 57 | forward_function = self.forward_chop 58 | else: 59 | forward_function = self.model.forward 60 | 61 | return self.forward_x8(x, forward_function) 62 | elif self.chop and not self.training: 63 | return self.forward_chop(x) 64 | else: 65 | return self.model(x) 66 | 67 | def get_model(self): 68 | if self.n_GPUs == 1: 69 | return self.model 70 | else: 71 | return self.model.module 72 | 73 | 74 | def state_dict(self, **kwargs): 75 | target = self.get_model() 76 | return target.state_dict(**kwargs) 77 | 78 | def count_parameters(self, model): 79 | if self.opt.n_GPUs > 1: 80 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 81 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 82 | 83 | def save(self, path, is_best=False): 84 | target = self.get_model() 85 | torch.save( 86 | target.state_dict(), 87 | os.path.join(path, 'model', 'model_latest.pt') 88 | ) 89 | if is_best: 90 | torch.save( 91 | target.state_dict(), 92 | os.path.join(path, 'model', 'model_best.pt') 93 | ) 94 | 95 | def load(self, pre_train='.',cpu=False): 96 | if cpu: 97 | kwargs = {'map_location': lambda storage, loc: storage} 98 | else: 99 | kwargs = {} 100 | #### load primal model #### 101 | if pre_train != '.': 102 | print('Loading model from {}'.format(pre_train)) 103 | self.get_model().load_state_dict( 104 | torch.load(pre_train, **kwargs), 105 | strict=False 106 | ) 107 | 108 | def forward_chop(self, x, shave=10, min_size=160000): 109 | scale = self.scale[self.idx_scale] 110 | n_GPUs = min(self.n_GPUs, 4) 111 | b, c, h, w = x.size() 112 | h_half, w_half = h // 2, w // 2 113 | h_size, w_size = h_half + shave, w_half + shave 114 | lr_list = [ 115 | x[:, :, 0:h_size, 0:w_size], 116 | x[:, :, 0:h_size, (w - w_size):w], 117 | x[:, :, (h - h_size):h, 0:w_size], 118 | x[:, :, (h - h_size):h, (w - w_size):w]] 119 | 120 | if w_size * h_size < min_size: 121 | sr_list = [] 122 | for i in range(0, 4, n_GPUs): 123 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 124 | sr_batch = self.model(lr_batch) 125 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 126 | else: 127 | sr_list = [ 128 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 129 | for patch in lr_list 130 | ] 131 | 132 | h, w = scale * h, scale * w 133 | h_half, w_half = scale * h_half, scale * w_half 134 | h_size, w_size = scale * h_size, scale * w_size 135 | shave *= scale 136 | 137 | output = x.new(b, c, h, w) 138 | output[:, :, 0:h_half, 0:w_half] \ 139 | = sr_list[0][:, :, 0:h_half, 0:w_half] 140 | output[:, :, 0:h_half, w_half:w] \ 141 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 142 | output[:, :, h_half:h, 0:w_half] \ 143 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 144 | output[:, :, h_half:h, w_half:w] \ 145 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 146 | 147 | return output 148 | 149 | def forward_x8(self, x, forward_function): 150 | def _transform(v, op): 151 | if self.precision != 'single': v = v.float() 152 | 153 | v2np = v.data.cpu().numpy() 154 | if op == 'v': 155 | tfnp = v2np[:, :, :, ::-1].copy() 156 | elif op == 'h': 157 | tfnp = v2np[:, :, ::-1, :].copy() 158 | elif op == 't': 159 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 160 | 161 | ret = torch.Tensor(tfnp).to(self.device) 162 | if self.precision == 'half': ret = ret.half() 163 | 164 | return ret 165 | 166 | lr_list = [x] 167 | for tf in 'v', 'h', 't': 168 | lr_list.extend([_transform(t, tf) for t in lr_list]) 169 | 170 | sr_list = [forward_function(aug) for aug in lr_list] 171 | for i in range(len(sr_list)): 172 | if i > 3: 173 | sr_list[i] = _transform(sr_list[i], 't') 174 | if i % 4 > 1: 175 | sr_list[i] = _transform(sr_list[i], 'h') 176 | if (i % 4) % 2 == 1: 177 | sr_list[i] = _transform(sr_list[i], 'v') 178 | 179 | output_cat = torch.cat(sr_list, dim=0) 180 | output = output_cat.mean(dim=0, keepdim=True) 181 | 182 | return output -------------------------------------------------------------------------------- /model/__pycache__/CFIN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFIN.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/CFINx2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx2.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/CFINx3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx3.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/CFINx4.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/CFINx4.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/MSDNN_LW1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MSDNN_LW1.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/MsDNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MsDNN.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/MultiAdd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MultiAdd.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/MultiAdd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/MultiAdd.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/msfin3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/msfin3.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/no_transmy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/no_transmy.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/rebuttal_updown1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/rebuttal_updown1.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy3.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy5.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy5_0.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy5_0.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy_j.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_j.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy_j.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_j.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/transmy_jx4.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/__pycache__/transmy_jx4.cpython-39.pyc -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVIPLab/CFIN/3a57b1c904a7cbafd6a59013cc871bdd18fcfd66/model/common.py -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utility 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser(description='DRN') 6 | 7 | parser.add_argument('--n_threads', type=int, default=8, 8 | help='number of threads for data loading') 9 | parser.add_argument('--cpu', action='store_true', 10 | help='use cpu only') 11 | parser.add_argument('--n_GPUs', type=int, default=1, 12 | help='number of GPUs') 13 | parser.add_argument('--seed', type=int, default=1, 14 | help='random seed') 15 | parser.add_argument('--data_dir', type=str, default='/home2/wenjieli/datasets', 16 | help='dataset directory') 17 | 18 | parser.add_argument('--data_train', type=str, default='DIV2K', 19 | help='train dataset name') 20 | parser.add_argument('--data_test', type=str, default='Set5', 21 | help='test dataset name') 22 | parser.add_argument('--data_range', type=str, default='1-800/896-900', 23 | help='train/test data range') 24 | parser.add_argument('--scale', type=int, default=2, 25 | help='super resolution scale') 26 | parser.add_argument('--patch_size', type=int, default=192, #192,144,96 27 | help='output patch size') 28 | parser.add_argument('--rgb_range', type=int, default=255, 29 | help='maximum value of RGB') 30 | parser.add_argument('--n_colors', type=int, default=3, 31 | help='number of color channels to use') 32 | parser.add_argument('--no_augment', action='store_true', 33 | help='do not use data augmentation') 34 | parser.add_argument('--model', help='model name: DRN-S | DRN-L', required=True) 35 | parser.add_argument('--pre_train', type=str, default='.', 36 | help='pre-trained model directory') 37 | parser.add_argument('--pre_train_dual', type=str, default='.', 38 | help='pre-trained dual model directory') 39 | parser.add_argument('--n_blocks', type=int, default=30, 40 | help='number of residual blocks, 16|30|40|80') 41 | parser.add_argument('--n_feats', type=int, default=20, 42 | help='number of feature maps') 43 | 44 | parser.add_argument('--chop', action='store_true', 45 | help='enable memory-efficient forward') 46 | parser.add_argument('--precision', type=str, default='single', 47 | choices=('single', 'half'), 48 | help='FP precision for test (single | half)') 49 | 50 | parser.add_argument('--num_steps', type=int, default=10, 51 | help='number of RCAB') 52 | parser.add_argument('--negval', type=float, default=0.2, 53 | help='Negative value parameter for Leaky ReLU') 54 | parser.add_argument('--test_every', type=int, default=1000, 55 | help='do test per every N batches') 56 | parser.add_argument('--epochs', type=int, default=1000, 57 | help='number of epochs to train') 58 | parser.add_argument('--batch_size', type=int, default=16, 59 | help='input batch size for training') 60 | parser.add_argument('--self_ensemble', action='store_true', 61 | help='use self-ensemble method for test') 62 | parser.add_argument('--test_only', action='store_true', 63 | help='set this option to test the model') 64 | parser.add_argument('--lr', type=float, default=5e-4, 65 | help='learning rate') 66 | parser.add_argument('--eta_min', type=float, default=1e-7, 67 | help='eta_min lr') 68 | 69 | parser.add_argument('--beta1', type=float, default=0.9, 70 | help='ADAM beta1') 71 | parser.add_argument('--beta2', type=float, default=0.999, 72 | help='ADAM beta2') 73 | parser.add_argument('--epsilon', type=float, default=1e-8, 74 | help='ADAM epsilon for numerical stability') 75 | parser.add_argument('--weight_decay', type=float, default=0, 76 | help='weight decay') 77 | parser.add_argument('--loss', type=str, default='1*L1', 78 | help='loss function configuration, L1|MSE') 79 | parser.add_argument('--skip_threshold', type=float, default='1e6', 80 | help='skipping batch that has large error') 81 | parser.add_argument('--dual_weight', type=float, default=0.1, 82 | help='the weight of dual loss') 83 | parser.add_argument('--save', type=str, default='./experiment/test/', 84 | help='file name to save') 85 | parser.add_argument('--print_every', type=int, default=100, 86 | help='how many batches to wait before logging training status') 87 | parser.add_argument('--save_results', action='store_true', 88 | help='save output results') 89 | 90 | parser.add_argument('--load', type=str, default='.', 91 | help='file name to load') 92 | parser.add_argument('--resume', type=int, default=0, 93 | help='resume from specific checkpoint') 94 | parser.add_argument("--ext", type=str, default='.npy') 95 | parser.add_argument("--n_train", type=int, default=800, 96 | help="number of training set") 97 | parser.add_argument("--cuda", action="store_true", default=True, 98 | help="use cuda") 99 | 100 | 101 | args = parser.parse_args() 102 | 103 | utility.init_model(args) 104 | 105 | args.scale = [args.scale] 106 | 107 | for arg in vars(args): 108 | if vars(args)[arg] == 'True': 109 | vars(args)[arg] = True 110 | elif vars(args)[arg] == 'False': 111 | vars(args)[arg] = False 112 | 113 | -------------------------------------------------------------------------------- /test_summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsummaryX import summary 3 | import model.MultiAdd as MultiAdd 4 | model = MultiAdd.MODEL() 5 | 6 | summary(model, torch.zeros((1, 3, 320, 180)))#HR:1280 x 720 7 | 8 | # input LR x2, HR size is 720p 9 | # summary(model, torch.zeros((1, 3, 640, 360))) 10 | 11 | # input LR x3, HR size is 720p 12 | # summary(model, torch.zeros((1, 3, 426, 240))) 13 | 14 | # input LR x4, HR size is 720p 15 | # summary(model, torch.zeros((1, 3, 320, 180))) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utility 3 | from decimal import Decimal 4 | from tqdm import tqdm 5 | 6 | class Trainer(): 7 | def __init__(self, opt, loader, my_model, my_loss, ckp): 8 | self.opt = opt 9 | self.scale = opt.scale 10 | self.ckp = ckp 11 | self.loader_train = loader.loader_train 12 | self.loader_test = loader.loader_test 13 | self.model = my_model 14 | self.loss = my_loss 15 | self.optimizer = utility.make_optimizer(opt, self.model) 16 | self.scheduler = utility.make_scheduler(opt, self.optimizer) 17 | self.error_last = 1e8 18 | 19 | def train(self): 20 | epoch = self.scheduler.last_epoch + 1 21 | lr = self.scheduler.get_lr()[0] 22 | 23 | self.ckp.write_log( 24 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 25 | ) 26 | self.loss.start_log() 27 | self.model.train() 28 | timer_data, timer_model = utility.timer(), utility.timer() 29 | for batch, (lr, hr, _) in enumerate(self.loader_train): 30 | lr, hr = self.prepare(lr, hr)#GPU 31 | timer_data.hold() 32 | timer_model.tic() 33 | 34 | self.optimizer.zero_grad() 35 | 36 | sr = self.model(lr[0]) 37 | loss = self.loss(sr, hr) 38 | 39 | if loss.item() < self.opt.skip_threshold * self.error_last: 40 | loss.backward() 41 | self.optimizer.step() 42 | else: 43 | print('Skip this batch {}! (Loss: {})'.format( 44 | batch + 1, loss.item() 45 | )) 46 | 47 | timer_model.hold() 48 | 49 | if (batch + 1) % self.opt.print_every == 0: 50 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 51 | (batch + 1) * self.opt.batch_size, 52 | len(self.loader_train.dataset), 53 | self.loss.display_loss(batch), 54 | timer_model.release(), 55 | timer_data.release())) 56 | 57 | timer_data.tic() 58 | 59 | self.loss.end_log(len(self.loader_train)) 60 | self.error_last = self.loss.log[-1, -1] 61 | self.step() 62 | 63 | def test(self): 64 | epoch = self.scheduler.last_epoch 65 | self.ckp.write_log('\nEvaluation:') 66 | self.ckp.add_log(torch.zeros(1, 1)) 67 | self.model.eval() 68 | 69 | timer_test = utility.timer() 70 | with torch.no_grad(): 71 | scale = max(self.scale) 72 | for si, s in enumerate([scale]): 73 | eval_psnr = 0 74 | tqdm_test = tqdm(self.loader_test, ncols=80) 75 | for _, (lr, hr, filename) in enumerate(tqdm_test): 76 | filename = filename[0] 77 | no_eval = (hr.nelement() == 1) 78 | if not no_eval: 79 | lr, hr = self.prepare(lr, hr) 80 | else: 81 | lr, = self.prepare(lr) 82 | 83 | sr = self.model(lr[0]) 84 | if isinstance(sr, list): sr = sr[-1] 85 | 86 | sr = utility.quantize(sr, self.opt.rgb_range) 87 | 88 | if not no_eval: 89 | eval_psnr += utility.calc_psnr( 90 | sr, hr, s, self.opt.rgb_range, 91 | benchmark=self.loader_test.dataset.benchmark 92 | ) 93 | 94 | # save test results 95 | if self.opt.save_results: 96 | self.ckp.save_results_nopostfix(filename, sr, s) 97 | 98 | self.ckp.log[-1, si] = eval_psnr / len(self.loader_test) 99 | best = self.ckp.log.max(0) 100 | self.ckp.write_log( 101 | '[{} x{}]\tPSNR: {:.2f} (Best: {:.2f} @epoch {})'.format( 102 | self.opt.data_test, s, 103 | self.ckp.log[-1, si], 104 | best[0][si], 105 | best[1][si] + 1 106 | ) 107 | ) 108 | 109 | self.ckp.write_log( 110 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 111 | ) 112 | if not self.opt.test_only: 113 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) 114 | 115 | def step(self): 116 | self.scheduler.step() 117 | 118 | def prepare(self, *args): 119 | device = torch.device('cpu' if self.opt.cpu else 'cuda') 120 | 121 | if len(args)>1: 122 | return [a.to(device) for a in args[0]], args[-1].to(device) 123 | return [a.to(device) for a in args[0]], 124 | 125 | def terminate(self): 126 | if self.opt.test_only: 127 | self.test() 128 | return True 129 | else: 130 | epoch = self.scheduler.last_epoch 131 | return epoch >= self.opt.epochs 132 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lrs 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | if torch.cuda.device_count() == 1: 15 | torch.cuda.manual_seed(seed) 16 | else: 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | class timer(): 21 | def __init__(self): 22 | self.acc = 0 23 | self.tic() 24 | 25 | def tic(self): 26 | self.t0 = time.time() 27 | 28 | def toc(self): 29 | return time.time() - self.t0 30 | 31 | def hold(self): 32 | self.acc += self.toc() 33 | 34 | def release(self): 35 | ret = self.acc 36 | self.acc = 0 37 | 38 | return ret 39 | 40 | def reset(self): 41 | self.acc = 0 42 | 43 | 44 | def quantize(img, rgb_range): 45 | pixel_range = 255 / rgb_range 46 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 47 | 48 | 49 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 50 | if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1): 51 | print("the dimention of sr image is not equal to hr's! ") 52 | sr = sr[:,:,:hr.size(-2),:hr.size(-1)] 53 | diff = (sr - hr).data.div(rgb_range) 54 | 55 | if benchmark: 56 | shave = scale 57 | if diff.size(1) > 1: 58 | convert = diff.new(1, 3, 1, 1) 59 | convert[0, 0, 0, 0] = 65.738 60 | convert[0, 1, 0, 0] = 129.057 61 | convert[0, 2, 0, 0] = 25.064 62 | diff.mul_(convert).div_(256) 63 | diff = diff.sum(dim=1, keepdim=True) 64 | else: 65 | shave = scale + 6 66 | 67 | valid = diff[:, :, shave:-shave, shave:-shave] 68 | mse = valid.pow(2).mean() 69 | 70 | return -10 * math.log10(mse) 71 | 72 | 73 | def make_optimizer(opt, my_model): 74 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 75 | optimizer_function = optim.Adam 76 | kwargs = { 77 | 'betas': (opt.beta1, opt.beta2), 78 | 'eps': opt.epsilon 79 | } 80 | kwargs['lr'] = opt.lr 81 | kwargs['weight_decay'] = opt.weight_decay 82 | 83 | return optimizer_function(trainable, **kwargs) 84 | 85 | 86 | def make_scheduler(opt, my_optimizer): 87 | scheduler = lrs.CosineAnnealingLR( 88 | my_optimizer, 89 | float(opt.epochs), 90 | eta_min=opt.eta_min 91 | ) 92 | 93 | return scheduler 94 | 95 | 96 | 97 | def init_model(args): 98 | ''' 99 | if args.model.find('MSFIN3') >= 0: 100 | if args.scale == 4: 101 | args.num_steps = 1 102 | args.n_feats = 24 103 | args.patch_size = 192 104 | elif args.scale == 8: 105 | args.n_blocks = 30 106 | args.n_feats = 8 107 | else: 108 | print('Use defaults n_blocks and n_feats.') 109 | # args.dual = True 110 | ''' 111 | if args.model.find('TRANSMY5') >= 0: 112 | if args.scale == 4: 113 | args.num_steps = 1 114 | args.n_feats = 32 115 | args.patch_size = 192 116 | elif args.scale == 8: 117 | args.n_blocks = 30 118 | args.n_feats = 8 119 | else: 120 | print('Use defaults n_blocks and n_feats.') 121 | 122 | --------------------------------------------------------------------------------