├── .gitignore ├── README.md ├── images ├── comp.png └── repcam++.png └── src ├── __init__.py ├── __pycache__ ├── option.cpython-39.pyc ├── template.cpython-39.pyc ├── trainer_cafm.cpython-39.pyc └── utility.cpython-39.pyc ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── common.cpython-39.pyc │ ├── div2k.cpython-39.pyc │ └── srdata.cpython-39.pyc ├── benchmark.py ├── common.py ├── demo.py ├── div2k.py ├── div2kjpeg.py ├── srdata.py └── video.py ├── dataloader.py ├── loss ├── __init__.py ├── __pycache__ │ └── __init__.cpython-39.pyc ├── adversarial.py ├── discriminator.py └── vgg.py ├── main.py ├── model ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── common.cpython-39.pyc │ ├── common_m0.cpython-39.pyc │ ├── edsr.cpython-39.pyc │ ├── edsr_m0.cpython-39.pyc │ ├── espcn.cpython-39.pyc │ └── espcn_m0.cpython-39.pyc ├── common.py ├── common_m0.py ├── edsr.py ├── edsr_m0 copy.py ├── edsr_m0.py ├── espcn.py ├── espcn_chunked.py ├── espcn_lf.py ├── espcn_m0.py ├── espcn_mdf.py ├── espcn_mdf_m0.py └── fix_patch_prompt.py ├── option.py ├── reparameter_edsr.py ├── reparameter_edsr_legacy.py ├── reparameter_espcn.py ├── template.py ├── train_bash_demo ├── demo_train_M1-n.sh └── demo_train_S1-n.sh ├── trainer_cafm.py ├── utility.py └── videotester.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.txt 3 | **experiment**/ 4 | **/vsd**/ 5 | **/**__pycache__**/ 6 | *.pyc 7 | lol*/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RepCaM++: Exploring Transparent Visual Prompt with Inference-time Re-parameterization for Neural Video Delivery 2 | ![Python 3.8](https://img.shields.io/badge/Python-3.8-blue) 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://www.computer.org/csdl/journal/tm/5555/01/10949820/25DZuw4IHTy) 4 | 5 | 6 | 7 | ## News! 8 | The extension version of RepCaM has been accepted by Transaction on Mobile Computing! 9 | 10 | * __For training RepCaM++ model__: 11 | ``` 12 | cd src 13 | CUDA_VISIBLE_DEVICES=3 python main.py \ 14 | --model EDSR \ 15 | --scale 2 \ 16 | --n_resblocks 16 \ 17 | --patch_size 48 \ 18 | --save edsr_sport_lr5e-5_chunked --patch_lr 5e-5 --reset --data_train DIV2K --data_test DIV2K --data_range 1-450/451-495 \ 19 | --dir_data /dir/to/data \ 20 | --batch_size 64 \ 21 | --epoch 600 \ 22 | --decay 300 \ 23 | --segnum 9 --is45s --use_cafm --std 0.1 24 | ``` 25 | 26 | * __For testing RepCaM++ model__: 27 | ``` 28 | python reparameter_edsr.py --model_folder experiment/edsr_dance_lr5e-5_chunked/model --n_res_blocks 16 29 | CUDA_VISIBLE_DEVICES=1 python main.py --data_test DIV2K --scale 2 --model EDSR_M0 \ 30 | --test_only --save_gt --save_results --save edsr_dance_lr5e-5_chunked \ 31 | --pre_train /home/rongyu/2opt/experiment/edsr_dance_lr5e-5_chunked/model/model_rep.pt \ 32 | --data_range 1-450 --is45s --dir_data /home/rongyu/dataset/vsd4k/dance_45s_1 \ 33 | --segnum 9 --use_cafm \ 34 | --patch_load /home/rongyu/2opt/experiment/edsr_dance_lr5e-5_chunked/model/patches\ 35 | --n_resblocks 16 36 | ``` 37 | 38 | # RepCaM: Re-parameterization Content-aware Modulation for Neural Video Delivery 39 | ![Python 3.8](https://img.shields.io/badge/Python-3.8-blue) 40 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://dl.acm.org/doi/pdf/10.1145/3592473.3592567) 41 | 42 | 43 | 44 | ## Introduction of dataset VSD4K and VSD4K-2023 45 | Our dataset VSD4K includes 6 popular categories: game, sport, dance, vlog, interview and city. Each category is consisted of various video length, including: 15s, 30s, 45s, etc. For a specific category and its specific video length, there are 3 scaling factors: x2, x3 and x4. In each file, there are HR images and its corresponding LR images. 1-n are training images , n - (n + n/10) are test images. (we select test image 1 out of 10). The VSD4K dataset can be obtained from [https://pan.baidu.com/s/14pcsC7taB4VAa3jvyw1kog] (password:u1qq) and google drive [https://drive.google.com/drive/folders/17fyX-bFc0IUp6LTIfTYU8R5_Ot79WKXC?usp=sharing]. The VSD4K-2023 dataset can be obtained from [https://pan.baidu.com/s/1mNJuKnCfYzd1q6PsyO1b8Q?pwd=d4a0] (password:d4a0) 46 | 47 | ``` 48 | e.g.:game 15s 49 | dataroot_gt: VSD4K/game/game_15s_1/DIV2K_train_HR/00001.png 50 | dataroot_lqx2: VSD4K/game/game_15s_1/DIV2K_train_LR_bicubic/X2/00001_x2.png 51 | dataroot_lqx3: VSD4K/game/game_15s_1/DIV2K_train_LR_bicubic/X3/00001_x3.png 52 | dataroot_lqx4: VSD4K/game/game_15s_1/DIV2K_train_LR_bicubic/X4/00001_x4.png 53 | ``` 54 | 55 | ## Dependencies 56 | * Python >= 3.6 57 | * Torch >= 1.0.0 58 | * opencv-python 59 | * numpy 60 | * skimage 61 | * imageio 62 | * matplotlib 63 | ## Quickstart 64 | M0 demotes the model without RepCaM module which is trained on the whole dataset. S{1-n} denotes n models that trained on n chunks of video. M{1-n} demotes one model along with n RepCaM modules that trained on the whole dataset. __M{1-n} is our proposed method__. 65 | 66 | 67 | ### How to set data_range 68 | n is the total frames in a video. We select one test image out of 10 training images. Thus, in VSD4K, 1-n is its training dataset, n-(n+/10) is the test dataset. Generally, we set 5s as the length of one chunk. Hence, 15s consists 3 chunks, 30s consists 6 chunks, etc. 69 | | Video length(train images + test images) | chunks | M0/M{1-n} | S1 | S2 | S3 | S4 | S5 | S6 | S7 | S8 | S9 | 70 | | :---: | :---: | :---: | :----: | :---: | :---: | :---: | :---: | :----: | :---: | :---: | :---: | 71 | | 15s(450+45) | 3 | 1-450/451-495 | 1-150/451-465 | 151-300/466-480 | 301-450/481-495 | - | - | - | - | - | - | 72 | | 30s(900+95) | 6 | 1-900/901-990 | 1-150/901-915 | 151-300/916-930 | 301-450/931-945 | 451-600/946-960 | 601-750/961-975 | 751-900/976-990 | - | - | - | 73 | | 45s(1350+135) | 9 | 1-1350/1351-1485 | 1-150/1351-1365 | 151-300/1366-1380 | 301-450/1381-1395 | 451-600/1396-1410 | 601-750/1411-1425 | 751-900/1426-1440 | 901-1050/1441-1455 | 1051-1200/1456-1470 | 1201-1350/1471-1485 | 74 | 75 | ### Train(version with TVP) 76 | 77 | See `example_TVP.txt`. 78 | 79 | ### Train(version without VPS) 80 | For simplicity, we only demonstrate how to train 'game_15s' by our method. 81 | 82 | * __For M{1-n} model__: 83 | ``` 84 | CUDA_VISIBLE_DEVICES=3 python main.py --model {EDSR/ESPCN/VDSRR/SRCNN/RCAN} --scale {scale factor} --patch_size {patch size} --save {name of the trained model} --reset --data_train DIV2K --data_test DIV2K --data_range {train_range}/{test_range} --dir_data {path of data} --batch_size {batch size} --epoch {epoch} --decay {decay} --segnum {numbers of chunk} --length 85 | ``` 86 | ``` 87 | e.g. 88 | CUDA_VISIBLE_DEVICES=3 python main.py --model EDSR --scale 2 --patch_size 48 --save trainm1_n --reset --data_train DIV2K --data_test DIV2K --data_range 1-450/451-495 --dir_data /home/datasets/VSD4K/game/game_15s_1 --batch_size 64 --epoch 500 --decay 300 --segnum 3 --is15s 89 | ``` 90 | 91 | You can apply our method on your own images. Place your HR images under YOURS/DIV2K_train_HR/, with the name start from 00001.png. 92 | Place your corresponding LR images under YOURS/DIV2K_train_LR_bicubic/X2, with the name start from 00001_x2.png. 93 | ``` 94 | e.g.: 95 | dataroot_gt: YOURS/DIV2K_train_HR/00001.png 96 | dataroot_lqx2: YOURS/DIV2K_train_LR_bicubic/X2/00001_x2.png 97 | dataroot_lqx3: YOURS/DIV2K_train_LR_bicubic/X3/00001_x3.png 98 | dataroot_lqx4: YOURS/DIV2K_train_LR_bicubic/X4/00001_x4.png 99 | ``` 100 | * The running command is like: 101 | ``` 102 | CUDA_VISIBLE_DEVICES=3 python main.py --model {EDSR/ESPCN/VDSRR/SRCNN/RCAN} --scale {scale factor} --patch_size {patch size} --save {name of the trained model} --reset --data_train DIV2K --data_test DIV2K --data_range {train_range}/{test_range} --dir_data {path of data} --batch_size {batch size} --epoch {epoch} --decay {decay} --segnum {numbers of chunk} --length 103 | ``` 104 | 105 | * For example: 106 | ``` 107 | e.g. 108 | CUDA_VISIBLE_DEVICES=3 python main.py --model EDSR --scale 2 --patch_size 48 --save trainm1_n --reset --data_train DIV2K --data_test DIV2K --data_range 1-450/451-495 --dir_data /home/datasets/VSD4K/game/game_15s_1 --batch_size 64 --epoch 500 --decay 300 --segnum 3 --is15s 109 | ``` 110 | 111 | ### Reparameterization 112 | ``` 113 | e.g. 114 | CUDA_VISIBLE_DEVICES=3 python reparameter_{}.py(eder, espcn) 115 | ``` 116 | 117 | ### Test 118 | For simplicity, we only demonstrate how to run 'game' category of 15s. All pretrain models(15s, 30s, 45s) of game category can be found in this link [https://pan.baidu.com/s/1P18FULL7CIK1FAa2xW56AA] (passward:bjv1) and google drive link [https://drive.google.com/drive/folders/1_N64A75iwgbweDBk7dUUDX0SJffnK5-l?usp=sharing]. 119 | 120 | * __For M{1-n} model__: 121 | ``` 122 | CUDA_VISIBLE_DEVICES=3 python main.py --data_test DIV2K --scale {scale factor} --model {EDSR/ESPCN/VDSRR/SRCNN/RCAN} --test_only --pre_train {path to pretrained model} --data_range {train_range} --{is15s/is30s/is45s} --dir_data {path of data} --segnum 3 123 | ``` 124 | ``` 125 | e.g.: 126 | CUDA_VISIBLE_DEVICES=3 python main.py --data_test DIV2K --scale 4 --model EDSR_M0 --test_only --pre_train /home/CaFM-pytorch/experiment/edsr_x2_p48_game_15s_1_seg1-3_batch64_k1_g64/model/model_rep.pt --data_range 1-150 --is15s --dir_data /home/datasets/VSD4K/game/game_15s_1 --segnum 3 127 | ``` 128 | ## Citation 129 | Please cite our work if you find it useful. 130 | ```bibtex 131 | @article{zhang2025repcam++, 132 | title={RepCaM++: Exploring Transparent Visual Prompt With Inference-Time Re-Parameterization for Neural Video Delivery}, 133 | author={Zhang, Rongyu and Duan, Xize and Liu, Jiaming and Du, Li and Du, Yuan and Wang, Dan and Zhang, Shanghang and Wang, Fangxin}, 134 | journal={IEEE Transactions on Mobile Computing}, 135 | year={2025}, 136 | publisher={IEEE} 137 | } 138 | ``` 139 | ```bibtex 140 | @inproceedings{zhang2023repcam, 141 | title={RepCaM: Re-parameterization Content-aware Modulation for Neural Video Delivery}, 142 | author={Zhang, Rongyu and Du, Lixuan and Liu, Jiaming and Song, Congcong and Wang, Fangxin and Li, Xiaoqi and Lu, Ming and Guo, Yandong and Zhang, Shanghang}, 143 | booktitle={Proceedings of the 33rd Workshop on Network and Operating System Support for Digital Audio and Video}, 144 | pages={1--7}, 145 | year={2023} 146 | } 147 | ``` 148 | 149 | ## Acknowledgment 150 | 151 | AdaFM proposed a closely related method for continual modulation of restoration levels. While they aimed to handle arbitrary restoration levels between a start and an end level, our goal is to compress the models of different chunks for video delivery. The reader is encouraged to review their work for more details. Please also consider to cite AdaFM if you use the code. [https://github.com/hejingwenhejingwen/AdaFM] 152 | -------------------------------------------------------------------------------- /images/comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/images/comp.png -------------------------------------------------------------------------------- /images/repcam++.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/images/repcam++.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/option.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/__pycache__/option.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/template.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/__pycache__/template.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/trainer_cafm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/__pycache__/trainer_cafm.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/utility.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/__pycache__/utility.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | # This is a simple wrapper function for ConcatDataset 10 | class MyConcatDataset(ConcatDataset): 11 | def __init__(self, datasets): 12 | super(MyConcatDataset, self).__init__(datasets) 13 | self.train = datasets[0].train 14 | print("=====================") 15 | 16 | def set_scale(self, idx_scale): 17 | for d in self.datasets: 18 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 19 | 20 | class Data: 21 | def __init__(self, args): 22 | self.loader_train = None 23 | if not args.test_only: 24 | datasets = [] 25 | for d in args.data_train: 26 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 27 | m = import_module('data.' + module_name.lower()) #求求你们了,多写一个if不会死 28 | datasets.append(getattr(m, module_name)(args, name=d)) 29 | if not args.chunked: 30 | self.loader_train = dataloader.DataLoader( 31 | MyConcatDataset(datasets), 32 | batch_size=args.batch_size, 33 | shuffle=False, 34 | pin_memory=not args.cpu, 35 | num_workers=args.n_threads, 36 | ) 37 | else: 38 | self.loader_train = dataloader.DataLoader( 39 | MyConcatDataset(datasets), 40 | batch_size=args.chunk_size, 41 | shuffle=False, 42 | pin_memory=not args.cpu, 43 | num_workers=args.n_threads, 44 | ) 45 | print(len(self.loader_train)) 46 | 47 | 48 | 49 | self.loader_test = [] 50 | for d in args.data_test: 51 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 52 | m = import_module('data.benchmark') 53 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 54 | else: 55 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 56 | m = import_module('data.' + module_name.lower()) 57 | testset = getattr(m, module_name)(args, train=False, name=d) 58 | 59 | self.loader_test.append( 60 | dataloader.DataLoader( 61 | testset, 62 | batch_size=1, 63 | shuffle=False, 64 | pin_memory=not args.cpu, 65 | num_workers=args.n_threads, 66 | ) 67 | ) 68 | -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/data/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/div2k.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/data/__pycache__/div2k.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/srdata.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/data/__pycache__/srdata.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False, psnr_index=None, data_partion=0): 9 | 10 | if not input_large: 11 | p = scale if multi else 1 12 | tp = p * patch_size 13 | ip = tp // scale 14 | else: 15 | tp = patch_size 16 | ip = patch_size 17 | 18 | if psnr_index is not None: 19 | n_patch = int(psnr_index.shape[0]*data_partion) 20 | index = random.randrange(0, n_patch + 1) 21 | ix = int(psnr_index[index][0]) 22 | iy = int(psnr_index[index][1]) 23 | else: 24 | ih, iw = args[0].shape[:2] 25 | ix = random.randrange(0, iw - ip + 1) 26 | iy = random.randrange(0, ih - ip + 1) 27 | 28 | if not input_large: 29 | tx, ty = scale * ix, scale * iy 30 | else: 31 | tx, ty = ix, iy 32 | 33 | ret = [ 34 | args[0][iy:iy + ip, ix:ix + ip, :], 35 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 36 | ] 37 | 38 | return ret 39 | 40 | def set_channel(*args, n_channels=3): 41 | def _set_channel(img): 42 | if img.ndim == 2: 43 | img = np.expand_dims(img, axis=2) 44 | 45 | c = img.shape[2] 46 | if n_channels == 1 and c == 3: 47 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 48 | elif n_channels == 3 and c == 1: 49 | img = np.concatenate([img] * n_channels, 2) 50 | 51 | return img 52 | 53 | return [_set_channel(a) for a in args] 54 | 55 | def np2Tensor(*args, rgb_range=255): 56 | def _np2Tensor(img): 57 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 58 | tensor = torch.from_numpy(np_transpose).float() 59 | tensor.mul_(rgb_range / 255) 60 | 61 | return tensor 62 | 63 | return [_np2Tensor(a) for a in args] 64 | 65 | def augment(*args, hflip=True, rot=True): 66 | hflip = hflip and random.random() < 0.5 67 | vflip = rot and random.random() < 0.5 68 | rot90 = rot and random.random() < 0.5 69 | 70 | def _augment(img): 71 | if hflip: img = img[:, ::-1, :] 72 | if vflip: img = img[::-1, :, :] 73 | if rot90: img = img.transpose(1, 0, 2) 74 | 75 | return img 76 | 77 | return [_augment(a) for a in args] 78 | 79 | -------------------------------------------------------------------------------- /src/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='ITW2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | self.dir_data = args.dir_data 20 | 21 | def _scan(self): 22 | names_hr, names_lr = super(DIV2K, self)._scan() 23 | names_hr = names_hr[self.begin - 1:self.end] 24 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 25 | 26 | return names_hr, names_lr 27 | 28 | def _set_filesystem(self, dir_data): 29 | super(DIV2K, self)._set_filesystem(dir_data) 30 | self.apath = dir_data 31 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 32 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 33 | 34 | if self.input_large: self.dir_lr += 'L' 35 | 36 | -------------------------------------------------------------------------------- /src/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # Below functions as used to prepare images 68 | def _scan(self): 69 | names_hr = sorted( 70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | ) 72 | names_lr = [[] for _ in self.scale] 73 | for f in names_hr: 74 | filename, _ = os.path.splitext(os.path.basename(f)) 75 | for si, s in enumerate(self.scale): 76 | names_lr[si].append(os.path.join( 77 | self.dir_lr, 'X{}/{}x{}{}'.format( 78 | s, filename, s, self.ext[1] 79 | ) 80 | )) 81 | 82 | return names_hr, names_lr 83 | 84 | def _set_filesystem(self, dir_data): 85 | self.apath = os.path.join(dir_data, self.name) 86 | self.dir_hr = os.path.join(self.apath, 'HR') 87 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 88 | if self.input_large: self.dir_lr += 'L' 89 | self.ext = ('.png', '.png') 90 | 91 | def _check_and_load(self, ext, img, f, verbose=True): 92 | if not os.path.isfile(f) or ext.find('reset') >= 0: 93 | if verbose: 94 | print('\rMaking a binary: {}'.format(f), end='') 95 | print() 96 | with open(f, 'wb') as _f: 97 | pickle.dump(imageio.imread(img), _f) 98 | 99 | def __getitem__(self, idx): 100 | lr, hr, filename = self._load_file(idx) 101 | pair = self.get_patch(lr, hr) 102 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 103 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 104 | return pair_t[0], pair_t[1], filename 105 | 106 | def __len__(self): 107 | if self.train: 108 | return len(self.images_hr) * self.repeat 109 | else: 110 | return len(self.images_hr) 111 | 112 | def _get_index(self, idx): 113 | if self.train: 114 | return idx % len(self.images_hr) 115 | else: 116 | return idx 117 | 118 | def _load_file(self, idx): 119 | idx = self._get_index(idx) 120 | f_hr = self.images_hr[idx] 121 | f_lr = self.images_lr[self.idx_scale][idx] 122 | 123 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 124 | if self.args.ext == 'img' or self.benchmark: 125 | hr = imageio.imread(f_hr) 126 | lr = imageio.imread(f_lr) 127 | elif self.args.ext.find('sep') >= 0: 128 | with open(f_hr, 'rb') as _f: 129 | hr = pickle.load(_f) 130 | with open(f_lr, 'rb') as _f: 131 | lr = pickle.load(_f) 132 | 133 | return lr, hr, filename 134 | 135 | def get_patch(self, lr, hr): 136 | scale = self.scale[self.idx_scale] 137 | if self.train: 138 | lr, hr = common.get_patch( 139 | lr, hr, 140 | patch_size=self.args.patch_size, 141 | scale=scale, 142 | multi=(len(self.scale) > 1), 143 | input_large=self.input_large 144 | ) 145 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 146 | else: 147 | ih, iw = lr.shape[:2] 148 | hr = hr[0:ih * scale, 0:iw * scale] 149 | 150 | return lr, hr 151 | 152 | def set_scale(self, idx_scale): 153 | if not self.input_large: 154 | self.idx_scale = idx_scale 155 | else: 156 | self.idx_scale = random.randint(0, len(self.scale) - 1) 157 | 158 | -------------------------------------------------------------------------------- /src/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | from option import args 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class Loss(nn.modules.loss._Loss): 16 | def __init__(self, args, ckp): 17 | super(Loss, self).__init__() 18 | print('Preparing loss function:') 19 | 20 | self.n_GPUs = args.n_GPUs 21 | self.loss = [] 22 | self.loss_module = nn.ModuleList() 23 | for loss in args.loss.split('+'): 24 | weight, loss_type = loss.split('*') 25 | if loss_type == 'MSE': 26 | loss_function = nn.MSELoss() 27 | elif loss_type == 'L1': 28 | loss_function = nn.L1Loss() 29 | elif loss_type == 'L12': 30 | loss_function = nn.L1Loss() 31 | elif loss_type.find('VGG') >= 0: 32 | module = import_module('loss.vgg') 33 | loss_function = getattr(module, 'VGG')( 34 | loss_type[3:], 35 | rgb_range=args.rgb_range 36 | ) 37 | elif loss_type.find('GAN') >= 0: 38 | module = import_module('loss.adversarial') 39 | loss_function = getattr(module, 'Adversarial')( 40 | args, 41 | loss_type 42 | ) 43 | 44 | self.loss.append({ 45 | 'type': loss_type, 46 | 'weight': float(weight), 47 | 'function': loss_function} 48 | ) 49 | if loss_type.find('GAN') >= 0: 50 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 51 | 52 | if len(self.loss) > 1: 53 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 54 | 55 | for l in self.loss: 56 | if l['function'] is not None: 57 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 58 | self.loss_module.append(l['function']) 59 | 60 | self.log = torch.Tensor() 61 | 62 | device = torch.device('cpu' if args.cpu else 'cuda') 63 | self.loss_module.to(device) 64 | if args.precision == 'half': self.loss_module.half() 65 | if not args.cpu and args.n_GPUs > 1: 66 | self.loss_module = nn.DataParallel( 67 | self.loss_module, range(args.n_GPUs) 68 | ) 69 | 70 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 71 | 72 | if args.tcloss_v1: 73 | def forward(self, sr, hr, hr_r, sr_r, sr_r_g): 74 | losses = [] 75 | for i, l in enumerate(self.loss): 76 | if l['function'] is not None: 77 | if l['type'] == 'L1': 78 | loss = l['function'](sr, hr) 79 | else: 80 | loss = l['function'](sr_r, hr_r) 81 | loss1 = l['function'](sr_r_g, sr) 82 | loss = 0.5*loss + 0.5*loss1 83 | #loss = l['function'](sr, sr_r) 84 | effective_loss = l['weight'] * loss 85 | losses.append(effective_loss) 86 | self.log[-1, i] += effective_loss.item() 87 | elif l['type'] == 'DIS': 88 | self.log[-1, i] += self.loss[i - 1]['function'].loss 89 | 90 | loss_sum = sum(losses) 91 | if len(self.loss) > 1: 92 | self.log[-1, -1] += loss_sum.item() 93 | 94 | return loss_sum 95 | 96 | elif args.tcloss_v2: 97 | def forward(self, sr, hr, sr_r, sr_r_gt, sr_, sr_gt): 98 | losses = [] 99 | for i, l in enumerate(self.loss): 100 | if l['function'] is not None: 101 | if l['type'] == 'L1': 102 | loss = l['function'](sr, hr) 103 | else: 104 | loss = l['function'](sr_, sr_gt) 105 | loss1 = l['function'](sr_r, sr_r_gt) 106 | loss = 0.5*loss1+0.5*loss 107 | #loss = loss1 108 | effective_loss = l['weight'] * loss 109 | losses.append(effective_loss) 110 | self.log[-1, i] += effective_loss.item() 111 | elif l['type'] == 'DIS': 112 | self.log[-1, i] += self.loss[i - 1]['function'].loss 113 | 114 | loss_sum = sum(losses) 115 | if len(self.loss) > 1: 116 | self.log[-1, -1] += loss_sum.item() 117 | 118 | return loss_sum 119 | 120 | else: 121 | def forward(self, sr, hr): 122 | losses = [] 123 | for i, l in enumerate(self.loss): 124 | if l['function'] is not None: 125 | loss = l['function'](sr, hr) 126 | effective_loss = l['weight'] * loss 127 | losses.append(effective_loss) 128 | self.log[-1, i] += effective_loss.item() 129 | elif l['type'] == 'DIS': 130 | self.log[-1, i] += self.loss[i - 1]['function'].loss 131 | 132 | loss_sum = sum(losses) 133 | if len(self.loss) > 1: 134 | self.log[-1, -1] += loss_sum.item() 135 | 136 | return loss_sum 137 | 138 | def step(self): 139 | for l in self.get_loss_module(): 140 | if hasattr(l, 'scheduler'): 141 | l.scheduler.step() 142 | 143 | def start_log(self): 144 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 145 | 146 | def end_log(self, n_batches): 147 | self.log[-1].div_(n_batches) 148 | 149 | def display_loss(self, batch): 150 | n_samples = batch + 1 151 | log = [] 152 | for l, c in zip(self.loss, self.log[-1]): 153 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 154 | 155 | return ''.join(log) 156 | 157 | def plot_loss(self, apath, epoch): 158 | axis = np.linspace(1, epoch, epoch) 159 | for i, l in enumerate(self.loss): 160 | label = '{} Loss'.format(l['type']) 161 | fig = plt.figure() 162 | plt.title(label) 163 | plt.plot(axis, self.log[:, i].numpy(), label=label) 164 | plt.legend() 165 | plt.xlabel('Epochs') 166 | plt.ylabel('Loss') 167 | plt.grid(True) 168 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 169 | plt.close(fig) 170 | 171 | def get_loss_module(self): 172 | if self.n_GPUs == 1: 173 | return self.loss_module 174 | else: 175 | return self.loss_module.module 176 | 177 | def save(self, apath): 178 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 179 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 180 | 181 | def load(self, apath, cpu=False): 182 | if cpu: 183 | kwargs = {'map_location': lambda storage, loc: storage} 184 | else: 185 | kwargs = {} 186 | 187 | self.load_state_dict(torch.load( 188 | os.path.join(apath, 'loss.pt'), 189 | **kwargs 190 | )) 191 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 192 | for l in self.get_loss_module(): 193 | if hasattr(l, 'scheduler'): 194 | for _ in range(len(self.log)): l.scheduler.step() 195 | 196 | -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/loss/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if gan_type == 'WGAN_GP': 19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 20 | optim_dict = { 21 | 'optimizer': 'ADAM', 22 | 'betas': (0, 0.9), 23 | 'epsilon': 1e-8, 24 | 'lr': 1e-5, 25 | 'weight_decay': args.weight_decay, 26 | 'decay': args.decay, 27 | 'gamma': args.gamma 28 | } 29 | optim_args = SimpleNamespace(**optim_dict) 30 | else: 31 | optim_args = args 32 | 33 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 34 | 35 | def forward(self, fake, real): 36 | # updating discriminator... 37 | self.loss = 0 38 | fake_detach = fake.detach() # do not backpropagate through G 39 | for _ in range(self.gan_k): 40 | self.optimizer.zero_grad() 41 | # d: B x 1 tensor 42 | d_fake = self.dis(fake_detach) 43 | d_real = self.dis(real) 44 | retain_graph = False 45 | if self.gan_type == 'GAN': 46 | loss_d = self.bce(d_real, d_fake) 47 | elif self.gan_type.find('WGAN') >= 0: 48 | loss_d = (d_fake - d_real).mean() 49 | if self.gan_type.find('GP') >= 0: 50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 52 | hat.requires_grad = True 53 | d_hat = self.dis(hat) 54 | gradients = torch.autograd.grad( 55 | outputs=d_hat.sum(), inputs=hat, 56 | retain_graph=True, create_graph=True, only_inputs=True 57 | )[0] 58 | gradients = gradients.view(gradients.size(0), -1) 59 | gradient_norm = gradients.norm(2, dim=1) 60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 61 | loss_d += gradient_penalty 62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 63 | elif self.gan_type == 'RGAN': 64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 66 | loss_d = self.bce(better_real, better_fake) 67 | retain_graph = True 68 | 69 | # Discriminator update 70 | self.loss += loss_d.item() 71 | loss_d.backward(retain_graph=retain_graph) 72 | self.optimizer.step() 73 | 74 | if self.gan_type == 'WGAN': 75 | for p in self.dis.parameters(): 76 | p.data.clamp_(-1, 1) 77 | 78 | self.loss /= self.gan_k 79 | 80 | # updating generator... 81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 82 | if self.gan_type == 'GAN': 83 | label_real = torch.ones_like(d_fake_bp) 84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 85 | elif self.gan_type.find('WGAN') >= 0: 86 | loss_g = -d_fake_bp.mean() 87 | elif self.gan_type == 'RGAN': 88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 90 | loss_g = self.bce(better_fake, better_real) 91 | 92 | # Generator loss 93 | return loss_g 94 | 95 | def state_dict(self, *args, **kwargs): 96 | state_discriminator = self.dis.state_dict(*args, **kwargs) 97 | state_optimizer = self.optimizer.state_dict() 98 | 99 | return dict(**state_discriminator, **state_optimizer) 100 | 101 | def bce(self, real, fake): 102 | label_real = torch.ones_like(real) 103 | label_fake = torch.zeros_like(fake) 104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 106 | bce_loss = bce_real + bce_fake 107 | return bce_loss 108 | 109 | # Some references 110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 111 | # OR 112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 113 | -------------------------------------------------------------------------------- /src/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import torch.backends.cudnn as cudnn 5 | 6 | import utility 7 | import data 8 | import model 9 | import loss 10 | from option import args 11 | from trainer_cafm import Trainer_cafm 12 | 13 | import pdb 14 | torch.manual_seed(args.seed) 15 | checkpoint = utility.checkpoint(args) 16 | 17 | def get_model_size(model): 18 | param_count = sum(p.numel() for p in model.parameters()) 19 | 20 | param_size = param_count * 4 / (1024 * 1024) 21 | 22 | return { 23 | 'parameters': param_count, 24 | 'size_mb': param_size 25 | } 26 | 27 | def print_param_info(model): 28 | for name, param in model.named_parameters(): 29 | print(f"Parameter: {name}") 30 | print(f"Type: {param.dtype}") 31 | print(f"Shape: {param.shape}") 32 | print("-" * 50) 33 | 34 | def main(): 35 | global model 36 | if args.data_test == ['video']: 37 | from videotester import VideoTester 38 | model = model.Model(args, checkpoint) 39 | t = VideoTester(args, model, checkpoint) 40 | t.test() 41 | else: 42 | if checkpoint.ok: 43 | #pdb.set_trace() 44 | loader = data.Data(args) 45 | _model = model.Model(args, checkpoint) 46 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 47 | print(get_model_size(_model)) 48 | if args.cafm: 49 | t = Trainer_cafm(args, loader, _model, _loss, checkpoint) 50 | else: 51 | # print("u have to enter --cafm in command") 52 | # assert(0) 53 | t = Trainer_cafm(args, loader, _model, _loss, checkpoint) 54 | while not t.terminate(): 55 | t.train() 56 | t.test() 57 | checkpoint.done() 58 | 59 | if __name__ == '__main__': 60 | # U can change the random seed by yourself 61 | random.seed(0) 62 | np.random.seed(0) 63 | torch.manual_seed(0) 64 | torch.cuda.manual_seed(0) 65 | torch.cuda.manual_seed_all(0) 66 | cudnn.benchmark = False 67 | cudnn.deterministic = True 68 | main() 69 | -------------------------------------------------------------------------------- /src/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/.DS_Store -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | #from .fix_patch_prompt import FixedPatchPrompter_image, FixedPatchPrompter_feature 10 | 11 | class Model(nn.Module): 12 | def __init__(self, args, ckp): 13 | super(Model, self).__init__() 14 | print('Making model...') 15 | 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.input_large = (args.model == 'VDSR') 19 | self.self_ensemble = args.self_ensemble 20 | self.chop = args.chop 21 | self.precision = args.precision 22 | self.cpu = args.cpu 23 | self.device = torch.device('cpu' if args.cpu else 'cuda') 24 | self.n_GPUs = args.n_GPUs 25 | self.save_models = args.save_models 26 | 27 | module = import_module('model.' + args.model.lower()) 28 | self.model = module.make_model(args).to(self.device) 29 | if args.precision == 'half': 30 | self.model.half() 31 | 32 | self.load( 33 | ckp.get_path('model'), 34 | pre_train=args.pre_train, 35 | resume=args.resume, 36 | cpu=args.cpu 37 | ) 38 | print(self.model, file=ckp.log_file) 39 | 40 | 41 | ''' 42 | def forward(self, x, idx_scale, seg_flag): 43 | ''' 44 | def forward(self, x, idx_scale, num): 45 | self.idx_scale = idx_scale 46 | if hasattr(self.model, 'set_scale'): 47 | self.model.set_scale(idx_scale) 48 | 49 | if self.training: 50 | if self.n_GPUs > 1: 51 | ''' 52 | TODO:data_paraller 怎么传递多个参数 53 | ''' 54 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 55 | else: 56 | ''' 57 | return self.model(x, seg_flag) 58 | ''' 59 | 60 | return self.model(x, num) 61 | else: 62 | if self.chop: 63 | forward_function = self.forward_chop 64 | else: 65 | forward_function = self.model.forward 66 | 67 | if self.self_ensemble: 68 | return self.forward_x8(num, x, forward_function=forward_function) 69 | else: 70 | return forward_function(x, num) 71 | 72 | def save(self, apath, epoch, is_best=False): 73 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 74 | 75 | if is_best: 76 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 77 | if self.save_models: 78 | save_dirs.append( 79 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 80 | ) 81 | 82 | for s in save_dirs: 83 | torch.save(self.model.state_dict(), s) 84 | 85 | def save_every(self, apath, epoch, is_best=False): 86 | save_dirs = [] 87 | 88 | if is_best: 89 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 90 | if self.save_models: 91 | save_dirs.append( 92 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 93 | ) 94 | 95 | for s in save_dirs: 96 | torch.save(self.model.state_dict(), s) 97 | 98 | def load(self, apath, pre_train='', resume=-1, cpu=False): 99 | load_from = None 100 | kwargs = {} 101 | if cpu: 102 | kwargs = {'map_location': lambda storage, loc: storage} 103 | 104 | if resume == -1: 105 | load_from = torch.load( 106 | os.path.join(apath, 'model_latest.pt'), 107 | **kwargs 108 | ) 109 | elif resume == 0: 110 | if pre_train == 'download': 111 | print('Download the model') 112 | dir_model = os.path.join('..', 'models') 113 | os.makedirs(dir_model, exist_ok=True) 114 | load_from = torch.utils.model_zoo.load_url( 115 | self.model.url, 116 | model_dir=dir_model, 117 | **kwargs 118 | ) 119 | elif pre_train: 120 | print('Load the model from {}'.format(pre_train)) 121 | load_from = torch.load(pre_train, **kwargs) 122 | else: 123 | load_from = torch.load( 124 | os.path.join(apath, 'model_{}.pt'.format(resume)), 125 | **kwargs 126 | ) 127 | 128 | if load_from: 129 | self.model.load_state_dict(load_from, strict=False) 130 | 131 | def forward_chop(self, *args, shave=10, min_size=160000): 132 | scale = 1 if self.input_large else self.scale[self.idx_scale] 133 | n_GPUs = min(self.n_GPUs, 4) 134 | # height, width 135 | h, w = args[0].size()[-2:] 136 | 137 | top = slice(0, h//2 + shave) 138 | bottom = slice(h - h//2 - shave, h) 139 | left = slice(0, w//2 + shave) 140 | right = slice(w - w//2 - shave, w) 141 | x_chops = [torch.cat([ 142 | a[..., top, left], 143 | a[..., top, right], 144 | a[..., bottom, left], 145 | a[..., bottom, right] 146 | ]) for a in args] 147 | 148 | y_chops = [] 149 | if h * w < 4 * min_size: 150 | for i in range(0, 4, n_GPUs): 151 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 152 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 153 | if not isinstance(y, list): y = [y] 154 | if not y_chops: 155 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 156 | else: 157 | for y_chop, _y in zip(y_chops, y): 158 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 159 | else: 160 | for p in zip(*x_chops): 161 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 162 | if not isinstance(y, list): y = [y] 163 | if not y_chops: 164 | y_chops = [[_y] for _y in y] 165 | else: 166 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 167 | 168 | h *= scale 169 | w *= scale 170 | top = slice(0, h//2) 171 | bottom = slice(h - h//2, h) 172 | bottom_r = slice(h//2 - h, None) 173 | left = slice(0, w//2) 174 | right = slice(w - w//2, w) 175 | right_r = slice(w//2 - w, None) 176 | 177 | # batch size, number of color channels 178 | b, c = y_chops[0][0].size()[:-2] 179 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 180 | for y_chop, _y in zip(y_chops, y): 181 | _y[..., top, left] = y_chop[0][..., top, left] 182 | _y[..., top, right] = y_chop[1][..., top, right_r] 183 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 184 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 185 | 186 | if len(y) == 1: y = y[0] 187 | 188 | return y 189 | 190 | def forward_x8(self, num, *args, forward_function=None): 191 | def _transform(v, op): 192 | if self.precision != 'single': v = v.float() 193 | 194 | v2np = v.data.cpu().numpy() 195 | if op == 'v': 196 | tfnp = v2np[:, :, :, ::-1].copy() 197 | elif op == 'h': 198 | tfnp = v2np[:, :, ::-1, :].copy() 199 | elif op == 't': 200 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 201 | 202 | ret = torch.Tensor(tfnp).to(self.device) 203 | if self.precision == 'half': ret = ret.half() 204 | 205 | return ret 206 | 207 | list_x = [] 208 | for a in args: 209 | x = [a] 210 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 211 | 212 | list_x.append(x) 213 | 214 | list_y = [] 215 | for x in zip(*list_x): 216 | y = forward_function(*x, num) 217 | if not isinstance(y, list): y = [y] 218 | if not list_y: 219 | list_y = [[_y] for _y in y] 220 | else: 221 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 222 | 223 | for _list_y in list_y: 224 | for i in range(len(_list_y)): 225 | if i > 3: 226 | _list_y[i] = _transform(_list_y[i], 't') 227 | if i % 4 > 1: 228 | _list_y[i] = _transform(_list_y[i], 'h') 229 | if (i % 4) % 2 == 1: 230 | _list_y[i] = _transform(_list_y[i], 'v') 231 | 232 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 233 | if len(y) == 1: y = y[0] 234 | 235 | return y 236 | -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/common_m0.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/common_m0.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/edsr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/edsr.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/edsr_m0.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/edsr_m0.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/espcn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/espcn.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/espcn_m0.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyZry98/RepCaM-Pytorch/e2de1f5fbff77cbb105dd59c1ef3f12ebe0144a6/src/model/__pycache__/espcn_m0.cpython-39.pyc -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | def set_padding_size(kernel_size, dilation): 14 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | padding = (kernel_size - 1) // 2 16 | return padding 17 | 18 | class MeanShift(nn.Conv2d): 19 | def __init__( 20 | self, rgb_range, 21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 22 | 23 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 24 | std = torch.Tensor(rgb_std) 25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | class BasicBlock(nn.Sequential): 31 | #hello ljm, i am syx 32 | def __init__( 33 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 34 | bn=True, act=nn.ReLU(True)): 35 | 36 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 37 | if bn: 38 | m.append(nn.BatchNorm2d(out_channels)) 39 | if act is not None: 40 | m.append(act) 41 | 42 | super(BasicBlock, self).__init__(*m) 43 | 44 | class BasicBlock_(nn.Module): 45 | def __init__( 46 | self, args, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 47 | bn=True, act=nn.ReLU(True)): 48 | self.use_cafm = args.use_cafm 49 | super(BasicBlock_, self).__init__() 50 | self.conv1 = conv(in_channels, out_channels, kernel_size, bias=bias) 51 | if args.cafm: 52 | if self.use_cafm: 53 | self.cafms = nn.ModuleList([ContentAwareFM(out_channels,1) for _ in range(args.segnum)]) 54 | self.act = act 55 | 56 | def forward(self, input, num): 57 | x = self.conv1(input) 58 | if self.use_cafm: 59 | x = self.cafms[num](x) 60 | x = self.act(x) 61 | return x 62 | 63 | class ContentAwareFM(nn.Module): 64 | # hello ckx 65 | def __init__(self, in_channel, kernel_size): 66 | 67 | super(ContentAwareFM, self).__init__() 68 | padding = set_padding_size(kernel_size, 1) 69 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 70 | padding=padding, groups=in_channel) 71 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 72 | def forward(self, x): 73 | return self.transformer(x) * self.gamma + x 74 | 75 | class ResBlock(nn.Module): 76 | def __init__( 77 | self, conv, n_feats, kernel_size, args, 78 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 79 | 80 | super(ResBlock, self).__init__() 81 | self.res_scale = res_scale 82 | self.args = args 83 | if self.args.cafm: 84 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 85 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 86 | if self.args.use_cafm: 87 | if self.args.cafm: 88 | self.cafms1 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 89 | self.cafms2 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 90 | self.act = act 91 | 92 | else: 93 | m = [] 94 | for i in range(2): 95 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 96 | if bn: 97 | m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 98 | if i == 0: 99 | m.append(act) 100 | self.body = nn.Sequential(*m) 101 | 102 | 103 | def forward(self, input, num): 104 | x = self.conv1(input) 105 | if self.args.cafm: 106 | if self.args.use_cafm: 107 | x = self.cafms1[num](x) 108 | x = self.conv2(self.act(x)) 109 | if self.args.use_cafm: 110 | x = self.cafms2[num](x) 111 | else: 112 | x = self.conv2(self.act(x)) 113 | res = x.mul(self.res_scale) 114 | res += input 115 | 116 | return res 117 | 118 | class ResBlock_rcan(nn.Module): 119 | def __init__( 120 | self, conv, n_feats, kernel_size, args, 121 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 122 | 123 | super(ResBlock_rcan, self).__init__() 124 | self.res_scale = res_scale 125 | self.args = args 126 | if self.args.cafm: 127 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 128 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 129 | if self.args.use_cafm: 130 | if self.args.cafm: 131 | self.cafms1 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 132 | self.cafms2 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 133 | self.act = act 134 | 135 | else: 136 | m = [] 137 | for i in range(2): 138 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 139 | if bn: 140 | m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 141 | if i == 0: 142 | m.append(act) 143 | self.body = nn.Sequential(*m) 144 | 145 | 146 | def forward(self, input, num): 147 | x = self.conv1(input) 148 | if self.args.cafm: 149 | if self.args.use_cafm: 150 | x = self.cafms1[num](x) 151 | x = self.conv2(self.act(x)) 152 | if self.args.use_cafm: 153 | x = self.cafms2[num](x) 154 | else: 155 | x = self.conv2(self.act(x)) 156 | res = x.mul(self.res_scale) 157 | 158 | return res 159 | 160 | #Repblock 161 | class RepBlock(nn.Module): 162 | def __init__( 163 | self, conv, n_feats, kernel_size, args, 164 | bias=True ): 165 | 166 | super(RepBlock, self).__init__() 167 | 168 | 169 | self.conv = conv(n_feats, n_feats, kernel_size, bias=bias) 170 | self.conv_0_0 = conv(n_feats, n_feats, 1, bias=bias) 171 | self.conv_0_1 = conv(n_feats, n_feats, 1, bias=bias) 172 | self.conv_0_2 = conv(n_feats, n_feats, kernel_size, bias=bias) 173 | 174 | self.conv_1_0 = conv(n_feats, n_feats, 1, bias=bias) 175 | self.conv_1_1 = conv(n_feats, n_feats, kernel_size, bias=bias) 176 | 177 | self.conv_2_0 = conv(n_feats, n_feats, kernel_size, bias=bias) 178 | #print('not original') 179 | 180 | def forward(self, x): 181 | out = self.conv_0_2(self.conv_0_1(self.conv_0_0(x))) 182 | out_res = self.conv_1_1(self.conv_1_0(x)) 183 | out_id = self.conv_2_0(x) 184 | return (out + out_res + out_id) 185 | 186 | class RepBlock_org(nn.Module): 187 | def __init__( 188 | self, conv, n_feats, kernel_size, args, 189 | bias=True ): 190 | 191 | super(RepBlock_org, self).__init__() 192 | 193 | 194 | #self.conv = conv(n_feats, n_feats, kernel_size, bias=bias) 195 | #self.conv_0_0 = conv(n_feats, n_feats, 1, bias=bias) 196 | #self.conv_0_1 = conv(n_feats, n_feats, 1, bias=bias) 197 | #self.conv_0_2 = conv(n_feats, n_feats, kernel_size, bias=bias) 198 | 199 | #self.conv_1_0 = conv(n_feats, n_feats, 1, bias=bias) 200 | #self.conv_1_1 = conv(n_feats, n_feats, kernel_size, bias=bias) 201 | 202 | self.conv_2_0 = conv(n_feats, n_feats, kernel_size, bias=bias) 203 | #print("orange") 204 | 205 | 206 | 207 | def forward(self, x): 208 | #out_long = self.conv_0_2(self.conv_0_1(self.conv_0_0(x))) 209 | #out_mid = self.conv_1_1(self.conv_1_0(x)) 210 | out_short = self.conv_2_0(x) 211 | return (out_short + 0) 212 | 213 | class ResBlock_org(nn.Module): 214 | def __init__( 215 | self, conv, n_feats, kernel_size, args, 216 | bias=True, bn=True, act=nn.ReLU(True), res_scale=1, org=False): 217 | 218 | super(ResBlock_org, self).__init__() 219 | m = [] 220 | for i in range(2): 221 | if org: 222 | m.append(RepBlock_org(conv, n_feats, kernel_size, args, bias=True)) 223 | else: 224 | m.append(RepBlock(conv, n_feats, kernel_size, args, bias=True)) 225 | # if bn: 226 | # m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 227 | if i == 0: 228 | m.append(act) 229 | self.body = nn.Sequential(*m) 230 | self.res_scale = res_scale 231 | 232 | def forward(self, x): 233 | res = self.body(x).mul(self.res_scale) 234 | res += x 235 | return res 236 | 237 | class Upsampler(nn.Sequential): 238 | def __init__(self, conv, scale, n_feats, kernel_size, bn=False, act=False, bias=True): 239 | 240 | m = [] 241 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 242 | for _ in range(int(math.log(scale, 2))): 243 | m.append(conv(n_feats, 4 * n_feats, kernel_size, bias)) 244 | m.append(nn.PixelShuffle(2)) 245 | if bn: 246 | m.append(nn.BatchNorm2d(n_feats)) 247 | if act == 'relu': 248 | m.append(nn.ReLU(True)) 249 | elif act == 'prelu': 250 | m.append(nn.PReLU(n_feats)) 251 | 252 | elif scale == 3: 253 | m.append(conv(n_feats, 9 * n_feats, kernel_size, bias)) 254 | m.append(nn.PixelShuffle(3)) 255 | if bn: 256 | m.append(nn.BatchNorm2d(n_feats)) 257 | if act == 'relu': 258 | m.append(nn.ReLU(True)) 259 | elif act == 'prelu': 260 | m.append(nn.PReLU(n_feats)) 261 | else: 262 | raise NotImplementedError 263 | 264 | super(Upsampler, self).__init__(*m) 265 | 266 | -------------------------------------------------------------------------------- /src/model/common_m0.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | def set_padding_size(kernel_size, dilation): 14 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | padding = (kernel_size - 1) // 2 16 | return padding 17 | 18 | class MeanShift(nn.Conv2d): 19 | def __init__( 20 | self, rgb_range, 21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 22 | 23 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 24 | std = torch.Tensor(rgb_std) 25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | class BasicBlock(nn.Sequential): 31 | #hello ljm, i am syx 32 | def __init__( 33 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 34 | bn=True, act=nn.ReLU(True)): 35 | 36 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 37 | if bn: 38 | m.append(nn.BatchNorm2d(out_channels)) 39 | if act is not None: 40 | m.append(act) 41 | 42 | super(BasicBlock, self).__init__(*m) 43 | 44 | class BasicBlock_(nn.Module): 45 | def __init__( 46 | self, args, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 47 | bn=True, act=nn.ReLU(True)): 48 | self.use_cafm = args.use_cafm 49 | super(BasicBlock_, self).__init__() 50 | self.conv1 = conv(in_channels, out_channels, kernel_size, bias=bias) 51 | if args.cafm: 52 | if self.use_cafm: 53 | self.cafms = nn.ModuleList([ContentAwareFM(out_channels,1) for _ in range(args.segnum)]) 54 | self.act = act 55 | 56 | def forward(self, input, num): 57 | x = self.conv1(input) 58 | if self.use_cafm: 59 | x = self.cafms[num](x) 60 | x = self.act(x) 61 | return x 62 | 63 | class ContentAwareFM(nn.Module): 64 | # hello ckx 65 | def __init__(self, in_channel, kernel_size): 66 | 67 | super(ContentAwareFM, self).__init__() 68 | padding = set_padding_size(kernel_size, 1) 69 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 70 | padding=padding, groups=in_channel) 71 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 72 | def forward(self, x): 73 | return self.transformer(x) * self.gamma + x 74 | 75 | class ResBlock(nn.Module): 76 | def __init__( 77 | self, conv, n_feats, kernel_size, args, 78 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 79 | 80 | super(ResBlock, self).__init__() 81 | self.res_scale = res_scale 82 | self.args = args 83 | if self.args.cafm: 84 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 85 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 86 | if self.args.use_cafm: 87 | if self.args.cafm: 88 | self.cafms1 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 89 | self.cafms2 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 90 | self.act = act 91 | 92 | else: 93 | m = [] 94 | for i in range(2): 95 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 96 | if bn: 97 | m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 98 | if i == 0: 99 | m.append(act) 100 | self.body = nn.Sequential(*m) 101 | 102 | 103 | def forward(self, input, num): 104 | x = self.conv1(input) 105 | if self.args.cafm: 106 | if self.args.use_cafm: 107 | x = self.cafms1[num](x) 108 | x = self.conv2(self.act(x)) 109 | if self.args.use_cafm: 110 | x = self.cafms2[num](x) 111 | else: 112 | x = self.conv2(self.act(x)) 113 | res = x.mul(self.res_scale) 114 | res += input 115 | 116 | return res 117 | 118 | class ResBlock_rcan(nn.Module): 119 | def __init__( 120 | self, conv, n_feats, kernel_size, args, 121 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 122 | 123 | super(ResBlock_rcan, self).__init__() 124 | self.res_scale = res_scale 125 | self.args = args 126 | if self.args.cafm: 127 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 128 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 129 | if self.args.use_cafm: 130 | if self.args.cafm: 131 | self.cafms1 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 132 | self.cafms2 = nn.ModuleList([ContentAwareFM(n_feats,1) for _ in range(args.segnum)]) 133 | self.act = act 134 | 135 | else: 136 | m = [] 137 | for i in range(2): 138 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 139 | if bn: 140 | m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 141 | if i == 0: 142 | m.append(act) 143 | self.body = nn.Sequential(*m) 144 | 145 | 146 | def forward(self, input, num): 147 | x = self.conv1(input) 148 | if self.args.cafm: 149 | if self.args.use_cafm: 150 | x = self.cafms1[num](x) 151 | x = self.conv2(self.act(x)) 152 | if self.args.use_cafm: 153 | x = self.cafms2[num](x) 154 | else: 155 | x = self.conv2(self.act(x)) 156 | res = x.mul(self.res_scale) 157 | 158 | return res 159 | 160 | class RepBlock_m0(nn.Module): 161 | def __init__( 162 | self, conv, n_feats, kernel_size, args, 163 | bias=True ): 164 | 165 | super(RepBlock_m0, self).__init__() 166 | 167 | self.conv = conv(n_feats, n_feats, kernel_size, bias=bias) 168 | 169 | 170 | def forward(self, x): 171 | out = self.conv(x) 172 | return out 173 | 174 | class ResBlock_org(nn.Module): 175 | def __init__( 176 | self, conv, n_feats, kernel_size, args, 177 | bias=True, bn=True, act=nn.ReLU(True), res_scale=1): 178 | 179 | super(ResBlock_org, self).__init__() 180 | m = [] 181 | for i in range(2): 182 | m.append(RepBlock_m0(conv, n_feats, kernel_size, args, bias=True)) 183 | # if bn: 184 | # m.append(ContentAwareFM(n_feats,7)) # the filter size of cafm during finetune. 1 | 3 | 5 | 7 185 | if i == 0: 186 | m.append(act) 187 | 188 | self.body = nn.Sequential(*m) 189 | self.res_scale = res_scale 190 | 191 | def forward(self, x): 192 | res = self.body(x).mul(self.res_scale) 193 | res += x 194 | return res 195 | 196 | class Upsampler(nn.Sequential): 197 | def __init__(self, conv, scale, n_feats, kernel_size, bn=False, act=False, bias=True): 198 | 199 | m = [] 200 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 201 | for _ in range(int(math.log(scale, 2))): 202 | m.append(conv(n_feats, 4 * n_feats, kernel_size, bias)) 203 | m.append(nn.PixelShuffle(2)) 204 | if bn: 205 | m.append(nn.BatchNorm2d(n_feats)) 206 | if act == 'relu': 207 | m.append(nn.ReLU(True)) 208 | elif act == 'prelu': 209 | m.append(nn.PReLU(n_feats)) 210 | 211 | elif scale == 3: 212 | m.append(conv(n_feats, 9 * n_feats, kernel_size, bias)) 213 | m.append(nn.PixelShuffle(3)) 214 | if bn: 215 | m.append(nn.BatchNorm2d(n_feats)) 216 | if act == 'relu': 217 | m.append(nn.ReLU(True)) 218 | elif act == 'prelu': 219 | m.append(nn.PReLU(n_feats)) 220 | else: 221 | raise NotImplementedError 222 | 223 | super(Upsampler, self).__init__(*m) 224 | -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import common 6 | 7 | 8 | url = { 9 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 10 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 11 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 12 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 13 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 14 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 15 | } 16 | import pdb 17 | 18 | 19 | def make_model(args, parent=False): 20 | return EDSR(args) 21 | 22 | class EDSR(nn.Module): 23 | def __init__(self, args, conv=common.default_conv): 24 | super(EDSR, self).__init__() 25 | #args.n_resblocks *= 30 26 | n_resblocks = args.n_resblocks 27 | self.numbers = n_resblocks 28 | n_feats = args.n_feats 29 | kernel_size = 3 30 | scale = args.scale[0] 31 | self.cafm = args.cafm 32 | self.force_no_rep = args.no_rep 33 | self.n_resblocks = args.n_resblocks 34 | act = nn.ReLU(True) 35 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 36 | if url_name in url: 37 | self.url = url[url_name] 38 | else: 39 | self.url = None 40 | self.sub_mean = common.MeanShift(args.rgb_range) 41 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 42 | 43 | # define head module 44 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 45 | 46 | # define body module 47 | print(args.cafm, self.force_no_rep) 48 | if args.cafm and not self.force_no_rep: 49 | m_body = [common.ResBlock(conv, n_feats, kernel_size, args, bn=True, act=act, res_scale=args.res_scale) for _ in range(n_resblocks)] 50 | else: 51 | if self.force_no_rep: print("Frocing no paramtrization") 52 | m_body = [common.ResBlock_org(conv, n_feats, kernel_size, args, act=act, res_scale=args.res_scale, org=self.force_no_rep) for _ in range(n_resblocks)] 53 | m_body.append(common.RepBlock(conv, n_feats, kernel_size, args, bias=True )) 54 | #print("oijoijoijoijoijoijoijoijoijoijoij", len(m_body)) 55 | # define tail module 56 | m_tail = [ 57 | common.Upsampler(conv, scale, n_feats, kernel_size, act=False), 58 | conv(n_feats, args.n_colors, kernel_size) 59 | ] 60 | 61 | self.head = nn.Sequential(*m_head) 62 | if args.cafm and not self.force_no_rep: 63 | self.body = nn.ModuleList(m_body) 64 | else: 65 | if self.force_no_rep: print("Forcing no paramtrization") 66 | self.body = nn.Sequential(*m_body) 67 | self.tail = nn.Sequential(*m_tail) 68 | 69 | def forward(self, x, num): 70 | x = self.sub_mean(x) 71 | x = self.head(x) 72 | #cafm 73 | if self.cafm and not self.force_no_rep: 74 | res = x 75 | for i in range(self.numbers): 76 | res = self.body[i](res, num) 77 | res = self.body[self.n_resblocks](res) 78 | res += x 79 | #original 80 | else: 81 | #print("going original") 82 | res = self.body(x) 83 | res += x 84 | 85 | x = self.tail(res) 86 | x = self.add_mean(x) 87 | 88 | return x 89 | 90 | def load_state_dict(self, state_dict, strict=True): 91 | own_state = self.state_dict() 92 | for name, param in state_dict.items(): 93 | if name in own_state: 94 | if isinstance(param, nn.Parameter): 95 | param = param.data 96 | try: 97 | own_state[name].copy_(param) 98 | except Exception: 99 | if name.find('tail') == -1: 100 | raise RuntimeError('While copying the parameter named {}, ' 101 | 'whose dimensions in the model are {} and ' 102 | 'whose dimensions in the checkpoint are {}.' 103 | .format(name, own_state[name].size(), param.size())) 104 | elif strict: 105 | if name.find('tail') == -1: 106 | raise KeyError('unexpected key "{}" in state_dict' 107 | .format(name)) 108 | # CaFM change the weight model name 109 | else: 110 | #print(name) 111 | name = name.replace("2.weight","3.weight") if "weight" in name else name.replace("2.bias","3.bias") 112 | if isinstance(param, nn.Parameter): 113 | param = param.data 114 | own_state[name].copy_(param) 115 | 116 | -------------------------------------------------------------------------------- /src/model/edsr_m0 copy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import common_m0 6 | #from fix_patch_prompt import FixedPatchPrompter_image, FixedPatchPrompter_feature 7 | 8 | url = { 9 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 10 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 11 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 12 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 13 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 14 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 15 | } 16 | import pdb 17 | 18 | 19 | def make_model(args, parent=False): 20 | return EDSR(args) 21 | 22 | class EDSR(nn.Module): 23 | def __init__(self, args, conv=common_m0.default_conv): 24 | super(EDSR, self).__init__() 25 | 26 | n_resblocks = args.n_resblocks 27 | self.numbers = n_resblocks 28 | n_feats = args.n_feats 29 | kernel_size = 3 30 | scale = args.scale[0] 31 | self.cafm = args.cafm 32 | self.n_resblocks = args.n_resblocks 33 | act = nn.ReLU(True) 34 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 35 | if url_name in url: 36 | self.url = url[url_name] 37 | else: 38 | self.url = None 39 | self.sub_mean = common_m0.MeanShift(args.rgb_range) 40 | self.add_mean = common_m0.MeanShift(args.rgb_range, sign=1) 41 | 42 | # define head module 43 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 44 | 45 | # define body module 46 | if args.cafm: 47 | m_body = [common_m0.ResBlock(conv, n_feats, kernel_size, args, bn=True, act=act, res_scale=args.res_scale) for _ in range(n_resblocks)] 48 | else: 49 | m_body = [common_m0.ResBlock_org(conv, n_feats, kernel_size, args, act=act, res_scale=args.res_scale) for _ in range(n_resblocks)] 50 | m_body.append(common_m0.RepBlock_m0(conv, n_feats, kernel_size, args, bias=True )) 51 | # define tail module 52 | m_tail = [ 53 | common_m0.Upsampler(conv, scale, n_feats, kernel_size, act=False), 54 | conv(n_feats, args.n_colors, kernel_size) 55 | ] 56 | 57 | self.head = nn.Sequential(*m_head) 58 | if args.cafm: 59 | self.body = nn.ModuleList(m_body) 60 | else: 61 | self.body = nn.Sequential(*m_body) 62 | self.tail = nn.Sequential(*m_tail) 63 | 64 | def forward(self, x, num): 65 | x = self.sub_mean(x) 66 | x = self.head(x) 67 | #cafm 68 | if self.cafm: 69 | res = x 70 | for i in range(self.numbers): 71 | res = self.body[i](res, num) 72 | res = self.body[self.n_resblocks](res) 73 | res += x 74 | #original 75 | else: 76 | res = self.body(x) 77 | res += x 78 | 79 | x = self.tail(res) 80 | x = self.add_mean(x) 81 | 82 | return x 83 | 84 | def load_state_dict(self, state_dict, strict=True): 85 | own_state = self.state_dict() 86 | for name, param in state_dict.items(): 87 | if name in own_state: 88 | if isinstance(param, nn.Parameter): 89 | param = param.data 90 | try: 91 | own_state[name].copy_(param) 92 | except Exception: 93 | if name.find('tail') == -1: 94 | raise RuntimeError('While copying the parameter named {}, ' 95 | 'whose dimensions in the model are {} and ' 96 | 'whose dimensions in the checkpoint are {}.' 97 | .format(name, own_state[name].size(), param.size())) 98 | elif strict: 99 | if name.find('tail') == -1: 100 | raise KeyError('unexpected key "{}" in state_dict' 101 | .format(name)) 102 | # CaFM change the weight model name 103 | else: 104 | #print(name) 105 | name = name.replace("2.weight","3.weight") if "weight" in name else name.replace("2.bias","3.bias") 106 | if isinstance(param, nn.Parameter): 107 | param = param.data 108 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/edsr_m0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import common_m0 6 | #from fix_patch_prompt import FixedPatchPrompter_image, FixedPatchPrompter_feature 7 | 8 | url = { 9 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 10 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 11 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 12 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 13 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 14 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 15 | } 16 | import pdb 17 | 18 | 19 | def make_model(args, parent=False): 20 | return EDSR(args) 21 | 22 | class EDSR(nn.Module): 23 | def __init__(self, args, conv=common_m0.default_conv): 24 | super(EDSR, self).__init__() 25 | 26 | n_resblocks = args.n_resblocks 27 | self.numbers = n_resblocks 28 | n_feats = args.n_feats 29 | kernel_size = 3 30 | scale = args.scale[0] 31 | self.cafm = args.cafm 32 | self.n_resblocks = args.n_resblocks 33 | act = nn.ReLU(True) 34 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 35 | if url_name in url: 36 | self.url = url[url_name] 37 | else: 38 | self.url = None 39 | self.sub_mean = common_m0.MeanShift(args.rgb_range) 40 | self.add_mean = common_m0.MeanShift(args.rgb_range, sign=1) 41 | 42 | # define head module 43 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 44 | 45 | # define body module 46 | if args.cafm: 47 | m_body = [common_m0.ResBlock(conv, n_feats, kernel_size, args, bn=True, act=act, res_scale=args.res_scale) for _ in range(n_resblocks)] 48 | else: 49 | m_body = [common_m0.ResBlock_org(conv, n_feats, kernel_size, args, act=act, res_scale=args.res_scale) for _ in range(n_resblocks)] 50 | m_body.append(common_m0.RepBlock_m0(conv, n_feats, kernel_size, args, bias=True )) 51 | # define tail module 52 | m_tail = [ 53 | common_m0.Upsampler(conv, scale, n_feats, kernel_size, act=False), 54 | conv(n_feats, args.n_colors, kernel_size) 55 | ] 56 | 57 | self.head = nn.Sequential(*m_head) 58 | if args.cafm: 59 | self.body = nn.ModuleList(m_body) 60 | else: 61 | self.body = nn.Sequential(*m_body) 62 | self.tail = nn.Sequential(*m_tail) 63 | 64 | def forward(self, x, num): 65 | x = self.sub_mean(x) 66 | x = self.head(x) 67 | #cafm 68 | if self.cafm: 69 | res = x 70 | for i in range(self.numbers): 71 | res = self.body[i](res, num) 72 | res = self.body[self.n_resblocks](res) 73 | res += x 74 | #original 75 | else: 76 | res = self.body(x) 77 | res += x 78 | 79 | x = self.tail(res) 80 | x = self.add_mean(x) 81 | 82 | return x 83 | 84 | def load_state_dict(self, state_dict, strict=True): 85 | own_state = self.state_dict() 86 | for name, param in state_dict.items(): 87 | if name in own_state: 88 | if isinstance(param, nn.Parameter): 89 | param = param.data 90 | try: 91 | own_state[name].copy_(param) 92 | except Exception: 93 | if name.find('tail') == -1: 94 | raise RuntimeError('While copying the parameter named {}, ' 95 | 'whose dimensions in the model are {} and ' 96 | 'whose dimensions in the checkpoint are {}.' 97 | .format(name, own_state[name].size(), param.size())) 98 | elif strict: 99 | if name.find('tail') == -1: 100 | raise KeyError('unexpected key "{}" in state_dict' 101 | .format(name)) 102 | # CaFM change the weight model name 103 | else: 104 | #print(name) 105 | name = name.replace("2.weight","3.weight") if "weight" in name else name.replace("2.bias","3.bias") 106 | if isinstance(param, nn.Parameter): 107 | param = param.data 108 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch.nn.init as init 4 | from .common import ContentAwareFM 5 | import torch 6 | 7 | def make_model(args, parent=False): 8 | return ESPCN(args) 9 | 10 | 11 | def set_padding_size(kernel_size, dilation): 12 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 13 | padding = (kernel_size - 1) // 2 14 | return padding 15 | 16 | class ContentAwareFM(nn.Module): 17 | # hello ckx 18 | def __init__(self, in_channel, kernel_size): 19 | 20 | super(ContentAwareFM, self).__init__() 21 | padding = set_padding_size(kernel_size, 1) 22 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 23 | padding=padding, groups=in_channel//2) 24 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 25 | def forward(self, x): 26 | return self.transformer(x) * self.gamma + x 27 | 28 | 29 | 30 | class ESPCN(nn.Module): 31 | def __init__(self, args): 32 | super(ESPCN, self).__init__() 33 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 34 | self.act_func = nn.ReLU(inplace=True) 35 | self.scale = int(args.scale[0]) # use scale[0] 36 | self.n_colors = args.n_colors 37 | self.cafm = args.cafm 38 | self.use_cafm = args.use_cafm 39 | self.segnum = args.segnum 40 | 41 | # conv1 42 | 43 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 44 | 45 | self.conv1_0_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 46 | self.conv1_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 47 | self.conv1_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 48 | 49 | self.conv1_1_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 50 | self.conv1_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 51 | 52 | self.conv1_2_0 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 53 | 54 | 55 | #conv2 56 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 57 | self.conv2_0_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 58 | self.conv2_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 59 | self.conv2_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 60 | 61 | self.conv2_1_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 62 | self.conv2_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 63 | 64 | self.conv2_2_0 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 65 | 66 | self.conv2_3_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 67 | self.conv2_3_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 68 | self.conv2_3_2 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 69 | self.conv2_3_3 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 70 | 71 | #conv3 72 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 73 | self.conv3_0_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 74 | self.conv3_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 75 | self.conv3_0_2 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 76 | 77 | self.conv3_1_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 78 | self.conv3_1_1 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 79 | 80 | self.conv3_2_0 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 81 | 82 | self.conv3_3_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 83 | self.conv3_3_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 84 | self.conv3_3_2 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 85 | self.conv3_3_3 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 86 | 87 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 88 | self.conv4_0_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 89 | self.conv4_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 90 | self.conv4_0_2 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 91 | 92 | self.conv4_1_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 93 | self.conv4_1_1 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 94 | 95 | self.conv4_2_0 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 96 | 97 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 98 | 99 | # self._initialize_weights() 100 | 101 | if self.cafm: 102 | if self.use_cafm: 103 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 104 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 105 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 106 | 107 | 108 | def forward(self, x, num): 109 | if self.cafm: 110 | out = self.act_func(self.conv1(x)) 111 | if self.use_cafm: 112 | out = self.cafms1[num](out) 113 | out = self.act_func(self.conv2(out)) 114 | if self.use_cafm: 115 | out = self.cafms2[num](out) 116 | out = self.act_func(self.conv3(out)) 117 | if self.use_cafm: 118 | out = self.cafms3[num](out) 119 | out = self.pixel_shuffle(self.conv4(out)) 120 | return out 121 | else: 122 | out0 = self.act_func(self.conv1_0_2(self.conv1_0_1(self.conv1_0_0(x))) + self.conv1_1_1(self.conv1_1_0(x)) + self.conv1_2_0(x)) 123 | out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 124 | # out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 125 | # out1 = self.act_func(self.conv2(out0)) 126 | 127 | out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 128 | # out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 129 | # out2 = self.act_func(self.conv3(out1)) 130 | out3 = self.pixel_shuffle(self.conv4_0_2(self.conv4_0_1(self.conv4_0_0(out2))) + self.conv4_1_1(self.conv4_1_0(out2)) + self.conv4_2_0(out2)) 131 | 132 | return out3 133 | 134 | 135 | def _initialize_weights(self): 136 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 137 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 138 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 139 | init.orthogonal_(self.conv4.weight) 140 | 141 | 142 | def load_state_dict(self, state_dict, strict=True): 143 | own_state = self.state_dict() 144 | for name, param in state_dict.items(): 145 | if name in own_state: 146 | if isinstance(param, nn.Parameter): 147 | param = param.data 148 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn_chunked.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from torch import nn 4 | import torch.nn.init as init 5 | from .common import ContentAwareFM 6 | import torch 7 | from .fix_patch_prompt import FixedPatchPrompter_image, FixedPatchPrompter_feature 8 | 9 | def make_model(args, parent=False): 10 | return ESPCN(args) 11 | 12 | 13 | def set_padding_size(kernel_size, dilation): 14 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | padding = (kernel_size - 1) // 2 16 | return padding 17 | 18 | class ContentAwareFM(nn.Module): 19 | # hello ckx 20 | def __init__(self, in_channel, kernel_size): 21 | 22 | super(ContentAwareFM, self).__init__() 23 | padding = set_padding_size(kernel_size, 1) 24 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 25 | padding=padding, groups=in_channel//2) 26 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 27 | def forward(self, x): 28 | return self.transformer(x) * self.gamma + x 29 | 30 | 31 | 32 | class ESPCN(nn.Module): 33 | def __init__(self, args): 34 | super(ESPCN, self).__init__() 35 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 36 | self.act_func = nn.ReLU(inplace=True) 37 | self.scale = int(args.scale[0]) # use scale[0] 38 | self.n_colors = args.n_colors 39 | self.cafm = args.cafm 40 | self.use_cafm = args.use_cafm 41 | self.segnum = args.segnum 42 | 43 | if args.is15s: 44 | l = 15 * 33 // args.chunk_size 45 | elif args.is30s: 46 | l = 3 * 15 * 33 // args.chunk_size 47 | else: 48 | l = 9 * 15 * 33 // args.chunk_size 49 | 50 | self.prompters = [FixedPatchPrompter_image(args.patch_size // (self.scale)) for _ in range(l)] 51 | self.chunk_size = args.chunk_size 52 | self.l = l 53 | # conv1 54 | 55 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 56 | 57 | self.conv1_0_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 58 | self.conv1_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 59 | self.conv1_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 60 | 61 | self.conv1_1_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 62 | self.conv1_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 63 | 64 | self.conv1_2_0 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 65 | 66 | 67 | #conv2 68 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 69 | self.conv2_0_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 70 | self.conv2_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 71 | self.conv2_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 72 | 73 | self.conv2_1_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 74 | self.conv2_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 75 | 76 | self.conv2_2_0 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 77 | 78 | self.conv2_3_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 79 | self.conv2_3_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 80 | self.conv2_3_2 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 81 | self.conv2_3_3 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 82 | 83 | #conv3 84 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 85 | self.conv3_0_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 86 | self.conv3_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 87 | self.conv3_0_2 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 88 | 89 | self.conv3_1_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 90 | self.conv3_1_1 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 91 | 92 | self.conv3_2_0 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 93 | 94 | self.conv3_3_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 95 | self.conv3_3_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 96 | self.conv3_3_2 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 97 | self.conv3_3_3 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 98 | 99 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 100 | self.conv4_0_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 101 | self.conv4_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 102 | self.conv4_0_2 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 103 | 104 | self.conv4_1_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 105 | self.conv4_1_1 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 106 | 107 | self.conv4_2_0 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 108 | 109 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 110 | 111 | # self._initialize_weights() 112 | 113 | if self.cafm: 114 | if self.use_cafm: 115 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 116 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 117 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 118 | 119 | 120 | def forward(self, x, num): 121 | if self.cafm: 122 | out = self.act_func(self.conv1(x)) 123 | if self.use_cafm: 124 | out = self.cafms1[num](out) 125 | out = self.act_func(self.conv2(out)) 126 | if self.use_cafm: 127 | out = self.cafms2[num](out) 128 | out = self.act_func(self.conv3(out)) 129 | if self.use_cafm: 130 | out = self.cafms3[num](out) 131 | out = self.pixel_shuffle(self.conv4(out)) 132 | return out 133 | else: 134 | #print("Using the prompter") 135 | idxs = np.clip(num//self.chunk_size, a_min=None , a_max=self.l - 1) 136 | print(idxs) 137 | x = self.prompters[idxs](x) 138 | out0 = self.act_func(self.conv1_0_2(self.conv1_0_1(self.conv1_0_0(x))) + self.conv1_1_1(self.conv1_1_0(x)) + self.conv1_2_0(x)) 139 | out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 140 | # out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 141 | # out1 = self.act_func(self.conv2(out0)) 142 | 143 | out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 144 | # out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 145 | # out2 = self.act_func(self.conv3(out1)) 146 | out3 = self.pixel_shuffle(self.conv4_0_2(self.conv4_0_1(self.conv4_0_0(out2))) + self.conv4_1_1(self.conv4_1_0(out2)) + self.conv4_2_0(out2)) 147 | 148 | return out3 149 | 150 | 151 | def _initialize_weights(self): 152 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 153 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 154 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 155 | init.orthogonal_(self.conv4.weight) 156 | 157 | 158 | def load_state_dict(self, state_dict, strict=True): 159 | own_state = self.state_dict() 160 | for name, param in state_dict.items(): 161 | if name in own_state: 162 | if isinstance(param, nn.Parameter): 163 | param = param.data 164 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn_lf.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch.nn.init as init 4 | from .common import ContentAwareFM 5 | import torch 6 | 7 | from .fix_patch_prompt import FixedPatchPrompter_feature_1 8 | 9 | def make_model(args, parent=False): 10 | return ESPCN(args) 11 | 12 | 13 | def set_padding_size(kernel_size, dilation): 14 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | padding = (kernel_size - 1) // 2 16 | return padding 17 | 18 | class ContentAwareFM(nn.Module): 19 | # hello ckx 20 | def __init__(self, in_channel, kernel_size): 21 | 22 | super(ContentAwareFM, self).__init__() 23 | padding = set_padding_size(kernel_size, 1) 24 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 25 | padding=padding, groups=in_channel//2) 26 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 27 | def forward(self, x): 28 | return self.transformer(x) * self.gamma + x 29 | 30 | 31 | 32 | class ESPCN(nn.Module): 33 | def __init__(self, args): 34 | super(ESPCN, self).__init__() 35 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 36 | self.act_func = nn.ReLU(inplace=True) 37 | self.scale = int(args.scale[0]) # use scale[0] 38 | self.n_colors = args.n_colors 39 | self.cafm = args.cafm 40 | self.use_cafm = args.use_cafm 41 | self.segnum = args.segnum 42 | #print(args.scale) 43 | self.prompter = FixedPatchPrompter_feature_1(args.patch_size // self.scale, std=args.std) 44 | # conv1 45 | 46 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 47 | 48 | self.conv1_0_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 49 | self.conv1_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 50 | self.conv1_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 51 | 52 | self.conv1_1_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 53 | self.conv1_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 54 | 55 | self.conv1_2_0 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 56 | 57 | 58 | #conv2 59 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 60 | self.conv2_0_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 61 | self.conv2_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 62 | self.conv2_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 63 | 64 | self.conv2_1_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 65 | self.conv2_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 66 | 67 | self.conv2_2_0 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 68 | 69 | self.conv2_3_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 70 | self.conv2_3_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 71 | self.conv2_3_2 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 72 | self.conv2_3_3 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 73 | 74 | #conv3 75 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 76 | self.conv3_0_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 77 | self.conv3_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 78 | self.conv3_0_2 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 79 | 80 | self.conv3_1_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 81 | self.conv3_1_1 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 82 | 83 | self.conv3_2_0 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 84 | 85 | self.conv3_3_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 86 | self.conv3_3_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 87 | self.conv3_3_2 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 88 | self.conv3_3_3 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 89 | 90 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 91 | self.conv4_0_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 92 | self.conv4_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 93 | self.conv4_0_2 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 94 | 95 | self.conv4_1_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 96 | self.conv4_1_1 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 97 | 98 | self.conv4_2_0 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 99 | 100 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 101 | 102 | # self._initialize_weights() 103 | 104 | if self.cafm: 105 | if self.use_cafm: 106 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 107 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 108 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 109 | 110 | 111 | def forward(self, x, num): 112 | if self.cafm: 113 | out = self.act_func(self.conv1(x)) 114 | if self.use_cafm: 115 | out = self.cafms1[num](out) 116 | out = self.act_func(self.conv2(out)) 117 | if self.use_cafm: 118 | out = self.cafms2[num](out) 119 | out = self.act_func(self.conv3(out)) 120 | if self.use_cafm: 121 | out = self.cafms3[num](out) 122 | out = self.pixel_shuffle(self.conv4(out)) 123 | return out 124 | else: 125 | out0 = self.act_func(self.conv1_0_2(self.conv1_0_1(self.conv1_0_0(x))) + self.conv1_1_1(self.conv1_1_0(x)) + self.conv1_2_0(x)) 126 | #print(out0.shape) 127 | out0 = self.prompter(out0) 128 | out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 129 | # out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 130 | # out1 = self.act_func(self.conv2(out0)) 131 | 132 | out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 133 | # out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 134 | # out2 = self.act_func(self.conv3(out1)) 135 | out3 = self.pixel_shuffle(self.conv4_0_2(self.conv4_0_1(self.conv4_0_0(out2))) + self.conv4_1_1(self.conv4_1_0(out2)) + self.conv4_2_0(out2)) 136 | 137 | return out3 138 | 139 | 140 | def _initialize_weights(self): 141 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 142 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 143 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 144 | init.orthogonal_(self.conv4.weight) 145 | 146 | 147 | def load_state_dict(self, state_dict, strict=True): 148 | own_state = self.state_dict() 149 | for name, param in state_dict.items(): 150 | if name in own_state: 151 | if isinstance(param, nn.Parameter): 152 | param = param.data 153 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn_m0.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch.nn.init as init 4 | from .common import ContentAwareFM 5 | import torch 6 | 7 | def make_model(args, parent=False): 8 | return ESPCN(args) 9 | 10 | 11 | def set_padding_size(kernel_size, dilation): 12 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 13 | padding = (kernel_size - 1) // 2 14 | return padding 15 | 16 | class ContentAwareFM(nn.Module): 17 | # hello ckx 18 | def __init__(self, in_channel, kernel_size): 19 | 20 | super(ContentAwareFM, self).__init__() 21 | padding = set_padding_size(kernel_size, 1) 22 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 23 | padding=padding, groups=in_channel//2) 24 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 25 | def forward(self, x): 26 | return self.transformer(x) * self.gamma + x 27 | 28 | 29 | 30 | class ESPCN(nn.Module): 31 | def __init__(self, args): 32 | super(ESPCN, self).__init__() 33 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 34 | self.act_func = nn.ReLU(inplace=True) 35 | self.scale = int(args.scale[0]) # use scale[0] 36 | self.n_colors = args.n_colors 37 | self.cafm = args.cafm 38 | self.use_cafm = args.use_cafm 39 | self.segnum = args.segnum 40 | 41 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 42 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 43 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 44 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 45 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 46 | 47 | # self._initialize_weights() 48 | 49 | if self.cafm: 50 | if self.use_cafm: 51 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 52 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 53 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 54 | 55 | 56 | def forward(self, x, num): 57 | if self.cafm: 58 | out = self.act_func(self.conv1(x)) 59 | if self.use_cafm: 60 | out = self.cafms1[num](out) 61 | out = self.act_func(self.conv2(out)) 62 | if self.use_cafm: 63 | out = self.cafms2[num](out) 64 | out = self.act_func(self.conv3(out)) 65 | if self.use_cafm: 66 | out = self.cafms3[num](out) 67 | out = self.pixel_shuffle(self.conv4(out)) 68 | return out 69 | else: 70 | out = self.act_func(self.conv1(x)) 71 | out = self.act_func(self.conv2(out)) 72 | out = self.act_func(self.conv3(out)) 73 | out = self.pixel_shuffle(self.conv4(out)) 74 | return out 75 | 76 | 77 | def _initialize_weights(self): 78 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 79 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 80 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 81 | init.orthogonal_(self.conv4.weight) 82 | 83 | 84 | def load_state_dict(self, state_dict, strict=True): 85 | own_state = self.state_dict() 86 | for name, param in state_dict.items(): 87 | if name in own_state: 88 | if isinstance(param, nn.Parameter): 89 | param = param.data 90 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn_mdf.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch.nn.init as init 4 | from .common import ContentAwareFM 5 | import torch 6 | from .fix_patch_prompt import FixedPatchPrompter_image 7 | 8 | def make_model(args, parent=False): 9 | return ESPCN(args) 10 | 11 | 12 | def set_padding_size(kernel_size, dilation): 13 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 14 | padding = (kernel_size - 1) // 2 15 | return padding 16 | 17 | class ContentAwareFM(nn.Module): 18 | # hello ckx 19 | def __init__(self, in_channel, kernel_size): 20 | 21 | super(ContentAwareFM, self).__init__() 22 | padding = set_padding_size(kernel_size, 1) 23 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 24 | padding=padding, groups=in_channel//2) 25 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 26 | def forward(self, x): 27 | return self.transformer(x) * self.gamma + x 28 | 29 | 30 | 31 | class ESPCN(nn.Module): 32 | def __init__(self, args): 33 | super(ESPCN, self).__init__() 34 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 35 | self.act_func = nn.ReLU(inplace=True) 36 | self.scale = int(args.scale[0]) # use scale[0] 37 | self.n_colors = args.n_colors 38 | self.cafm = args.cafm 39 | self.use_cafm = args.use_cafm 40 | self.segnum = args.segnum 41 | 42 | self.prompter = FixedPatchPrompter_image(args.prompt_size, std=args.std) 43 | print("std: ", args.std) 44 | 45 | # conv1 46 | 47 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 48 | 49 | self.conv1_0_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 50 | self.conv1_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 51 | self.conv1_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 52 | 53 | self.conv1_1_0 = nn.Conv2d(self.n_colors, 64, (1, 1), (1, 1), (0, 0)) 54 | self.conv1_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 55 | 56 | self.conv1_2_0 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 57 | 58 | 59 | #conv2 60 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 61 | self.conv2_0_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 62 | self.conv2_0_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 63 | self.conv2_0_2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 64 | 65 | self.conv2_1_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 66 | self.conv2_1_1 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 67 | 68 | self.conv2_2_0 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 69 | 70 | self.conv2_3_0 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 71 | self.conv2_3_1 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 72 | self.conv2_3_2 = nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0)) 73 | self.conv2_3_3 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 74 | 75 | #conv3 76 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 77 | self.conv3_0_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 78 | self.conv3_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 79 | self.conv3_0_2 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 80 | 81 | self.conv3_1_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 82 | self.conv3_1_1 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 83 | 84 | self.conv3_2_0 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 85 | 86 | self.conv3_3_0 = nn.Conv2d(64, 32, (1, 1), (1, 1), (0, 0)) 87 | self.conv3_3_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 88 | self.conv3_3_2 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 89 | self.conv3_3_3 = nn.Conv2d(32, 32, (3, 3), (1, 1), (1, 1)) 90 | 91 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 92 | self.conv4_0_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 93 | self.conv4_0_1 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 94 | self.conv4_0_2 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 95 | 96 | self.conv4_1_0 = nn.Conv2d(32, 32, (1, 1), (1, 1), (0, 0)) 97 | self.conv4_1_1 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 98 | 99 | self.conv4_2_0 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 100 | 101 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 102 | 103 | # self._initialize_weights() 104 | 105 | if self.cafm: 106 | if self.use_cafm: 107 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 108 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 109 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 110 | 111 | 112 | def forward(self, x, num): 113 | if self.cafm: 114 | out = self.act_func(self.conv1(x)) 115 | if self.use_cafm: 116 | out = self.cafms1[num](out) 117 | out = self.act_func(self.conv2(out)) 118 | if self.use_cafm: 119 | out = self.cafms2[num](out) 120 | out = self.act_func(self.conv3(out)) 121 | if self.use_cafm: 122 | out = self.cafms3[num](out) 123 | out = self.pixel_shuffle(self.conv4(out)) 124 | return out 125 | else: 126 | #print("Using the prompter") 127 | x = self.prompter(x) 128 | out0 = self.act_func(self.conv1_0_2(self.conv1_0_1(self.conv1_0_0(x))) + self.conv1_1_1(self.conv1_1_0(x)) + self.conv1_2_0(x)) 129 | out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 130 | # out1 = self.act_func(self.conv2_0_2(self.conv2_0_1(self.conv2_0_0(out0))) + self.conv2_1_1(self.conv2_1_0(out0)) + self.conv2_2_0(out0)) 131 | # out1 = self.act_func(self.conv2(out0)) 132 | 133 | out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 134 | # out2 = self.act_func(self.conv3_0_2(self.conv3_0_1(self.conv3_0_0(out1))) + self.conv3_1_1(self.conv3_1_0(out1)) + self.conv3_2_0(out1)) 135 | # out2 = self.act_func(self.conv3(out1)) 136 | out3 = self.pixel_shuffle(self.conv4_0_2(self.conv4_0_1(self.conv4_0_0(out2))) + self.conv4_1_1(self.conv4_1_0(out2)) + self.conv4_2_0(out2)) 137 | 138 | return out3 139 | 140 | 141 | def _initialize_weights(self): 142 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 143 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 144 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 145 | init.orthogonal_(self.conv4.weight) 146 | 147 | 148 | def load_state_dict(self, state_dict, strict=True): 149 | own_state = self.state_dict() 150 | for name, param in state_dict.items(): 151 | if name in own_state: 152 | if isinstance(param, nn.Parameter): 153 | param = param.data 154 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/espcn_mdf_m0.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch.nn.init as init 4 | from .common import ContentAwareFM 5 | import torch 6 | from .fix_patch_prompt import FixedPatchPrompter_image, FixedPatchPrompter_feature 7 | 8 | def make_model(args, parent=False): 9 | return ESPCN(args) 10 | 11 | 12 | def set_padding_size(kernel_size, dilation): 13 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 14 | padding = (kernel_size - 1) // 2 15 | return padding 16 | 17 | class ContentAwareFM(nn.Module): 18 | # hello ckx 19 | def __init__(self, in_channel, kernel_size): 20 | 21 | super(ContentAwareFM, self).__init__() 22 | padding = set_padding_size(kernel_size, 1) 23 | self.transformer = nn.Conv2d(in_channel, in_channel, kernel_size, 24 | padding=padding, groups=in_channel//2) 25 | self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 26 | def forward(self, x): 27 | return self.transformer(x) * self.gamma + x 28 | 29 | 30 | 31 | class ESPCN(nn.Module): 32 | def __init__(self, args): 33 | super(ESPCN, self).__init__() 34 | # self.act_func = nn.LeakyReLU(negative_slope=0.2) 35 | self.act_func = nn.ReLU(inplace=True) 36 | self.scale = int(args.scale[0]) # use scale[0] 37 | self.n_colors = args.n_colors 38 | self.cafm = args.cafm 39 | self.use_cafm = args.use_cafm 40 | self.segnum = args.segnum 41 | 42 | self.conv1 = nn.Conv2d(self.n_colors, 64, (3, 3), (1, 1), (1, 1)) 43 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 44 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 45 | self.conv4 = nn.Conv2d(32, 3 * (self.scale ** 2), (3, 3), (1, 1), (1, 1)) 46 | self.pixel_shuffle = nn.PixelShuffle(self.scale) 47 | 48 | self.prompter = FixedPatchPrompter_image(args.patch_size // (self.scale) ) 49 | # self._initialize_weights() 50 | 51 | if self.cafm: 52 | if self.use_cafm: 53 | self.cafms1 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 54 | self.cafms2 = nn.ModuleList([ContentAwareFM(64,1) for _ in range(self.segnum)]) 55 | self.cafms3 = nn.ModuleList([ContentAwareFM(32,1) for _ in range(self.segnum)]) 56 | 57 | 58 | def forward(self, x, num): 59 | if self.cafm: 60 | out = self.act_func(self.conv1(x)) 61 | if self.use_cafm: 62 | out = self.cafms1[num](out) 63 | out = self.act_func(self.conv2(out)) 64 | if self.use_cafm: 65 | out = self.cafms2[num](out) 66 | out = self.act_func(self.conv3(out)) 67 | if self.use_cafm: 68 | out = self.cafms3[num](out) 69 | out = self.pixel_shuffle(self.conv4(out)) 70 | return out 71 | else: 72 | x = self.prompter(x) 73 | out = self.act_func(self.conv1(x)) 74 | out = self.act_func(self.conv2(out)) 75 | out = self.act_func(self.conv3(out)) 76 | out = self.pixel_shuffle(self.conv4(out)) 77 | return out 78 | 79 | 80 | def _initialize_weights(self): 81 | init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) 82 | init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) 83 | init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) 84 | init.orthogonal_(self.conv4.weight) 85 | 86 | 87 | def load_state_dict(self, state_dict, strict=True): 88 | own_state = self.state_dict() 89 | for name, param in state_dict.items(): 90 | if name in own_state: 91 | if isinstance(param, nn.Parameter): 92 | param = param.data 93 | own_state[name].copy_(param) -------------------------------------------------------------------------------- /src/model/fix_patch_prompt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | #from mmcv.cnn.bricks import PLUGIN_LAYERS 3 | import torch 4 | 5 | 6 | class FixedPatchPrompter_image(nn.Module): 7 | def __init__(self, prompt_size = 12, std = 1): 8 | super(FixedPatchPrompter_image, self).__init__() 9 | self.psize = prompt_size 10 | self.patch = nn.Parameter(std * torch.randn([3, self.psize, self.psize])) 11 | 12 | def forward(self, x): 13 | isize = x.shape[-2:] 14 | prompt = torch.zeros([x.shape[0], 3, isize[0], isize[1]], device='cuda') 15 | prompt[:, :, :self.psize, :self.psize] = self.patch.unsqueeze(0) 16 | return x + prompt 17 | 18 | 19 | # # for feature level 20 | #@PLUGIN_LAYERS.register_module() 21 | class FixedPatchPrompter_feature_1(nn.Module): 22 | def __init__(self, prompt_size = 24, std = 1): 23 | super(FixedPatchPrompter_feature_1, self).__init__() 24 | self.psize = prompt_size 25 | self.patch = nn.Parameter(torch.randn([64, prompt_size, prompt_size])*std) #for feature size of espcn_lf 26 | 27 | def forward(self, x): 28 | tmp = torch.zeros_like(x) 29 | #print(tmp.shape, x.shape) 30 | tmp[:,:, :self.psize, :self.psize] = self.patch.unsqueeze(0) 31 | 32 | return x + tmp 33 | 34 | class FixedPatchPrompter_feature_default(nn.Module): 35 | def __init__(self, prompt_size, image_size): 36 | super(FixedPatchPrompter_feature_default, self).__init__() 37 | self.isize = image_size 38 | self.psize = prompt_size 39 | self.patch = nn.Parameter(torch.randn([2, 2048, self.psize, self.psize])) #2 is batchsize, 2048 is feature dimension 40 | 41 | def forward(self, x): 42 | tmp = torch.zeros_like(x) 43 | tmp[:,:, :self.psize, :self.psize] = self.patch 44 | return x + tmp -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=6, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='/home/dlx/CaFM-Pytorch-ICCV2021-main/src/game_45s_1/lol_45s_1', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../test/..', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='DIV2K', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default='DIV2K', 29 | help='test dataset name') 30 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 31 | help='train/test data range') 32 | parser.add_argument('--ext', type=str, default='sep', 33 | help='dataset file extension') 34 | parser.add_argument('--scale', type=str, default='4', 35 | help='super resolution scale') 36 | parser.add_argument('--patch_size', type=int, default=48, 37 | help='output patch size') 38 | parser.add_argument('--rgb_range', type=int, default=255, 39 | help='maximum value of RGB') 40 | parser.add_argument('--n_colors', type=int, default=3, 41 | help='number of color channels to use') 42 | parser.add_argument('--chop', action='store_true', 43 | help='enable memory-efficient forward') 44 | parser.add_argument('--no_augment', action='store_true', 45 | help='do not use data augmentation') 46 | 47 | # Model specifications 48 | parser.add_argument('--model', default='EDSR', 49 | help='model name') 50 | 51 | parser.add_argument('--act', type=str, default='relu', 52 | help='activation function') 53 | parser.add_argument('--pre_train', type=str, default='', 54 | help='pre-trained model directory') 55 | parser.add_argument('--extend', type=str, default='.', 56 | help='pre-trained model directory') 57 | parser.add_argument('--n_resblocks', type=int, default=2, 58 | help='number of residual blocks') 59 | parser.add_argument('--n_feats', type=int, default=64, 60 | help='number of feature maps') 61 | parser.add_argument('--res_scale', type=float, default=1, 62 | help='residual scaling') 63 | parser.add_argument('--shift_mean', default=True, 64 | help='subtract pixel mean from the input') 65 | parser.add_argument('--dilation', action='store_true', 66 | help='use dilated convolution') 67 | parser.add_argument('--precision', type=str, default='single', 68 | choices=('single', 'half'), 69 | help='FP precision for test (single | half)') 70 | 71 | # Option for Residual dense network (RDN) 72 | parser.add_argument('--G0', type=int, default=64, 73 | help='default number of filters. (Use in RDN)') 74 | parser.add_argument('--RDNkSize', type=int, default=3, 75 | help='default kernel size. (Use in RDN)') 76 | parser.add_argument('--RDNconfig', type=str, default='B', 77 | help='parameters config of RDN. (Use in RDN)') 78 | 79 | # Option for Residual channel attention network (RCAN) 80 | parser.add_argument('--n_resgroups', type=int, default=10, 81 | help='number of residual groups') 82 | parser.add_argument('--reduction', type=int, default=16, 83 | help='number of feature maps reduction') 84 | 85 | # Training specifications 86 | parser.add_argument('--reset', action='store_true', 87 | help='reset the training') 88 | parser.add_argument('--test_every', type=int, default=1000, 89 | help='do test per every N batches') 90 | parser.add_argument('--epochs', type=int, default=300, 91 | help='number of epochs to train') 92 | parser.add_argument('--batch_size', type=int, default=16, 93 | help='input batch size for training') 94 | parser.add_argument('--split_batch', type=int, default=1, 95 | help='split the batch into smaller chunks') 96 | parser.add_argument('--self_ensemble', action='store_true', 97 | help='use self-ensemble method for test') 98 | parser.add_argument('--test_only', action='store_true', 99 | help='set this option to test the model') 100 | parser.add_argument('--gan_k', type=int, default=1, 101 | help='k value for adversarial loss') 102 | 103 | # Optimization specifications 104 | parser.add_argument('--lr', type=float, default=1e-4, #5*1e-4 105 | help='learning rate') 106 | parser.add_argument('--decay', type=str, default='200', #8 107 | help='learning rate decay type') 108 | parser.add_argument('--gamma', type=float, default=0.5, #0.1 109 | help='learning rate decay factor for step decay') 110 | parser.add_argument('--optimizer', default='ADAM', 111 | choices=('SGD', 'ADAM', 'RMSprop'), 112 | help='optimizer to use (SGD | ADAM | RMSprop)') 113 | parser.add_argument('--momentum', type=float, default=0.9, 114 | help='SGD momentum') 115 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 116 | help='ADAM beta') 117 | parser.add_argument('--epsilon', type=float, default=1e-8, 118 | help='ADAM epsilon for numerical stability') 119 | parser.add_argument('--weight_decay', type=float, default=0, 120 | help='weight decay') 121 | parser.add_argument('--gclip', type=float, default=0, 122 | help='gradient clipping threshold (0 = no clipping)') 123 | parser.add_argument('--patch_epsilon', type=float, default=1e-8, 124 | help='ADAM epsilon for numerical stability for the patch prompter') 125 | parser.add_argument('--patch_betas', type=tuple, default=(0.9, 0.999), 126 | help='ADAM beta for the patch prompter') 127 | parser.add_argument('--patch_lr', type=float, default=1e-4, #5*1e-4 128 | help='learning rate') 129 | 130 | # Loss specifications 131 | parser.add_argument('--loss', type=str, default='1*L1', #'0.5*L1+0.5*MSE' '1*L1' '0.5*L1+0.5*L12' 132 | help='loss function configuration') 133 | parser.add_argument('--skip_threshold', type=float, default='1e8', 134 | help='skipping batch that has large error') 135 | 136 | # Log specifications 137 | parser.add_argument('--save', type=str, default='test', 138 | help='file name to save') 139 | parser.add_argument('--load', type=str, default='', 140 | help='file name to load') 141 | parser.add_argument('--resume', type=int, default=0, 142 | help='resume from specific checkpoint') 143 | parser.add_argument('--save_models', action='store_true', 144 | help='save all intermediate models') 145 | parser.add_argument('--print_every', type=int, default=100, 146 | help='how many batches to wait before logging training status') 147 | parser.add_argument('--save_results', action='store_true', 148 | help='save output results') 149 | parser.add_argument('--save_gt', action='store_true', 150 | help='save low-resolution and high-resolution images together') 151 | 152 | # if add cafm 153 | parser.add_argument('--cafm', action='store_true', 154 | help='edsr + cafm') 155 | parser.add_argument('--cafm_espcn', action='store_true', 156 | help='espcn + cafm') 157 | parser.add_argument('--edsr_espcn', action='store_true', 158 | help='edsr + espcn, side tuning') 159 | parser.add_argument('--edsr_res', action='store_true', 160 | help='edsr only fine tune part resblock') 161 | 162 | parser.add_argument('--segnum', type=int, default=1, 163 | help='segnumber') 164 | 165 | parser.add_argument('--sidetuning', action='store_true', 166 | help='using sidetuning') 167 | 168 | parser.add_argument('--cafm_side', action='store_true', 169 | help='cafm + sidetuning') 170 | 171 | parser.add_argument('--data_partion', type=float, default=0.05, 172 | help='data_partion for data sampling') 173 | 174 | parser.add_argument('--tcloss_v1', action='store_true', 175 | help='tcloss_v1') 176 | 177 | parser.add_argument('--tcloss_v2', action='store_true', 178 | help='tcloss_v2') 179 | 180 | parser.add_argument('--tcloss_seg', type=int, default=0, 181 | help='which seg is selected to finetuning') 182 | 183 | parser.add_argument('--dvp', action='store_true', 184 | help='dvp: use SR as the GT') 185 | 186 | parser.add_argument('--use_cafm', action='store_true', 187 | help='using cafm block') 188 | 189 | parser.add_argument('--is45s', action='store_true', 190 | help='is 45s video') 191 | 192 | parser.add_argument('--is30s', action='store_true', 193 | help='is 30s video') 194 | 195 | parser.add_argument('--is15s', action='store_true', 196 | help='is 15s video') 197 | 198 | parser.add_argument('--finetune', action='store_true', 199 | help='using fintuning') 200 | 201 | parser.add_argument('--chunked', action='store_true', 202 | default=False,) 203 | 204 | parser.add_argument('--chunk_size', type=int, default=1,) 205 | 206 | parser.add_argument('--std', type=float, default=1) 207 | 208 | parser.add_argument('--prompt_size', type=int, default=12) 209 | 210 | parser.add_argument('--patch_load', type=str, default='') 211 | 212 | parser.add_argument('--no_rep', action='store_true', default=False,) 213 | 214 | args = parser.parse_args() 215 | template.set_template(args) 216 | 217 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 218 | args.data_train = args.data_train.split('+') 219 | args.data_test = args.data_test.split('+') 220 | 221 | if args.epochs == 0: 222 | args.epochs = 1e8 223 | 224 | for arg in vars(args): 225 | if vars(args)[arg] == 'True': 226 | vars(args)[arg] = True 227 | elif vars(args)[arg] == 'False': 228 | vars(args)[arg] = False 229 | 230 | -------------------------------------------------------------------------------- /src/reparameter_edsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import argparse 4 | import re 5 | import traceback 6 | import rich 7 | 8 | 9 | def transII_addbranch(kernels, biases): 10 | return sum(kernels), sum(biases) 11 | 12 | def transIII_1x1_kxk(k1, b1, k2, b2,groups): 13 | if groups == 1: 14 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) 15 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 16 | elif groups == 10: 17 | k = k1 * k2 18 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 19 | else: 20 | k_slices = [] 21 | b_slices = [] 22 | k1_T = k1.permute(1, 0, 2, 3) 23 | k1_group_width = k1.size(0) // groups 24 | k2_group_width = k2.size(0) // groups 25 | for g in range(groups): 26 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 27 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 28 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 29 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 30 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 31 | return k, b_hat + b2 32 | 33 | def transIV_depthconcat(kernels, biases): 34 | return torch.cat(kernels, dim=0), torch.cat(biases) 35 | 36 | def transVI_multiscale(kernel, target_kernel_size): 37 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 38 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 39 | return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) 40 | 41 | def extract_numbers(s): 42 | match = re.match(r'body\.(\d+)\.body\.(\d+)\.conv_(\d+)', s) 43 | if match: 44 | return tuple(map(int, match.groups())) 45 | else: 46 | return (None, None, None) 47 | 48 | #reparameter 49 | def reparameter(model,list1, list2, list3): 50 | k_0_01, b_0_01 = transIII_1x1_kxk(model[list1[3]],model[list1[2]],model[list1[5]],model[list1[4]],groups=1) 51 | k_0_012, b_0_012 = transIII_1x1_kxk(model[list1[1]],model[list1[0]],k_0_01,b_0_01,groups=1) 52 | 53 | k_1_01, b_1_01 = transIII_1x1_kxk(model[list2[1]],model[list2[0]],model[list2[3]],model[list2[2]],groups=1) 54 | 55 | k_012, b_012 = transII_addbranch((k_0_012,k_1_01,model[list3[1]]),(b_0_012,b_1_01,model[list3[0]])) 56 | 57 | conv2_list = list1 + list2 + list3 58 | 59 | for i in conv2_list: 60 | # print(i) 61 | del model[i] 62 | 63 | # print(model.keys()) 64 | 65 | return k_012, b_012 66 | 67 | # body.0.body.0.conv_0 68 | def get_layer_name(layer_depth, conv_num, branch_num): 69 | return 'body.' + str(layer_depth) + '.body.' + str(conv_num) + '.conv_' + str(branch_num) 70 | 71 | 72 | parser = argparse.ArgumentParser(description='PyTorch EDSR') 73 | parser.add_argument('--model_folder', type=str, default='experiment/model', help='model folder to use') 74 | parser.add_argument('--n_res_blocks', type=int, default=2) 75 | parser.add_argument('--m_branches', type=int, default=3) 76 | 77 | model_folder = parser.parse_args().model_folder 78 | n = parser.parse_args().n_res_blocks 79 | m = parser.parse_args().m_branches 80 | 81 | if m != 3: 82 | raise ValueError('Only 3 branches are supported') 83 | 84 | 85 | model_path = model_folder + '/model_best.pt' 86 | model_outpath = model_folder + '/model_rep.pt' 87 | 88 | model = torch.load(model_path) 89 | # print(model.keys()) 90 | 91 | res = [[] for i in range(n)] 92 | for resblock in res: 93 | for i in range(2): 94 | resblock.append([]) 95 | for branch in resblock: 96 | for i in range(m): 97 | branch.append([]) 98 | 99 | """res0_conv1_0_list = [] 100 | res0_conv1_1_list = [] 101 | res0_conv1_2_list = [] 102 | 103 | res0_conv2_0_list = [] 104 | res0_conv2_1_list = [] 105 | res0_conv2_2_list = [] 106 | 107 | res1_conv1_0_list = [] 108 | res1_conv1_1_list = [] 109 | res1_conv1_2_list = [] 110 | 111 | res1_conv2_0_list = [] 112 | res1_conv2_1_list = [] 113 | res1_conv2_2_list = []""" 114 | 115 | 116 | conv_0_list = [] 117 | conv_1_list = [] 118 | conv_2_list = [] 119 | 120 | def imap(j): 121 | return 0 if j == 0 else 2 122 | 123 | #rich.print(model.keys()) 124 | 125 | # Could be simplified. 126 | for k in model.keys(): 127 | flag = False 128 | for i in range(n): 129 | for j in range(2): 130 | for l in range(m): 131 | if get_layer_name(i, imap(j), l) in k: 132 | res[i][j][l].append(k) 133 | flag = True 134 | break 135 | if flag: 136 | break 137 | if flag: 138 | break 139 | if flag: 140 | continue 141 | 142 | if f"body.{n}.conv_0" in k: 143 | conv_0_list.append(k) 144 | elif f"body.{n}.conv_1" in k: 145 | conv_1_list.append(k) 146 | elif f"body.{n}.conv_2" in k: 147 | conv_2_list.append(k) 148 | 149 | """ 150 | for k in model.keys(): 151 | if "body.0.body.0.conv_0" in k: 152 | res0_conv1_0_list.append(k) 153 | elif "body.0.body.0.conv_1" in k: 154 | res0_conv1_1_list.append(k) 155 | elif "body.0.body.0.conv_2" in k: 156 | res0_conv1_2_list.append(k) 157 | elif "body.0.body.2.conv_0" in k: 158 | res0_conv2_0_list.append(k) 159 | elif "body.0.body.2.conv_1" in k: 160 | res0_conv2_1_list.append(k) 161 | elif "body.0.body.2.conv_2" in k: 162 | res0_conv2_2_list.append(k) 163 | 164 | elif "body.1.body.0.conv_0" in k: 165 | res1_conv1_0_list.append(k) 166 | elif "body.1.body.0.conv_1" in k: 167 | res1_conv1_1_list.append(k) 168 | elif "body.1.body.0.conv_2" in k: 169 | res1_conv1_2_list.append(k) 170 | elif "body.1.body.2.conv_0" in k: 171 | res1_conv2_0_list.append(k) 172 | elif "body.1.body.2.conv_1" in k: 173 | res1_conv2_1_list.append(k) 174 | elif "body.1.body.2.conv_2" in k: 175 | res1_conv2_2_list.append(k) 176 | elif "body.2.conv_0" in k: 177 | conv_0_list.append(k) 178 | elif "body.2.conv_1" in k: 179 | conv_1_list.append(k) 180 | elif "body.2.conv_2" in k: 181 | conv_2_list.append(k) 182 | 183 | else: 184 | continue""" 185 | 186 | for i in range(n): 187 | for j in range(2): 188 | for l in range(m): 189 | res[i][j][l].sort() 190 | 191 | """res0_conv1_0_list.sort() 192 | res0_conv1_1_list.sort() 193 | res0_conv1_2_list.sort() 194 | # print(res0_conv1_0_list) 195 | 196 | res0_conv2_0_list.sort() 197 | res0_conv2_1_list.sort() 198 | res0_conv2_2_list.sort() 199 | 200 | 201 | res1_conv1_0_list.sort() 202 | res1_conv1_1_list.sort() 203 | res1_conv1_2_list.sort() 204 | 205 | res1_conv2_0_list.sort() 206 | res1_conv2_1_list.sort() 207 | res1_conv2_2_list.sort()""" 208 | 209 | conv_0_list.sort() 210 | conv_1_list.sort() 211 | conv_2_list.sort() 212 | 213 | """ 214 | model['body.0.body.0.conv.weight'], model['body.0.body.0.conv.bias'] = reparameter(model, res0_conv1_0_list, res0_conv1_1_list, res0_conv1_2_list) 215 | model['body.0.body.2.conv.weight'], model['body.0.body.2.conv.bias'] = reparameter(model, res0_conv2_0_list, res0_conv2_1_list, res0_conv2_2_list) 216 | 217 | model['body.1.body.0.conv.weight'], model['body.1.body.0.conv.bias'] = reparameter(model, res1_conv1_0_list, res1_conv1_1_list, res1_conv1_2_list) 218 | model['body.1.body.2.conv.weight'], model['body.1.body.2.conv.bias'] = reparameter(model, res1_conv2_0_list, res1_conv2_1_list, res1_conv2_2_list) 219 | """ 220 | 221 | for i in range(n): 222 | try: 223 | model[f'body.{i}.body.0.conv.weight'], model[f'body.{i}.body.0.conv.bias'] = reparameter(model, res[i][0][0], res[i][0][1], res[i][0][2]) 224 | model[f'body.{i}.body.2.conv.weight'], model[f'body.{i}.body.2.conv.bias'] = reparameter(model, res[i][1][0], res[i][1][1], res[i][1][2]) 225 | except Exception as e: 226 | print(i) 227 | print(len(res[i][1][0]), len(res[i][1][1]), len(res[i][1][2])) 228 | print(res[i][1][0]) 229 | print("An error occured") 230 | traceback.print_exc() 231 | exit() 232 | 233 | 234 | 235 | try: 236 | model[f'body.{n}.conv.weight'], model[f'body.{n}.conv.bias'] = reparameter(model, conv_0_list, conv_1_list, conv_2_list) 237 | except: 238 | print(conv_0_list, conv_1_list, conv_2_list) 239 | traceback.print_exc() 240 | exit() 241 | 242 | # print(model.keys()) 243 | 244 | 245 | torch.save(model,model_outpath) -------------------------------------------------------------------------------- /src/reparameter_edsr_legacy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import argparse 4 | 5 | def transII_addbranch(kernels, biases): 6 | return sum(kernels), sum(biases) 7 | 8 | def transIII_1x1_kxk(k1, b1, k2, b2,groups): 9 | if groups == 1: 10 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) 11 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 12 | elif groups == 10: 13 | k = k1 * k2 14 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 15 | else: 16 | k_slices = [] 17 | b_slices = [] 18 | k1_T = k1.permute(1, 0, 2, 3) 19 | k1_group_width = k1.size(0) // groups 20 | k2_group_width = k2.size(0) // groups 21 | for g in range(groups): 22 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 23 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 24 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 25 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 26 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 27 | return k, b_hat + b2 28 | 29 | def transIV_depthconcat(kernels, biases): 30 | return torch.cat(kernels, dim=0), torch.cat(biases) 31 | 32 | def transVI_multiscale(kernel, target_kernel_size): 33 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 34 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 35 | return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) 36 | 37 | 38 | #reparameter 39 | def reparameter(model,list1, list2, list3): 40 | 41 | # print(list1) 42 | 43 | 44 | 45 | k_0_01, b_0_01 = transIII_1x1_kxk(model[list1[3]],model[list1[2]],model[list1[5]],model[list1[4]],groups=1) 46 | k_0_012, b_0_012 = transIII_1x1_kxk(model[list1[1]],model[list1[0]],k_0_01,b_0_01,groups=1) 47 | 48 | k_1_01, b_1_01 = transIII_1x1_kxk(model[list2[1]],model[list2[0]],model[list2[3]],model[list2[2]],groups=1) 49 | 50 | k_012, b_012 = transII_addbranch((k_0_012,k_1_01,model[list3[1]]),(b_0_012,b_1_01,model[list3[0]])) 51 | 52 | conv2_list = list1 + list2 + list3 53 | 54 | for i in conv2_list: 55 | # print(i) 56 | del model[i] 57 | 58 | # print(model.keys()) 59 | 60 | return k_012, b_012 61 | 62 | parser = argparse.ArgumentParser(description='PyTorch EDSR') 63 | parser.add_argument('--model_folder', type=str, default='experiment/model', help='model folder to use') 64 | 65 | model_folder = parser.parse_args().model_folder 66 | model_path = model_folder + '/model_best.pt' 67 | model_outpath = model_folder + '/model_rep.pt' 68 | 69 | model = torch.load(model_path) 70 | # print(model.keys()) 71 | 72 | res0_conv1_0_list = [] 73 | res0_conv1_1_list = [] 74 | res0_conv1_2_list = [] 75 | 76 | res0_conv2_0_list = [] 77 | res0_conv2_1_list = [] 78 | res0_conv2_2_list = [] 79 | 80 | res1_conv1_0_list = [] 81 | res1_conv1_1_list = [] 82 | res1_conv1_2_list = [] 83 | 84 | res1_conv2_0_list = [] 85 | res1_conv2_1_list = [] 86 | res1_conv2_2_list = [] 87 | 88 | 89 | conv_0_list = [] 90 | conv_1_list = [] 91 | conv_2_list = [] 92 | 93 | 94 | for k in model.keys(): 95 | if "body.0.body.0.conv_0" in k: 96 | res0_conv1_0_list.append(k) 97 | elif "body.0.body.0.conv_1" in k: 98 | res0_conv1_1_list.append(k) 99 | elif "body.0.body.0.conv_2" in k: 100 | res0_conv1_2_list.append(k) 101 | elif "body.0.body.2.conv_0" in k: 102 | res0_conv2_0_list.append(k) 103 | elif "body.0.body.2.conv_1" in k: 104 | res0_conv2_1_list.append(k) 105 | elif "body.0.body.2.conv_2" in k: 106 | res0_conv2_2_list.append(k) 107 | 108 | elif "body.1.body.0.conv_0" in k: 109 | res1_conv1_0_list.append(k) 110 | elif "body.1.body.0.conv_1" in k: 111 | res1_conv1_1_list.append(k) 112 | elif "body.1.body.0.conv_2" in k: 113 | res1_conv1_2_list.append(k) 114 | elif "body.1.body.2.conv_0" in k: 115 | res1_conv2_0_list.append(k) 116 | elif "body.1.body.2.conv_1" in k: 117 | res1_conv2_1_list.append(k) 118 | elif "body.1.body.2.conv_2" in k: 119 | res1_conv2_2_list.append(k) 120 | elif "body.2.conv_0" in k: 121 | conv_0_list.append(k) 122 | elif "body.2.conv_1" in k: 123 | conv_1_list.append(k) 124 | elif "body.2.conv_2" in k: 125 | conv_2_list.append(k) 126 | 127 | else: 128 | continue 129 | 130 | 131 | res0_conv1_0_list.sort() 132 | res0_conv1_1_list.sort() 133 | res0_conv1_2_list.sort() 134 | # print(res0_conv1_0_list) 135 | 136 | res0_conv2_0_list.sort() 137 | res0_conv2_1_list.sort() 138 | res0_conv2_2_list.sort() 139 | 140 | 141 | res1_conv1_0_list.sort() 142 | res1_conv1_1_list.sort() 143 | res1_conv1_2_list.sort() 144 | 145 | res1_conv2_0_list.sort() 146 | res1_conv2_1_list.sort() 147 | res1_conv2_2_list.sort() 148 | 149 | conv_0_list.sort() 150 | conv_1_list.sort() 151 | conv_2_list.sort() 152 | 153 | 154 | model['body.0.body.0.conv.weight'], model['body.0.body.0.conv.bias'] = reparameter(model, res0_conv1_0_list, res0_conv1_1_list, res0_conv1_2_list) 155 | model['body.0.body.2.conv.weight'], model['body.0.body.2.conv.bias'] = reparameter(model, res0_conv2_0_list, res0_conv2_1_list, res0_conv2_2_list) 156 | 157 | model['body.1.body.0.conv.weight'], model['body.1.body.0.conv.bias'] = reparameter(model, res1_conv1_0_list, res1_conv1_1_list, res1_conv1_2_list) 158 | model['body.1.body.2.conv.weight'], model['body.1.body.2.conv.bias'] = reparameter(model, res1_conv2_0_list, res1_conv2_1_list, res1_conv2_2_list) 159 | 160 | print(len(res0_conv1_0_list), len(res0_conv1_1_list), len(res0_conv1_2_list)) 161 | print(res0_conv1_0_list, res0_conv1_1_list, res0_conv1_2_list) 162 | 163 | model['body.2.conv.weight'], model['body.2.conv.bias'] = reparameter(model, conv_0_list, conv_1_list, conv_2_list) 164 | 165 | # print(model.keys()) 166 | 167 | 168 | torch.save(model,model_outpath) -------------------------------------------------------------------------------- /src/reparameter_espcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import argparse 4 | 5 | def transII_addbranch(kernels, biases): 6 | return sum(kernels), sum(biases) 7 | 8 | def transIII_1x1_kxk(k1, b1, k2, b2,groups): 9 | if groups == 1: 10 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) # 11 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 12 | elif groups == 10: 13 | k = k1 * k2 14 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 15 | else: 16 | k_slices = [] 17 | b_slices = [] 18 | k1_T = k1.permute(1, 0, 2, 3) 19 | k1_group_width = k1.size(0) // groups 20 | k2_group_width = k2.size(0) // groups 21 | for g in range(groups): 22 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 23 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 24 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 25 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 26 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 27 | return k, b_hat + b2 28 | 29 | def transIV_depthconcat(kernels, biases): 30 | return torch.cat(kernels, dim=0), torch.cat(biases) 31 | 32 | def transVI_multiscale(kernel, target_kernel_size): 33 | H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 34 | W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 35 | return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) 36 | 37 | #reparameter 38 | def reparameter(model,list1, list2, list3): 39 | 40 | print(list1) 41 | 42 | 43 | 44 | k_0_01, b_0_01 = transIII_1x1_kxk(model[list1[3]],model[list1[2]],model[list1[5]],model[list1[4]],groups=1) 45 | k_0_012, b_0_012 = transIII_1x1_kxk(model[list1[1]],model[list1[0]],k_0_01,b_0_01,groups=1) 46 | 47 | k_1_01, b_1_01 = transIII_1x1_kxk(model[list2[1]],model[list2[0]],model[list2[3]],model[list2[2]],groups=1) 48 | 49 | k_012, b_012 = transII_addbranch((k_0_012,k_1_01,model[list3[1]]),(b_0_012,b_1_01,model[list3[0]])) 50 | 51 | conv2_list = list1 + list2 + list3 52 | 53 | for i in conv2_list: 54 | # print(i) 55 | del model[i] 56 | 57 | # print(model.keys()) 58 | 59 | return k_012, b_012 60 | 61 | parser = argparse.ArgumentParser(description='PyTorch EDSR') 62 | parser.add_argument('--model_folder', type=str, default='experiment/model', help='model folder to use') 63 | 64 | model_folder = parser.parse_args().model_folder 65 | model_path = model_folder + '/model_best.pt' 66 | model_outpath = model_folder + '/model_rep.pt' 67 | 68 | model = torch.load(model_path) 69 | # print(model.keys()) 70 | 71 | conv1_0_list = [] 72 | conv1_1_list = [] 73 | conv1_2_list = [] 74 | 75 | conv2_0_list = [] 76 | conv2_1_list = [] 77 | conv2_2_list = [] 78 | 79 | conv3_0_list = [] 80 | conv3_1_list = [] 81 | conv3_2_list = [] 82 | 83 | conv4_0_list = [] 84 | conv4_1_list = [] 85 | conv4_2_list = [] 86 | 87 | for k in model.keys(): 88 | if "conv1_0" in k: 89 | conv1_0_list.append(k) 90 | elif "conv1_1" in k: 91 | conv1_1_list.append(k) 92 | elif "conv1_2" in k: 93 | conv1_2_list.append(k) 94 | elif "conv2_0" in k: 95 | conv2_0_list.append(k) 96 | elif "conv2_1" in k: 97 | conv2_1_list.append(k) 98 | elif "conv2_2" in k: 99 | conv2_2_list.append(k) 100 | elif "conv3_0" in k: 101 | conv3_0_list.append(k) 102 | elif "conv3_1" in k: 103 | conv3_1_list.append(k) 104 | elif "conv3_2" in k: 105 | conv3_2_list.append(k) 106 | elif "conv4_0" in k: 107 | conv4_0_list.append(k) 108 | elif "conv4_1" in k: 109 | conv4_1_list.append(k) 110 | elif "conv4_2" in k: 111 | conv4_2_list.append(k) 112 | else: 113 | continue 114 | 115 | conv1_0_list.sort() 116 | conv1_1_list.sort() 117 | conv1_2_list.sort() 118 | conv2_0_list.sort() 119 | conv2_1_list.sort() 120 | conv2_2_list.sort() 121 | conv3_0_list.sort() 122 | conv3_1_list.sort() 123 | conv3_2_list.sort() 124 | conv4_0_list.sort() 125 | conv4_1_list.sort() 126 | conv4_2_list.sort() 127 | 128 | # print(conv2_0_list) 129 | # print(conv1_1_list) 130 | 131 | #conv1 reparameter 132 | model['conv1.weight'], model['conv1.bias'] = reparameter(model, conv1_0_list, conv1_1_list, conv1_2_list) 133 | 134 | #conv2 reparameter 135 | model['conv2.weight'], model['conv2.bias'] = reparameter(model, conv2_0_list, conv2_1_list, conv2_2_list) 136 | 137 | 138 | # conv3 reparameter 139 | model['conv3.weight'], model['conv3.bias'] = reparameter(model, conv3_0_list, conv3_1_list, conv3_2_list) 140 | 141 | # conv4 reparameter 142 | model['conv4.weight'], model['conv4.bias'] = reparameter(model, conv4_0_list, conv4_1_list, conv4_2_list) 143 | 144 | 145 | torch.save(model,model_outpath) -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | 10 | if args.template.find('RCAN') >= 0: 11 | args.model = 'RCAN' 12 | args.n_resgroups = 10 13 | args.n_resblocks = 20 14 | args.n_feats = 64 15 | args.chop = True 16 | 17 | if args.template.find('EDSR') >= 0: 18 | args.model = 'EDSR' 19 | args.n_resblocks = 16 20 | args.n_feats = 64 21 | 22 | if args.template.find('VDSRR') >= 0: 23 | args.model = 'VDSRR' 24 | args.n_resblocks = 20 25 | args.n_feats = 64 26 | args.patch_size = 48 27 | args.lr = 1e-4 28 | 29 | if args.template.find('ESPCN') >= 0: 30 | args.model = 'ESPCN' 31 | args.n_feats = 64 32 | 33 | if args.template.find('SRCNN') >= 0: 34 | args.model = 'SRCNN' 35 | args.n_feats = 64 -------------------------------------------------------------------------------- /src/train_bash_demo/demo_train_M1-n.sh: -------------------------------------------------------------------------------- 1 | python main.py --model EDSR --scale 3 --patch_size 48 --save EDSR_X3_demo_game_15s_1_M1-n --reset --data_train DIV2K --data_test DIV2K --data_range 1-450/451-495 --cafm --dir_data /home/datasets/VSD4K/game/game_15s_1 --use_cafm --batch_size 64 --epoch 500 --decay 300 --is15s --segnum 3 2 | -------------------------------------------------------------------------------- /src/train_bash_demo/demo_train_S1-n.sh: -------------------------------------------------------------------------------- 1 | python main.py --model EDSR --scale 3 --patch_size 48 --cafm --save EDSR_X3_demo_game_15s_1_M0 --reset --data_train DIV2K --data_test DIV2K --data_range 1-450/451-495 --is15s --dir_data /home/datasets/VSD4K/game/game_15s_1 2 | python main.py --model EDSR --scale 3 --patch_size 48 --cafm --save EDSR_X3_demo_game_15s_1_S1 --reset --data_train DIV2K --data_test DIV2K --data_range 1-150/451-465 --is15s --dir_data /home/datasets/VSD4K/game/game_15s_1 3 | python main.py --model EDSR --scale 3 --patch_size 48 --cafm --save EDSR_X3_demo_game_15s_1_S2 --reset --data_train DIV2K --data_test DIV2K --data_range 151-300/466-480 --is15s --dir_data /home/datasets/VSD4K/game/game_15s_1 4 | python main.py --model EDSR --scale 3 --patch_size 48 --cafm --save EDSR_X3_demo_game_15s_1_S3 --reset --data_train DIV2K --data_test DIV2K --data_range 301-450/481-495 --is15s --dir_data /home/datasets/VSD4K/game/game_15s_1 5 | -------------------------------------------------------------------------------- /src/trainer_cafm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | import torch.nn.utils as utils 9 | from tqdm import tqdm 10 | import sys 11 | import numpy as np 12 | import cv2 as cv 13 | import imageio 14 | from model.fix_patch_prompt import FixedPatchPrompter_image 15 | 16 | class Trainer_cafm(): 17 | def __init__(self, args, loader, my_model, my_loss, ckp: utility.checkpoint): 18 | self.args = args 19 | self.scale = args.scale 20 | 21 | self.ckp = ckp 22 | self.loader_train = loader.loader_train 23 | self.loader_test = loader.loader_test 24 | self.model = my_model 25 | self.loss = my_loss 26 | self.optimizer = utility.make_optimizer(args, self.model) 27 | 28 | self.patch = FixedPatchPrompter_image(prompt_size = args.prompt_size, std = args.std).cuda() 29 | self.patch_optimizer = utility.make_patch_optimizer(args, self.patch) 30 | 31 | if args.use_cafm: 32 | self.patch = [FixedPatchPrompter_image(prompt_size = args.prompt_size, std = args.std).cuda() 33 | for i in range(args.segnum + 1)] 34 | self.patch_optimizer = [utility.make_patch_optimizer(args, self.patch[i]) 35 | for i in range(args.segnum + 1)] 36 | 37 | if self.args.load != '': 38 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 39 | 40 | if self.args.use_cafm: 41 | if args.patch_load != '': 42 | print() 43 | print(args.patch_load) 44 | path = args.patch_load 45 | for i in range(args.segnum + 1): 46 | patch_checkpoint_path = os.path.join(path, f'patch_{i}.pt') 47 | if os.path.exists(patch_checkpoint_path): 48 | checkpoint = torch.load(patch_checkpoint_path) 49 | self.patch[i].load_state_dict(checkpoint) 50 | #self.patch_optimizer[i].load_state_dict(checkpoint['optimizer']) 51 | print(f"patch{i} loaded") 52 | else: 53 | print(patch_checkpoint_path + " not exists") 54 | 55 | self.error_last = 1e8 56 | 57 | def train(self): 58 | 59 | self.loss.step() 60 | epoch = self.optimizer.get_last_epoch() + 1 61 | lr = self.optimizer.get_lr() 62 | plr = self.patch_optimizer[0].get_lr() 63 | 64 | self.ckp.write_log( 65 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 66 | ) 67 | self.ckp.write_log( 68 | '[Epoch {}]\tPatch learning rate: {:.2e}'.format(epoch, Decimal(plr)) 69 | ) 70 | self.loss.start_log() 71 | self.model.train() 72 | 73 | timer_data, timer_model = utility.timer(), utility.timer() 74 | # TEMP 75 | self.loader_train.dataset.set_scale(0) 76 | length = self.args.data_range.split('/')[0].split('-')[1] 77 | segnum = self.args.segnum 78 | for batch, (lr, hr, num,) in enumerate(self.loader_train): 79 | lr, hr = self.prepare(lr, hr) 80 | timer_data.hold() 81 | timer_model.tic() 82 | 83 | self.optimizer.zero_grad() 84 | if self.args.use_cafm: 85 | for opt in self.patch_optimizer: 86 | opt.zero_grad() 87 | #use_cafm 88 | if self.args.use_cafm: 89 | t_num = utility.make_trainChunk(num, length, segnum) 90 | numlist = [(t_num[i],i)for i in range(segnum)] 91 | #print(numlist) 92 | loss = 0 93 | for i in numlist: 94 | if len(i[0])!=0: 95 | pre_sr = self.patch[i[1]](lr[i[0]]) 96 | sr = self.model(pre_sr, 0, i[1]) 97 | loss += self.loss(sr, hr[i[0]])*len(i[0]) 98 | loss = loss/len(num) 99 | elif self.args.chunked: 100 | sr = self.model(lr, 0, np.array(num).astype(int)) 101 | loss = self.loss(sr, hr) 102 | #baseline 103 | else: 104 | sr = self.model(self.patch(lr), 0, num) 105 | loss = self.loss(sr, hr) 106 | 107 | loss.backward() 108 | if self.args.gclip > 0: 109 | utils.clip_grad_value_( 110 | self.model.parameters(), 111 | self.args.gclip 112 | ) 113 | self.optimizer.step() 114 | if self.args.use_cafm: 115 | for opt in self.patch_optimizer: 116 | opt.step() 117 | else: 118 | self.patch_optimizer.step() 119 | 120 | timer_model.hold() 121 | 122 | if (batch + 1) % self.args.print_every == 0: 123 | self.ckp.write_log('[{}/{}]\t{:.1f}\t{:.1f}+{:.1f}s'.format( 124 | (batch + 1) * self.args.batch_size, 125 | len(self.loader_train.dataset), 126 | #self.loss.display_loss(batch), 127 | loss.data, 128 | timer_model.release(), 129 | timer_data.release())) 130 | 131 | timer_data.tic() 132 | print(f"Allocated memory: {torch.cuda.memory_allocated()} bytes") 133 | print(f"Reserved memory: {torch.cuda.memory_reserved()} bytes") 134 | 135 | self.loss.end_log(len(self.loader_train)) 136 | self.error_last = self.loss.log[-1, -1] 137 | self.optimizer.schedule() 138 | if self.args.use_cafm: 139 | for opt in self.patch_optimizer: 140 | opt.schedule() 141 | else: 142 | self.patch_optimizer.step() 143 | 144 | def test(self): 145 | 146 | 147 | # low_hr_path = '/home/dlx/CaFM-Pytorch-ICCV2021-main/src/new_data/new_data/low_hr_7000/' 148 | # low_list = os.listdir(low_hr_path) 149 | # low_list.sort(key= lambda x:int(x[:-4])) 150 | # # print( low_list) 151 | 152 | torch.set_grad_enabled(False) 153 | 154 | epoch = self.optimizer.get_last_epoch() 155 | self.ckp.write_log('\nEvaluation:') 156 | self.ckp.add_log( 157 | torch.zeros(1, len(self.loader_test), len(self.scale)) 158 | ) 159 | self.model.eval() 160 | 161 | timer_test = utility.timer() 162 | if self.args.save_results: self.ckp.begin_background() 163 | for idx_data, d in enumerate(self.loader_test): 164 | for idx_scale, scale in enumerate(self.scale): 165 | d.dataset.set_scale(idx_scale) 166 | 167 | k = 0 168 | 169 | for lr, hr, filename in tqdm(d, ncols=80): 170 | # print(filename) 171 | lr, hr = self.prepare(lr, hr) 172 | filename[0] = filename[0].split('x')[0] 173 | if self.args.is45s: 174 | flag = utility.make_testChunk('45s', filename[0]) 175 | elif self.args.is15s: 176 | flag = utility.make_testChunk('15s', filename[0]) 177 | elif self.args.is30s: 178 | flag = utility.make_testChunk('30s', filename=[0]) 179 | 180 | if self.args.use_cafm: 181 | sr = self.model(self.patch[flag](lr), idx_scale, flag) 182 | else: 183 | sr = self.model(self.patch(lr), idx_scale, flag) 184 | # print(sr.shape) 185 | 186 | 187 | 188 | # # low_hr = cv.imread(low_hr_path+ low_list[k]) 189 | 190 | # # low_hr = low_hr[:, :, ::-1] 191 | # # print(low_hr_path+ low_list[k]) 192 | # low_hr = cv.imread(low_hr_path + low_list[k], flags=cv.IMREAD_COLOR) 193 | 194 | # low_hr = cv.cvtColor(low_hr,cv.COLOR_BGR2RGB) 195 | # print(low_hr_path+ low_list[k],filename) 196 | 197 | # k = k+1 198 | # device = hr.device 199 | # low_hr = torch.from_numpy(low_hr).to(device) 200 | # # print(low_hr[0,0,0]) 201 | # low_hr = low_hr.permute(2,0,1).unsqueeze(dim=0) 202 | 203 | # # print(low_hr.shape) 204 | 205 | 206 | 207 | sr = utility.quantize(sr, self.args.rgb_range) 208 | save_list = [sr] 209 | 210 | #utility.calc_lpips(sr, hr) 211 | #utility.calc_ssim(sr, hr) 212 | 213 | # save_list = [low_hr] 214 | 215 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 216 | sr, hr, scale, self.args.rgb_range, dataset=d 217 | ) 218 | if self.args.save_gt: 219 | save_list.extend([lr, hr]) 220 | 221 | if self.args.save_results: 222 | self.ckp.save_results(d, filename[0], save_list, scale) 223 | 224 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 225 | best = self.ckp.log.max(0) 226 | self.ckp.write_log( 227 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 228 | d.dataset.name, 229 | scale, 230 | self.ckp.log[-1, idx_data, idx_scale], 231 | best[0][idx_data, idx_scale], 232 | best[1][idx_data, idx_scale] + 1 233 | ) 234 | ) 235 | 236 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 237 | self.ckp.write_log('Saving...') 238 | if not self.args.use_cafm: 239 | self.ckp.write_log(f"Mean of the patch prompt: {torch.mean(self.patch.patch)}, Std: {torch.std(self.patch.patch)}") 240 | else: 241 | for idx, patch in enumerate(self.patch): 242 | self.ckp.write_log(f"Mean of the patch prompt seg {idx}: {torch.mean(patch.patch)}, Std: {torch.std(patch.patch)}") 243 | 244 | 245 | if self.args.save_results: 246 | self.ckp.end_background() 247 | 248 | if not self.args.test_only: 249 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 250 | #self.ckp.save_everyepoch(self, epoch, is_best=True) 251 | if best[1][0, 0] + 1 == epoch: 252 | for i, model in enumerate(self.patch): 253 | self.ckp.save_patch(model, epoch, is_best=True, idx=i) 254 | elif epoch % 40 == 0: 255 | for i, model in enumerate(self.patch): 256 | self.ckp.save_patch(model, epoch, is_best=True, idx=f"epo_{epoch}_{i}") 257 | 258 | 259 | self.ckp.write_log( 260 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 261 | ) 262 | 263 | torch.set_grad_enabled(True) 264 | 265 | def prepare(self, *args): 266 | device = torch.device('cpu' if self.args.cpu else 'cuda') 267 | def _prepare(tensor): 268 | if self.args.precision == 'half': tensor = tensor.half() 269 | return tensor.to(device) 270 | 271 | return [_prepare(a) for a in args] 272 | 273 | def terminate(self): 274 | if self.args.test_only: 275 | self.test() 276 | return True 277 | else: 278 | epoch = self.optimizer.get_last_epoch() + 1 279 | return epoch >= self.args.epochs 280 | 281 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | 15 | import torch 16 | import lpips 17 | import torch.optim as optim 18 | import torch.optim.lr_scheduler as lrs 19 | class timer(): 20 | def __init__(self): 21 | self.acc = 0 22 | self.tic() 23 | 24 | def tic(self): 25 | self.t0 = time.time() 26 | 27 | def toc(self, restart=False): 28 | diff = time.time() - self.t0 29 | if restart: self.t0 = time.time() 30 | return diff 31 | 32 | def hold(self): 33 | self.acc += self.toc() 34 | 35 | def release(self): 36 | ret = self.acc 37 | self.acc = 0 38 | 39 | return ret 40 | 41 | def reset(self): 42 | self.acc = 0 43 | 44 | class checkpoint(): 45 | def __init__(self, args): 46 | self.args = args 47 | self.ok = True 48 | self.log = torch.Tensor() 49 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 50 | print("save:", args.save) 51 | print("load:", args.load) 52 | #exit() 53 | 54 | if not args.load: 55 | if not args.save: 56 | args.save = now 57 | self.dir = os.path.join('..', 'experiment', args.save) 58 | else: 59 | self.dir = os.path.join('..', 'experiment', args.load) 60 | if os.path.exists(self.dir): 61 | self.log = torch.load(self.get_path('psnr_log.pt')) 62 | print('Continue from epoch {}...'.format(len(self.log))) 63 | else: 64 | args.load = '' 65 | 66 | if args.reset: 67 | os.system('rm -rf ' + self.dir) 68 | args.load = '' 69 | 70 | os.makedirs(self.dir, exist_ok=True) 71 | os.makedirs(self.get_path('model'), exist_ok=True) 72 | for d in args.data_test: 73 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 74 | 75 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 76 | self.log_file = open(self.get_path('log.txt'), open_type) 77 | with open(self.get_path('config.txt'), open_type) as f: 78 | f.write(now + '\n\n') 79 | for arg in vars(args): 80 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 81 | f.write('\n') 82 | 83 | self.n_processes = 8 84 | 85 | def get_path(self, *subdir): 86 | return os.path.join(self.dir, *subdir) 87 | 88 | def save(self, trainer, epoch, is_best=False): 89 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 90 | trainer.loss.save(self.dir) 91 | trainer.loss.plot_loss(self.dir, epoch) 92 | 93 | self.plot_psnr(epoch) 94 | trainer.optimizer.save(self.dir) 95 | torch.save(self.log, self.get_path('psnr_log.pt')) 96 | 97 | def save_patch(self, patch, epoch, is_best=False, idx=0): 98 | path = self.get_path('model/patches/') 99 | if not os.path.exists(path): 100 | os.makedirs(path) 101 | torch.save(patch.state_dict(), self.get_path('model/patches/') + f"patch_{idx}.pt") 102 | #model.save(self.get_path('model/patches'), epoch, is_best=is_best, idx=idx) 103 | 104 | def save_everyepoch(self, trainer, epoch, is_best=False): 105 | save_path = self.get_path('model')+"/" + str(epoch) 106 | if not os.path.exists(save_path): 107 | os.makedirs(save_path) 108 | trainer.model.save_every(save_path, epoch, is_best=is_best) 109 | 110 | def add_log(self, log): 111 | self.log = torch.cat([self.log, log]) 112 | 113 | def write_log(self, log, refresh=False): 114 | print(log) 115 | self.log_file.write(str(log) + '\n') 116 | if refresh: 117 | self.log_file.close() 118 | self.log_file = open(self.get_path('log.txt'), 'a') 119 | 120 | def done(self): 121 | self.log_file.close() 122 | 123 | def plot_psnr(self, epoch): 124 | axis = np.linspace(1, epoch, epoch) 125 | for idx_data, d in enumerate(self.args.data_test): 126 | label = 'SR on {}'.format(d) 127 | fig = plt.figure() 128 | plt.title(label) 129 | for idx_scale, scale in enumerate(self.args.scale): 130 | plt.plot( 131 | axis, 132 | self.log[:, idx_data, idx_scale].numpy(), 133 | label='Scale {}'.format(scale) 134 | ) 135 | plt.legend() 136 | plt.xlabel('Epochs') 137 | plt.ylabel('PSNR') 138 | plt.grid(True) 139 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 140 | plt.close(fig) 141 | 142 | def begin_background(self): 143 | self.queue = Queue() 144 | 145 | def bg_target(queue): 146 | while True: 147 | if not queue.empty(): 148 | filename, tensor = queue.get() 149 | if filename is None: break 150 | imageio.imwrite(filename, tensor.numpy()) 151 | 152 | self.process = [ 153 | Process(target=bg_target, args=(self.queue,)) \ 154 | for _ in range(self.n_processes) 155 | ] 156 | 157 | for p in self.process: p.start() 158 | 159 | def end_background(self): 160 | for _ in range(self.n_processes): self.queue.put((None, None)) 161 | while not self.queue.empty(): time.sleep(1) 162 | for p in self.process: p.join() 163 | 164 | def save_results(self, dataset, filename, save_list, scale): 165 | if self.args.save_results: 166 | filename = self.get_path( 167 | 'results-{}'.format(dataset.dataset.name), 168 | '{}_x{}_'.format(filename, scale) 169 | ) 170 | print(filename) 171 | 172 | postfix = ('SR', 'LR', 'HR') 173 | for v, p in zip(save_list, postfix): 174 | normalized = v[0].mul(255 / self.args.rgb_range) 175 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 176 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 177 | 178 | def quantize(img, rgb_range): 179 | pixel_range = 255 / rgb_range 180 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 181 | 182 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 183 | if hr.nelement() == 1: return 0 184 | diff = (sr - hr) / rgb_range 185 | if dataset and dataset.dataset.benchmark: 186 | shave = scale 187 | if diff.size(1) > 1: 188 | gray_coeffs = [65.738, 129.057, 25.064] 189 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 190 | diff = diff.mul(convert).sum(dim=1) 191 | else: 192 | shave = scale + 6 193 | 194 | valid = diff[..., shave:-shave, shave:-shave] 195 | mse = valid.pow(2).mean() 196 | if mse==0: 197 | return 1000 198 | return -10 * math.log10(mse) 199 | 200 | lpips_values = [] 201 | loss_fn = lpips.LPIPS(net='alex') 202 | def calc_lpips(sr, hr): 203 | global lpips_values 204 | sr.to(hr.device) 205 | 206 | for sr_img, hr_img in zip(sr, hr): 207 | lpips_value = loss_fn(sr_img.cpu(), hr_img.cpu()) 208 | lpips_values.append(lpips_value.item()) 209 | 210 | average_lpips = sum(lpips_values) / len(lpips_values) 211 | print("Average lpips:", average_lpips) 212 | return average_lpips 213 | 214 | import skimage.metrics 215 | ssim_values = [] 216 | def calc_ssim(sr, hr): 217 | global ssim_values 218 | for sr_img, hr_img in zip(sr, hr): 219 | sr_img_np = sr_img.permute(1, 2, 0).cpu().numpy() 220 | hr_img_np = hr_img.permute(1, 2, 0).cpu().numpy() 221 | ssim_value = skimage.metrics.structural_similarity(sr_img_np, hr_img_np, multichannel=True) 222 | ssim_values.append(ssim_value) 223 | 224 | average_ssim = sum(ssim_values) / len(ssim_values) 225 | print("Average SSIM:", average_ssim) 226 | return average_ssim 227 | 228 | def make_optimizer(args, target): 229 | ''' 230 | make optimizer and scheduler together 231 | ''' 232 | # optimizer 233 | if args.cafm: 234 | if args.finetune: 235 | trainable = [{'params':[ param for name, param in target.named_parameters() if 'transformer' in name or 'gamma' in name]}] 236 | else: 237 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 238 | else: 239 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 240 | 241 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 242 | 243 | if args.optimizer == 'SGD': 244 | optimizer_class = optim.SGD 245 | kwargs_optimizer['momentum'] = args.momentum 246 | elif args.optimizer == 'ADAM': 247 | optimizer_class = optim.Adam 248 | kwargs_optimizer['betas'] = args.betas 249 | kwargs_optimizer['eps'] = args.epsilon 250 | elif args.optimizer == 'RMSprop': 251 | optimizer_class = optim.RMSprop 252 | kwargs_optimizer['eps'] = args.epsilon 253 | 254 | # scheduler 255 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 256 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 257 | scheduler_class = lrs.MultiStepLR 258 | 259 | class CustomOptimizer(optimizer_class): 260 | def __init__(self, *args, **kwargs): 261 | super(CustomOptimizer, self).__init__(*args, **kwargs) 262 | 263 | def _register_scheduler(self, scheduler_class, **kwargs): 264 | self.scheduler = scheduler_class(self, **kwargs) 265 | 266 | def save(self, save_dir): 267 | torch.save(self.state_dict(), self.get_dir(save_dir)) 268 | 269 | def load(self, load_dir, epoch=1): 270 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 271 | if epoch > 1: 272 | for _ in range(epoch): self.scheduler.step() 273 | 274 | def get_dir(self, dir_path): 275 | return os.path.join(dir_path, 'optimizer.pt') 276 | 277 | def schedule(self): 278 | self.scheduler.step() 279 | 280 | def get_lr(self): 281 | return self.scheduler.get_lr()[0] 282 | 283 | def get_last_epoch(self): 284 | return self.scheduler.last_epoch 285 | 286 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 287 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 288 | return optimizer 289 | 290 | def make_patch_optimizer(args, target): 291 | ''' 292 | make optimizer and scheduler together 293 | ''' 294 | # optimizer 295 | if args.cafm: 296 | if args.finetune: 297 | trainable = [{'params':[ param for name, param in target.named_parameters() if 'transformer' in name or 'gamma' in name]}] 298 | else: 299 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 300 | else: 301 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 302 | 303 | kwargs_optimizer = {'lr': args.patch_lr, 'weight_decay': args.weight_decay} 304 | 305 | if args.optimizer == 'SGD': 306 | optimizer_class = optim.SGD 307 | kwargs_optimizer['momentum'] = args.patch_momentum 308 | elif args.optimizer == 'ADAM': 309 | optimizer_class = optim.Adam 310 | kwargs_optimizer['betas'] = args.patch_betas 311 | kwargs_optimizer['eps'] = args.patch_epsilon 312 | elif args.optimizer == 'RMSprop': 313 | optimizer_class = optim.RMSprop 314 | kwargs_optimizer['eps'] = args.patch_epsilon 315 | 316 | # scheduler 317 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 318 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 319 | scheduler_class = lrs.MultiStepLR 320 | 321 | class CustomOptimizer(optimizer_class): 322 | def __init__(self, *args, **kwargs): 323 | super(CustomOptimizer, self).__init__(*args, **kwargs) 324 | 325 | def _register_scheduler(self, scheduler_class, **kwargs): 326 | self.scheduler = scheduler_class(self, **kwargs) 327 | 328 | def save(self, save_dir): 329 | torch.save(self.state_dict(), self.get_dir(save_dir)) 330 | 331 | def load(self, load_dir, epoch=1): 332 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 333 | if epoch > 1: 334 | for _ in range(epoch): self.scheduler.step() 335 | 336 | def get_dir(self, dir_path): 337 | return os.path.join(dir_path, 'optimizer.pt') 338 | 339 | def schedule(self): 340 | self.scheduler.step() 341 | 342 | def get_lr(self): 343 | return self.scheduler.get_lr()[0] 344 | 345 | def get_last_epoch(self): 346 | return self.scheduler.last_epoch 347 | 348 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 349 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 350 | return optimizer 351 | 352 | def make_trainChunk(num, length, segnum): 353 | chunk = int(int(length)//int(segnum)) 354 | t_num = [[] for i in range(segnum)] 355 | for i in range(segnum): 356 | s = i*chunk + 1 357 | e = (i+1)*chunk 358 | t_num[i] = [j for j in range(len(num)) if s<=int(num[j])<=e] 359 | return t_num 360 | 361 | 362 | def make_testChunk(length, filename, segnum=3, fps=30): 363 | if segnum == 3: 364 | if length == '45s': 365 | if int(filename)<=150 or 1351 <= int(filename) <= 1365: 366 | flag = 0 367 | elif 151 <= int(filename) <= 300 or 1366 <= int(filename) <= 1380: 368 | flag = 1 369 | elif 301 <= int(filename) <= 450 or 1381 <= int(filename) <= 1395: 370 | flag = 2 371 | elif 451 <= int(filename) <= 600 or 1396 <= int(filename) <= 1410: 372 | flag = 3 373 | elif 601 <= int(filename) <= 750 or 1411 <= int(filename) <= 1425: 374 | flag = 4 375 | elif 751 <= int(filename) <= 900 or 1426 <= int(filename) <= 1440: 376 | flag = 5 377 | elif 901 <= int(filename) <= 1050 or 1441 <= int(filename) <= 1455: 378 | flag = 6 379 | elif 1051 <= int(filename) <= 1200 or 1456 <= int(filename) <= 1470: 380 | flag = 7 381 | elif 1201 <= int(filename) <= 1350 or 1471 <= int(filename) <= 1485: 382 | flag = 8 383 | else: 384 | flag = 9 385 | return flag 386 | elif length =='15s': 387 | if 1 <= int(filename)<=150 or 451 <= int(filename) <= 465: 388 | flag = 0 389 | elif 151 <= int(filename) <= 300 or 466 <= int(filename) <= 480: 390 | flag = 1 391 | elif 301 <= int(filename) <= 450 or 481 <= int(filename) <= 495: 392 | flag = 2 393 | else: 394 | flag = 3 395 | return flag 396 | elif length == '30s': 397 | if int(filename)<=150 or 901 <= int(filename) <= 915: 398 | flag = 0 399 | elif 151 <= int(filename) <= 300 or 916 <= int(filename) <= 930: 400 | flag = 1 401 | elif 301 <= int(filename) <= 450 or 931 <= int(filename) <= 945: 402 | flag = 2 403 | elif 451 <= int(filename) <= 600 or 946 <= int(filename) <= 960: 404 | flag = 3 405 | elif 601 <= int(filename) <= 750 or 961 <= int(filename) <= 975: 406 | flag = 4 407 | elif 751 <= int(filename) <= 900 or 976 <= int(filename) <= 990: 408 | flag = 5 409 | else: 410 | flag = 6 411 | return flag 412 | else: 413 | if length == '15s': 414 | seg_size = int(int(15*fps)/segnum) 415 | test_seg_size = int(45/segnum) 416 | flag = None 417 | for i in range(segnum): 418 | if i*seg_size <= int(filename) <= (i+1)*seg_size or 450 + i*test_seg_size <= int(filename) <= 450 + (i+1)*test_seg_size: 419 | flag = i 420 | break 421 | if not flag: flag = segnum 422 | return flag 423 | if length == '45s': 424 | seg_size = int(int(45*fps)/segnum) 425 | test_seg_size = int(135/segnum) 426 | flag = None 427 | for i in range(segnum): 428 | if i*seg_size <= int(filename) <= (i+1)*seg_size or 1365 + i*test_seg_size <= int(filename) <= 1365 + (i+1)*test_seg_size: 429 | flag = i 430 | break 431 | if not flag: flag = segnum 432 | return flag 433 | -------------------------------------------------------------------------------- /src/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), # type: ignore 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | --------------------------------------------------------------------------------