├── .gitignore ├── README.md ├── assets └── cover_adabm.png ├── environment.yaml └── src ├── data ├── __init__.py ├── benchmark.py ├── common.py ├── div2k.py ├── div2k_valid.py ├── srdata.py ├── test2k.py ├── test4k.py └── test8k.py ├── main.py ├── model ├── __init__.py ├── common.py ├── edsr.py ├── quantize.py ├── rdn.py └── srresnet.py ├── option.py ├── run.sh ├── trainer.py └── utility.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *.pyc 5 | 6 | /.vscode 7 | 8 | src/run_exp.sh 9 | experiment/ 10 | pretrained_model/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaBM 2 | This repository includes the official implementation of the paper [**AdaBM: On-the-Fly Adaptive Bit Mapping for Image Super-Resolution**](https://arxiv.org/abs/2404.03296) (CVPR2024). 3 | 4 | 5 | 6 | 7 | 8 | 11 | 12 |

13 | 14 |

15 | 16 | 17 | 18 | ## Requirements 19 | A suitable [conda](https://conda.io/) environment named `adabm` can be created and activated with: 20 | ``` 21 | conda env create -f environment.yaml 22 | conda activate adabm 23 | ``` 24 | 25 | ## Preparation 26 | ### Dataset 27 | * For training, we use LR images sampled from [DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar). 28 | * For testing, we use [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) and large input datasets [Test2K,4K,8K](https://drive.google.com/drive/folders/18b3QKaDJdrd9y0KwtrWU2Vp9nHxvfTZH?usp=sharing). 29 | Test8K contains the images (index 1401-1500) from [DIV8K](https://competitions.codalab.org/competitions/22217#participate). Test2K/4K contain the images (index 1201-1300/1301-1400) from DIV8K which are downsampled to 2K and 4K resolution. 30 | After downloading the datasets, the dataset directory should be organized as follows: 31 | 32 | ``` 33 | datasets 34 | -DIV2K 35 | - DIV2K_train_LR_bicubic # for training 36 | - DIV2K_train_HR 37 | - test2k # for testing 38 | - test4k 39 | - test8k 40 | -benchmark # for testing 41 | ``` 42 | 43 | ### Pretrained Models 44 | Please download the pretrained models from [here](https://drive.google.com/drive/folders/1GLuvwy3WWFG2H6iEA6-7tqRj_Jnzcn86?usp=drive_link) and place them in `pretrained_model`. 45 | 46 | ## Usage 47 | 48 | ### How to train 49 | 50 | ``` 51 | sh run.sh edsr 0 6 8 # gpu_id a_bit w_bit 52 | sh run.sh edsr 0 4 4 # gpu_id a_bit w_bit 53 | ``` 54 | 55 | ### How to test 56 | 57 | ``` 58 | sh run.sh edsr_eval 0 6 8 # gpu_id a_bit w_bit 59 | sh run.sh edsr_eval 0 4 4 # gpu_id a_bit w_bit 60 | ``` 61 | 62 | > * set `--dir_data` to the directory path for datasets. 63 | > * set `--pre_train` to the saved model path for testing model. 64 | > * the trained model is saved in `experiment` directory. 65 | > * set `--test_own` to the own image path for testing. 66 | 67 | More running scripts can be found in `run.sh`. 68 | 69 | ## Comments 70 | Our implementation is based on [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch). 71 | 72 | #### Coming Soon... 73 | - [ ] parallel patch inference 74 | 75 | ## BibTeX 76 | If you found our implementation useful, please consider citing our paper: 77 | ``` 78 | @misc{hong2024adabm, 79 | title={AdaBM: On-the-Fly Adaptive Bit Mapping for Image Super-Resolution}, 80 | author={Cheeun Hong and Kyoung Mu Lee}, 81 | year={2024}, 82 | eprint={2404.03296}, 83 | archivePrefix={arXiv}, 84 | primaryClass={cs.CV} 85 | } 86 | ``` 87 | 88 | ## Contact 89 | > Email: [cheeun914@snu.ac.kr](cheeun914@snu.ac.kr) -------------------------------------------------------------------------------- /assets/cover_adabm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheeun/AdaBM/c855f5204d7e6f7c5652219aecef55ece4d36930/assets/cover_adabm.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: adabm 2 | channels: 3 | - pytorch 4 | - comet_ml 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - attrs=21.4.0=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - blinker=1.4=py_1 12 | - brotlipy=0.7.0=py36h27cfd23_1003 13 | - c-ares=1.18.1=h7f8727e_0 14 | - ca-certificates=2021.10.8=ha878542_0 15 | - certifi=2021.5.30=py36h06a4308_0 16 | - cffi=1.14.6=py36h400218f_0 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - comet-git-pure=0.19.16=py_0 19 | - comet_ml=3.2.0=py36 20 | - configobj=5.0.6=py36h06a4308_1 21 | - coverage=4.0.3=py36_1 22 | - cryptography=35.0.0=py36hd23ed53_0 23 | - cudatoolkit=10.1.243=h6bb024c_0 24 | - cycler=0.10.0=py36_0 25 | - dataclasses=0.8=pyh4f3eec9_6 26 | - dbus=1.13.14=hb2f20db_0 27 | - everett=1.0.2=py_1 28 | - expat=2.2.6=he6710b0_0 29 | - fontconfig=2.13.0=h9420a91_0 30 | - freetype=2.9.1=h8a8886c_1 31 | - fsspec=2022.1.0=pyhd8ed1ab_0 32 | - future=0.18.2=py36h5fab9bb_3 33 | - glib=2.63.1=h3eb4bd4_1 34 | - google-auth-oauthlib=0.4.1=py_1 35 | - gst-plugins-base=1.14.0=hbbd80ab_1 36 | - gstreamer=1.14.0=hb31296c_0 37 | - icu=58.2=he6710b0_3 38 | - imageio=2.8.0=py_0 39 | - importlib_metadata=4.8.1=hd3eb1b0_0 40 | - intel-openmp=2020.1=217 41 | - jpeg=9b=h024ee3a_2 42 | - jsonschema=3.2.0=pyhd3eb1b0_2 43 | - kiwisolver=1.2.0=py36hfd86e86_0 44 | - ld_impl_linux-64=2.33.1=h53a641e_7 45 | - libedit=3.1.20181209=hc058e9b_0 46 | - libffi=3.3=he6710b0_1 47 | - libgcc-ng=9.1.0=hdf63c60_0 48 | - libgfortran-ng=7.3.0=hdf63c60_0 49 | - libpng=1.6.37=hbc83047_0 50 | - libprotobuf=3.13.0.1=h8b12597_0 51 | - libstdcxx-ng=9.1.0=hdf63c60_0 52 | - libtiff=4.1.0=h2733197_0 53 | - libuuid=1.0.3=h1bed415_2 54 | - libuv=1.40.0=h7b6447c_0 55 | - libxcb=1.13=h1bed415_1 56 | - libxml2=2.9.9=hea5a465_1 57 | - matplotlib=3.1.3=py36_0 58 | - matplotlib-base=3.1.3=py36hef1b27d_0 59 | - mkl=2020.1=217 60 | - mkl-service=2.3.0=py36he904b0f_0 61 | - mkl_fft=1.0.15=py36ha843d7b_0 62 | - mkl_random=1.1.0=py36hd6b4f25_0 63 | - ncurses=6.2=he6710b0_1 64 | - netifaces=0.10.9=py36h27cfd23_1004 65 | - ninja=1.9.0=py36hfd86e86_0 66 | - nvidia-ml=7.352.0=pyhd3eb1b0_0 67 | - olefile=0.46=py36_0 68 | - onnx=1.8.0=py36h197aa4f_0 69 | - openssl=1.1.1m=h7f8727e_0 70 | - pcre=8.43=he6710b0_0 71 | - pillow=7.1.2=py36hb39fc2d_0 72 | - pip=20.0.2=py36_3 73 | - psutil=5.8.0=py36h27cfd23_1 74 | - pyasn1=0.4.8=py_0 75 | - pycparser=2.21=pyhd3eb1b0_0 76 | - pyjwt=2.3.0=pyhd8ed1ab_1 77 | - pyopenssl=22.0.0=pyhd3eb1b0_0 78 | - pyparsing=2.4.7=py_0 79 | - pyqt=5.9.2=py36h05f1152_2 80 | - pyrsistent=0.17.3=py36h7b6447c_0 81 | - pysocks=1.7.1=py36h06a4308_0 82 | - python=3.6.10=h7579374_2 83 | - python-dateutil=2.8.1=py_0 84 | - python_abi=3.6=1_cp36m 85 | - pytorch-lightning=1.5.9=pyhd8ed1ab_0 86 | - qt=5.9.7=h5867ecd_1 87 | - readline=8.0=h7b6447c_0 88 | - sip=4.19.8=py36hf484d3e_0 89 | - six=1.14.0=py36_0 90 | - sqlite=3.31.1=h62c20be_1 91 | - tk=8.6.8=hbc83047_0 92 | - torchaudio=0.7.0=py36 93 | - torchmetrics=0.7.1=pyhd8ed1ab_0 94 | - tornado=6.0.4=py36h7b6447c_1 95 | - tqdm=4.46.0=py_0 96 | - typing-extensions=3.7.4.3=0 97 | - typing_extensions=3.7.4.3=py_0 98 | - websocket-client=0.58.0=py36h06a4308_4 99 | - werkzeug=1.0.1=pyh9f0ad1d_0 100 | - wheel=0.34.2=py36_0 101 | - wrapt=1.12.1=py36h7b6447c_1 102 | - wurlitzer=1.0.3=py_2 103 | - xz=5.2.5=h7b6447c_0 104 | - yaml=0.2.5=h516909a_0 105 | - zlib=1.2.11=h7b6447c_3 106 | - zstd=1.3.7=h0b5b093_0 107 | - pip: 108 | - absl-py==0.10.0 109 | - addict==2.4.0 110 | - antlr4-python3-runtime==4.9.3 111 | - astunparse==1.6.3 112 | - backcall==0.2.0 113 | - basicsr==1.4.2 114 | - cachetools==4.1.1 115 | - chardet==3.0.4 116 | - coloredlogs==14.0 117 | - cupy-cuda102==8.4.0 118 | - decorator==4.4.2 119 | - fastrlock==0.5 120 | - filelock==3.4.1 121 | - gast==0.3.3 122 | - google-auth==1.21.2 123 | - google-pasta==0.2.0 124 | - grpcio==1.32.0 125 | - h5py==2.10.0 126 | - huggingface-hub==0.4.0 127 | - humanfriendly==9.0 128 | - idna==2.10 129 | - imagecodecs==2020.2.18 130 | - imagenet-stubs==0.0.7 131 | - importlib-metadata==1.7.0 132 | - install==1.3.4 133 | - ipython==7.16.3 134 | - ipython-genutils==0.2.0 135 | - jedi==0.17.2 136 | - joblib==1.1.0 137 | - jsonpatch==1.32 138 | - jsonpointer==2.3 139 | - keras-preprocessing==1.1.2 140 | - kornia==0.5.11 141 | - lmdb==1.2.1 142 | - lpips==0.1.4 143 | - markdown==3.2.2 144 | - networkx==2.4 145 | - numpy==1.18.1 146 | - oauthlib==3.1.0 147 | - onnxruntime==1.6.0 148 | - opencv-python==4.2.0.34 149 | - opt-einsum==3.3.0 150 | - ort-nightly==1.7.0.dev202102191 151 | - packaging==21.0 152 | - pandas==1.1.5 153 | - parso==0.7.1 154 | - pexpect==4.8.0 155 | - pickle5==0.0.12 156 | - pickleshare==0.7.5 157 | - progress==1.6 158 | - progressbar==2.5 159 | - prompt-toolkit==3.0.36 160 | - protobuf==3.13.0 161 | - ptyprocess==0.7.0 162 | - pyasn1-modules==0.2.8 163 | - pydeprecate==0.3.1 164 | - pydot==1.4.1 165 | - pygments==2.14.0 166 | - pyhocon==0.3.59 167 | - python-graphviz==0.14.2 168 | - pytorchcv==0.0.67 169 | - pytz==2022.2.1 170 | - pywavelets==1.1.1 171 | - pyyaml==5.4.1 172 | - pyzmq==23.2.1 173 | - requests==2.24.0 174 | - requests-oauthlib==1.3.0 175 | - rsa==4.6 176 | - scikit-image==0.17.2 177 | - scikit-learn==0.24.2 178 | - scipy==1.5.0 179 | - seaborn==0.11.2 180 | - setuptools==59.5.0 181 | - tb-nightly==2.11.0a20220816 182 | - tensorboard==2.3.0 183 | - tensorboard-data-server==0.6.1 184 | - tensorboard-plugin-wit==1.7.0 185 | - tensorboardx==2.1 186 | - tensorflow==2.3.0 187 | - tensorflow-estimator==2.3.0 188 | - termcolor==1.1.0 189 | - thop==0.0.31-2005241907 190 | - threadpoolctl==3.1.0 191 | - tifffile==2020.5.11 192 | - timm==0.6.12 193 | - torch==1.7.0 194 | - torch-dct==0.1.5 195 | - torchensemble==0.1.7 196 | - torchfile==0.1.0 197 | - torchsummary==1.5.1 198 | - torchvision==0.11.2 199 | - traitlets==4.3.3 200 | - urllib3==1.25.10 201 | - visdom==0.1.8.9 202 | - warmup-scheduler==0.3 203 | - wcwidth==0.2.6 204 | - yapf==0.32.0 205 | - zipp==3.1.0 206 | prefix: /home/cheeun914/anaconda3/envs/adabm 207 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from torch.utils.data import dataloader 3 | from torch.utils.data import ConcatDataset 4 | 5 | import random 6 | import torch.utils.data as data 7 | 8 | import numpy 9 | import torch 10 | 11 | class MyConcatDataset(ConcatDataset): 12 | def __init__(self, datasets): 13 | super(MyConcatDataset, self).__init__(datasets) 14 | self.train = datasets[0].train 15 | 16 | def set_scale(self, idx_scale): 17 | for d in self.datasets: 18 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 19 | 20 | class Data: 21 | def __init__(self, args): 22 | self.loader_train = None 23 | if not args.test_only: 24 | datasets = [] 25 | for d in args.data_train: 26 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 27 | m = import_module('data.' + module_name.lower()) 28 | datasets.append(getattr(m, module_name)(args, name=d)) 29 | 30 | train_dataset = MyConcatDataset(datasets) 31 | indices = random.sample(range(0, len(train_dataset)), args.num_data) 32 | # sampled over (train_dataset[0]=001.png~train_dataset[799]=800.png) 33 | 34 | print('Indices of {} sampled data...'.format(len(indices))) 35 | print(indices) 36 | 37 | train_dataset_sampled = data.Subset(train_dataset, indices) 38 | 39 | self.loader_train = dataloader.DataLoader( 40 | train_dataset_sampled, 41 | batch_size=args.batch_size_update, 42 | shuffle=True, 43 | pin_memory=not args.cpu, 44 | num_workers=args.n_threads, 45 | ) 46 | 47 | self.loader_test = [] 48 | for d in args.data_test: 49 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 50 | m = import_module('data.benchmark') 51 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 52 | else: 53 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 54 | m = import_module('data.' + module_name.lower()) 55 | testset = getattr(m, module_name)(args, train=False, name=d) 56 | 57 | self.loader_test.append( 58 | dataloader.DataLoader( 59 | testset, 60 | batch_size=1, 61 | shuffle=False, 62 | pin_memory=not args.cpu, 63 | num_workers=args.n_threads, 64 | ) 65 | ) 66 | 67 | self.loader_init = None 68 | if not args.test_only: 69 | datasets = [] 70 | for d in args.data_train: 71 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 72 | m = import_module('data.' + module_name.lower()) 73 | datasets.append(getattr(m, module_name)(args, name=d)) 74 | 75 | self.loader_init = dataloader.DataLoader( 76 | train_dataset_sampled, 77 | batch_size=args.batch_size_calib, 78 | shuffle=True, 79 | pin_memory=not args.cpu, 80 | num_workers=args.n_threads, 81 | ) 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | def _set_filesystem(self, dir_data): 17 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 18 | self.dir_hr = os.path.join(self.apath, 'HR') 19 | if self.input_large: 20 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 21 | else: 22 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 23 | self.ext = ('', '.png') 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 7 | ih, iw = args[0].shape[:2] 8 | 9 | if not input_large: 10 | p = scale if multi else 1 11 | tp = p * patch_size 12 | ip = tp // scale 13 | else: 14 | tp = patch_size 15 | ip = patch_size 16 | 17 | ix = random.randrange(0, iw - ip + 1) 18 | iy = random.randrange(0, ih - ip + 1) 19 | # print(ix, iy) 20 | 21 | if not input_large: 22 | tx, ty = scale * ix, scale * iy 23 | else: 24 | tx, ty = ix, iy 25 | 26 | ret = [ 27 | args[0][iy:iy + ip, ix:ix + ip, :], 28 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 29 | ] 30 | 31 | return ret 32 | 33 | def set_channel(*args, n_channels=3): 34 | def _set_channel(img): 35 | if img.ndim == 2: 36 | img = np.expand_dims(img, axis=2) 37 | 38 | c = img.shape[2] 39 | if n_channels == 1 and c == 3: 40 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 41 | elif n_channels == 3 and c == 1: 42 | img = np.concatenate([img] * n_channels, 2) 43 | 44 | return img 45 | 46 | return [_set_channel(a) for a in args] 47 | 48 | def np2Tensor(*args, rgb_range=255): 49 | def _np2Tensor(img): 50 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 51 | tensor = torch.from_numpy(np_transpose).float() 52 | tensor.mul_(rgb_range / 255) 53 | 54 | return tensor 55 | 56 | return [_np2Tensor(a) for a in args] 57 | 58 | def augment(*args, hflip=True, rot=True): 59 | hflip = hflip and random.random() < 0.5 60 | vflip = rot and random.random() < 0.5 61 | rot90 = rot and random.random() < 0.5 62 | 63 | def _augment(img): 64 | if hflip: img = img[:, ::-1, :] 65 | if vflip: img = img[::-1, :, :] 66 | if rot90: img = img.transpose(1, 0, 2) 67 | 68 | return img 69 | 70 | return [_augment(a) for a in args] -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin-1 : self.end] 23 | names_lr = [n[self.begin-1 : self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | # self.dir_hr = None 30 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 31 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 32 | if self.input_large: self.dir_lr += 'L' 33 | 34 | -------------------------------------------------------------------------------- /src/data/div2k_valid.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class div2k_valid(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(div2k_valid, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'DIV2K' ) 19 | self.dir_hr = os.path.join(self.apath, 'DIV2K_valid_HR') 20 | self.dir_lr = os.path.join(self.apath, 'DIV2K_valid_LR_bicubic') 21 | self.ext = ('', '.png') 22 | -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = False #(args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('bin') >= 0: 32 | # Binary files are stored in 'bin' folder 33 | # If the binary file exists, load it. If not, make it. 34 | list_hr, list_lr = self._scan() 35 | self.images_hr = self._check_and_load( 36 | args.ext, list_hr, self._name_hrbin() 37 | ) 38 | self.images_lr = [ 39 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \ 40 | for s, l in zip(self.scale, list_lr) 41 | ] 42 | else: 43 | if args.ext.find('img') >= 0 or benchmark: 44 | self.images_hr, self.images_lr = list_hr, list_lr 45 | elif args.ext.find('sep') >= 0: 46 | os.makedirs( 47 | self.dir_hr.replace(self.apath, path_bin), 48 | exist_ok=True 49 | ) 50 | for s in self.scale: 51 | os.makedirs( 52 | os.path.join( 53 | self.dir_lr.replace(self.apath, path_bin), 54 | 'X{}'.format(s) 55 | ), 56 | exist_ok=True 57 | ) 58 | 59 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 60 | for h in list_hr: 61 | b = h.replace(self.apath, path_bin) 62 | b = b.replace(self.ext[0], '.pt') 63 | self.images_hr.append(b) 64 | self._check_and_load( 65 | args.ext, [h], b, verbose=True, load=False 66 | ) 67 | 68 | for i, ll in enumerate(list_lr): 69 | for l in ll: 70 | b = l.replace(self.apath, path_bin) 71 | b = b.replace(self.ext[1], '.pt') 72 | self.images_lr[i].append(b) 73 | self._check_and_load( 74 | args.ext, [l], b, verbose=True, load=False 75 | ) 76 | 77 | if train: 78 | n_patches = args.batch_size_update * args.test_every 79 | n_images = len(args.data_train) * len(self.images_hr) 80 | if args.num_data < 800: 81 | n_images = args.num_data 82 | 83 | if n_images == 0: 84 | self.repeat = 0 85 | else: 86 | self.repeat = max(n_patches // n_images, 1) 87 | # self.repeat == 1 for PTQ 88 | 89 | # Below functions as used to prepare images 90 | def _scan(self): 91 | names_hr = sorted( 92 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 93 | ) 94 | names_lr = [[] for _ in self.scale] 95 | for f in names_hr: 96 | filename, _ = os.path.splitext(os.path.basename(f)) 97 | for si, s in enumerate(self.scale): 98 | if 'DIV2K' in self.dir_hr and 'test' not in self.dir_hr: 99 | names_lr[si].append(os.path.join( 100 | self.dir_lr, 'X{}/{}x{}{}'.format( 101 | s, filename, s, self.ext[1] 102 | ) 103 | )) 104 | # for test2k, test4k, test8k 105 | elif 'test' in self.dir_hr: 106 | names_lr[si].append(os.path.join( 107 | self.dir_lr, 'X{}/{}{}'.format( 108 | s, filename, self.ext[1] 109 | ) 110 | )) 111 | else: 112 | names_lr[si].append(os.path.join( 113 | self.dir_lr, 'X{}/{}{}'.format( 114 | s, filename, self.ext[1] 115 | ) 116 | )) 117 | return names_hr, names_lr 118 | 119 | def _set_filesystem(self, dir_data): 120 | self.apath = os.path.join(dir_data, self.name) 121 | self.dir_hr = os.path.join(self.apath, 'HR') 122 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 123 | if self.input_large: self.dir_lr += 'L' 124 | self.ext = ('.png', '.png') 125 | 126 | def _name_hrbin(self): 127 | return os.path.join( 128 | self.apath, 129 | 'bin', 130 | '{}_bin_HR.pt'.format(self.split) 131 | ) 132 | 133 | def _name_lrbin(self, scale): 134 | return os.path.join( 135 | self.apath, 136 | 'bin', 137 | '{}_bin_LR_X{}.pt'.format(self.split, scale) 138 | ) 139 | 140 | def _check_and_load(self, ext, l, f, verbose=True, load=True): 141 | if os.path.isfile(f) and ext.find('reset') < 0: 142 | if load: 143 | if verbose: print('Loading {}...'.format(f)) 144 | with open(f, 'rb') as _f: ret = pickle.load(_f) 145 | return ret 146 | else: 147 | return None 148 | else: 149 | if verbose: 150 | if ext.find('reset') >= 0: 151 | print('Making a new binary: {}'.format(f)) 152 | else: 153 | print('{} does not exist. Now making binary...'.format(f)) 154 | b = [{ 155 | 'name': os.path.splitext(os.path.basename(_l))[0], 156 | 'image': imageio.imread(_l) 157 | } for _l in l] 158 | with open(f, 'wb') as _f: pickle.dump(b, _f) 159 | 160 | return b 161 | 162 | def __getitem__(self, idx): 163 | lr, hr, filename = self._load_file(idx) 164 | pair = self.get_patch(lr, hr) 165 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 166 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 167 | 168 | return pair_t[0], pair_t[1], filename 169 | 170 | def __len__(self): 171 | if self.train: 172 | return len(self.images_hr) * self.repeat 173 | else: 174 | return len(self.images_hr) 175 | 176 | def _get_index(self, idx): 177 | if self.train: 178 | return idx % len(self.images_hr) 179 | else: 180 | return idx 181 | 182 | def _load_file(self, idx): 183 | idx = self._get_index(idx) 184 | f_hr = self.images_hr[idx] 185 | f_lr = self.images_lr[self.idx_scale][idx] 186 | 187 | if self.args.ext.find('bin') >= 0: 188 | filename = f_hr['name'] 189 | hr = f_hr['image'] 190 | lr = f_lr['image'] 191 | else: 192 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 193 | if self.args.ext == 'img' or self.benchmark: 194 | hr = imageio.imread(f_hr) 195 | lr = imageio.imread(f_lr) 196 | elif self.args.ext.find('sep') >= 0: 197 | with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] 198 | with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] 199 | 200 | return lr, hr, filename 201 | 202 | def get_patch(self, lr, hr): 203 | scale = self.scale[self.idx_scale] 204 | if self.train: 205 | lr, hr = common.get_patch( 206 | lr, hr, 207 | patch_size=self.args.patch_size, 208 | scale=scale, 209 | multi=(len(self.scale) > 1), 210 | input_large=self.input_large 211 | ) 212 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 213 | else: 214 | ih, iw = lr.shape[:2] 215 | hr = hr[0:ih * scale, 0:iw * scale] 216 | 217 | return lr, hr 218 | 219 | def set_scale(self, idx_scale): 220 | if not self.input_large: 221 | self.idx_scale = idx_scale 222 | else: 223 | self.idx_scale = random.randint(0, len(self.scale) - 1) 224 | -------------------------------------------------------------------------------- /src/data/test2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class test2k(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(test2k, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR') 21 | self.ext = ('', '.png') 22 | -------------------------------------------------------------------------------- /src/data/test4k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class test4k(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(test4k, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR') 21 | self.ext = ('', '.png') 22 | -------------------------------------------------------------------------------- /src/data/test8k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class test8k(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(test8k, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR') 21 | self.ext = ('', '.png') 22 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utility 4 | import data 5 | import model 6 | from option import args 7 | from trainer import Trainer 8 | 9 | import time 10 | import datetime 11 | 12 | import numpy 13 | import random 14 | 15 | torch.manual_seed(args.seed) 16 | torch.backends.cudnn.deterministic=True 17 | torch.backends.cudnn.benchmark=False 18 | numpy.random.seed(args.seed) 19 | random.seed(args.seed) 20 | torch.cuda.manual_seed(args.seed) 21 | torch.cuda.manual_seed_all(args.seed) 22 | checkpoint = utility.checkpoint(args) 23 | 24 | if checkpoint.ok: 25 | exp_start_time = time.time() 26 | _loader = data.Data(args) 27 | _model = model.Model(args, checkpoint) 28 | t = Trainer(args, _loader, _model, checkpoint) 29 | 30 | # t.test_teacher() 31 | while not t.terminate(): 32 | torch.manual_seed(args.seed) 33 | t.train() 34 | t.test() 35 | 36 | exp_end_time = time.time() 37 | exp_time_interval = exp_end_time - exp_start_time 38 | t_string = "Total Running Time is: " + str(datetime.timedelta(seconds=exp_time_interval)) + "\n" 39 | checkpoint.write_log('{}'.format(t_string)) 40 | checkpoint.done() -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | from model.quantize import * 10 | 11 | class Model(nn.Module): 12 | def __init__(self, args, ckp): 13 | super(Model, self).__init__() 14 | print('Making model...') 15 | self.args = args 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.input_large = (args.model == 'VDSR') 19 | self.self_ensemble = args.self_ensemble 20 | self.chop = args.chop 21 | self.precision = args.precision 22 | self.cpu = args.cpu 23 | self.device = torch.device('cpu' if args.cpu else 'cuda') 24 | self.n_GPUs = args.n_GPUs 25 | self.save_models = args.save_models 26 | 27 | module = import_module('model.' + args.model.lower()) 28 | self.model = module.make_model(args).to(self.device) 29 | if args.precision == 'half': 30 | self.model.half() 31 | 32 | self.load( 33 | ckp.get_path('model'), 34 | pre_train=args.pre_train, 35 | resume=args.resume, 36 | cpu=args.cpu 37 | ) 38 | # print(self.model, file=ckp.log_file) 39 | 40 | def forward(self, x, idx_scale): 41 | self.idx_scale = idx_scale 42 | 43 | if hasattr(self.model, 'set_scale'): 44 | self.model.set_scale(idx_scale) 45 | 46 | if self.training: 47 | if self.n_GPUs > 1: 48 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 49 | else: 50 | return self.model(x) 51 | else: 52 | if self.chop: 53 | forward_function = self.forward_chop 54 | else: 55 | forward_function = self.model.forward 56 | 57 | if self.self_ensemble: 58 | return self.forward_x8(x, forward_function=forward_function) 59 | else: 60 | return forward_function(x) 61 | 62 | def save(self, apath, epoch, is_best=False): 63 | save_dirs = [os.path.join(apath, 'checkpoint.pt')] 64 | 65 | if is_best: 66 | save_dirs.append(os.path.join(apath, 'checkpoint_bestpsnr.pt')) 67 | if self.save_models: 68 | save_dirs.append( 69 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 70 | ) 71 | 72 | for s in save_dirs: 73 | torch.save(self.model.state_dict(), s) 74 | 75 | def load(self, apath, pre_train='', resume=-1, cpu=False): 76 | load_from = None 77 | kwargs = {} 78 | if cpu: 79 | kwargs = {'map_location': lambda storage, loc: storage} 80 | 81 | if resume == -1: 82 | load_from = torch.load( 83 | os.path.join(apath, 'model_latest.pt'), 84 | **kwargs 85 | ) 86 | elif resume == 0: 87 | if pre_train == 'download': 88 | print('Download the model') 89 | dir_model = os.path.join('..', 'models') 90 | os.makedirs(dir_model, exist_ok=True) 91 | load_from = torch.utils.model_zoo.load_url( 92 | self.model.url, 93 | model_dir=dir_model, 94 | **kwargs 95 | ) 96 | elif pre_train: 97 | print('Load the model from {}'.format(pre_train)) 98 | load_from = torch.load(pre_train, **kwargs) 99 | else: 100 | load_from = torch.load( 101 | os.path.join(apath, 'model_{}.pt'.format(resume)), 102 | **kwargs 103 | ) 104 | 105 | if load_from: 106 | load_from = load_from['state_dict'] if 'state_dict' in load_from else load_from 107 | self.model.load_state_dict(load_from, strict=False) 108 | # self.model.load_state_dict(load_from, strict=True) # for debugging 109 | 110 | def forward_chop(self, *args, shave=10, min_size=160000): 111 | scale = 1 if self.input_large else self.scale[self.idx_scale] 112 | n_GPUs = min(self.n_GPUs, 4) 113 | # height, width 114 | h, w = args[0].size()[-2:] 115 | 116 | top = slice(0, h//2 + shave) 117 | bottom = slice(h - h//2 - shave, h) 118 | left = slice(0, w//2 + shave) 119 | right = slice(w - w//2 - shave, w) 120 | x_chops = [torch.cat([ 121 | a[..., top, left], 122 | a[..., top, right], 123 | a[..., bottom, left], 124 | a[..., bottom, right] 125 | ]) for a in args] 126 | 127 | y_chops = [] 128 | if h * w < 4 * min_size: 129 | for i in range(0, 4, n_GPUs): 130 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 131 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 132 | if not isinstance(y, list): y = [y] 133 | if not y_chops: 134 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 135 | else: 136 | for y_chop, _y in zip(y_chops, y): 137 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 138 | else: 139 | for p in zip(*x_chops): 140 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 141 | if not isinstance(y, list): y = [y] 142 | if not y_chops: 143 | y_chops = [[_y] for _y in y] 144 | else: 145 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 146 | 147 | h *= scale 148 | w *= scale 149 | top = slice(0, h//2) 150 | bottom = slice(h - h//2, h) 151 | bottom_r = slice(h//2 - h, None) 152 | left = slice(0, w//2) 153 | right = slice(w - w//2, w) 154 | right_r = slice(w//2 - w, None) 155 | 156 | # batch size, number of color channels 157 | b, c = y_chops[0][0].size()[:-2] 158 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 159 | for y_chop, _y in zip(y_chops, y): 160 | _y[..., top, left] = y_chop[0][..., top, left] 161 | _y[..., top, right] = y_chop[1][..., top, right_r] 162 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 163 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 164 | 165 | if len(y) == 1: y = y[0] 166 | 167 | return y 168 | 169 | def forward_x8(self, *args, forward_function=None): 170 | def _transform(v, op): 171 | if self.precision != 'single': v = v.float() 172 | 173 | v2np = v.data.cpu().numpy() 174 | if op == 'v': 175 | tfnp = v2np[:, :, :, ::-1].copy() 176 | elif op == 'h': 177 | tfnp = v2np[:, :, ::-1, :].copy() 178 | elif op == 't': 179 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 180 | 181 | ret = torch.Tensor(tfnp).to(self.device) 182 | if self.precision == 'half': ret = ret.half() 183 | 184 | return ret 185 | 186 | list_x = [] 187 | for a in args: 188 | x = [a] 189 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 190 | 191 | list_x.append(x) 192 | 193 | list_y = [] 194 | for x in zip(*list_x): 195 | y = forward_function(*x) 196 | if not isinstance(y, list): y = [y] 197 | if not list_y: 198 | list_y = [[_y] for _y in y] 199 | else: 200 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 201 | 202 | for _list_y in list_y: 203 | for i in range(len(_list_y)): 204 | if i > 3: 205 | _list_y[i] = _transform(_list_y[i], 't') 206 | if i % 4 > 1: 207 | _list_y[i] = _transform(_list_y[i], 'h') 208 | if (i % 4) % 2 == 1: 209 | _list_y[i] = _transform(_list_y[i], 'v') 210 | 211 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 212 | if len(y) == 1: y = y[0] 213 | 214 | return y 215 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from model import quantize 8 | import numpy as np 9 | 10 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 11 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 12 | 13 | class ShortCut(nn.Module): 14 | def __init__(self): 15 | super(ShortCut, self).__init__() 16 | 17 | def forward(self, input): 18 | return input 19 | 20 | class MeanShift(nn.Conv2d): 21 | def __init__( 22 | self, rgb_range, 23 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 24 | 25 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 26 | std = torch.Tensor(rgb_std) 27 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 28 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 29 | for p in self.parameters(): 30 | p.requires_grad = False 31 | 32 | class BasicBlock(nn.Sequential): 33 | def __init__( 34 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 35 | bn=True, act=nn.ReLU(True)): 36 | 37 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 38 | if bn: 39 | m.append(nn.BatchNorm2d(out_channels)) 40 | if act is not None: 41 | m.append(act) 42 | 43 | super(BasicBlock, self).__init__(*m) 44 | 45 | 46 | class ResBlock(nn.Module): 47 | def __init__(self, args, conv, n_feats, kernel_size, bias=True, bn=False, act='relu', res_scale=1): 48 | super(ResBlock, self).__init__() 49 | 50 | self.args = args 51 | 52 | self.conv1 = quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=bias, non_adaptive=False) 53 | self.conv2 = quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=bias, non_adaptive=False) 54 | 55 | if act == 'relu': 56 | self.act = nn.ReLU(True) 57 | elif act== 'prelu': 58 | self.act = nn.PReLU() 59 | 60 | self.res_scale = res_scale 61 | self.is_bn = bn 62 | if bn: 63 | self.bn1 = nn.BatchNorm2d(n_feats) 64 | self.bn2 = nn.BatchNorm2d(n_feats) 65 | self.shortcut = ShortCut() 66 | 67 | def forward(self, x): 68 | if self.args.imgwise: 69 | bit_img = x[3] 70 | bit = x[2] 71 | feat = x[1] 72 | x = x[0] 73 | 74 | residual = self.shortcut(x) 75 | 76 | if self.args.imgwise: 77 | out,bit = self.conv1([x, bit, bit_img]) 78 | else: 79 | out,bit = self.conv1([x, bit]) 80 | 81 | if self.is_bn: 82 | out = self.bn1(out) 83 | 84 | out1 = self.act(out) 85 | 86 | if self.args.imgwise: 87 | res, bit = self.conv2([out1, bit, bit_img]) 88 | else: 89 | res, bit = self.conv2([out1,bit]) 90 | 91 | if self.is_bn: 92 | res = self.bn2(res) 93 | 94 | res = res.mul(self.res_scale) 95 | res += residual 96 | 97 | if feat is None: 98 | feat = res / torch.norm(res, p=2) 99 | else: 100 | feat = torch.cat([feat, res / torch.norm(res, p=2)]) 101 | 102 | if self.args.imgwise: 103 | return [res, feat, bit, bit_img] 104 | else: 105 | return [res, feat, bit] 106 | 107 | 108 | class Upsampler(nn.Sequential): 109 | def __init__(self, args, conv, scale, n_feats, bn=False, act=False, bias=True, fq=False): 110 | m = [] 111 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 112 | for _ in range(int(math.log(scale, 2))): 113 | if fq: 114 | m.append(conv(args, n_feats, 4*n_feats, 3, bias=bias, non_adaptive=True, to_8bit=True)) 115 | else: 116 | m.append(conv(n_feats, 4 * n_feats, 3, bias=bias)) 117 | 118 | m.append(nn.PixelShuffle(2)) 119 | if bn: 120 | m.append(nn.BatchNorm2d(n_feats)) 121 | if act == 'relu': 122 | m.append(nn.ReLU(True)) 123 | elif act == 'prelu': 124 | m.append(nn.PReLU()) 125 | 126 | elif scale == 3: 127 | if fq: 128 | m.append(conv(args, n_feats, 9 * n_feats, 3, bias=bias, non_adaptive=True, to_8bit=True)) 129 | else: 130 | m.append(conv(n_feats, 9 * n_feats, 3, bias=bias)) 131 | 132 | m.append(nn.PixelShuffle(3)) 133 | if bn: 134 | m.append(nn.BatchNorm2d(n_feats)) 135 | if act == 'relu': 136 | m.append(nn.ReLU(True)) 137 | elif act == 'prelu': 138 | m.append(nn.PReLU()) 139 | else: 140 | raise NotImplementedError 141 | 142 | super(Upsampler, self).__init__(*m) -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import quantize 3 | 4 | import torch.nn as nn 5 | import torch 6 | 7 | import kornia as K 8 | import numpy as np 9 | 10 | import cv2 11 | import math 12 | 13 | def make_model(args, parent=False): 14 | return EDSR(args) 15 | 16 | class EDSR(nn.Module): 17 | def __init__(self, args, conv=common.default_conv): 18 | super(EDSR, self).__init__() 19 | n_feats = args.n_feats 20 | kernel_size = 3 21 | scale = args.scale[0] 22 | act = 'relu' 23 | 24 | self.sub_mean = common.MeanShift(args.rgb_range) 25 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 26 | self.fq = args.fq 27 | 28 | 29 | # Head module 30 | if args.fq: 31 | m_head = [quantize.QConv2d(args, args.n_colors, n_feats, kernel_size, bias=True, non_adaptive=True, to_8bit=True)] 32 | else: 33 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 34 | 35 | # Body module 36 | m_body = [ 37 | common.ResBlock( 38 | args, conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 39 | ) for _ in range(args.n_resblocks) 40 | ] 41 | 42 | if args.fq: 43 | m_body.append(quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=True)) 44 | else: 45 | m_body.append(conv(n_feats, n_feats, kernel_size)) 46 | 47 | # Tail module 48 | if args.fq: 49 | m_tail = [ 50 | common.Upsampler(args, quantize.QConv2d, scale, n_feats, act=False, fq=args.fq), 51 | quantize.QConv2d(args, n_feats, args.n_colors, kernel_size, bias=True, non_adaptive=True, to_8bit=True) 52 | ] 53 | else: 54 | m_tail = [ 55 | common.Upsampler(args, conv, scale, n_feats, act=False), 56 | conv(n_feats, args.n_colors, kernel_size) 57 | ] 58 | 59 | self.head = nn.Sequential(*m_head) 60 | self.body = nn.Sequential(*m_body) 61 | self.tail = nn.Sequential(*m_tail) 62 | 63 | if args.imgwise: 64 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda()) 65 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda()) 66 | self.tanh = nn.Tanh() 67 | self.ema_epoch = 1 68 | self.init = False 69 | 70 | self.args = args 71 | 72 | 73 | def forward(self, x): 74 | if self.args.imgwise: 75 | image = x 76 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 77 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3 78 | 79 | if self.init: 80 | # print(image_grad) 81 | if self.ema_epoch == 1: 82 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) 83 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) 84 | nn.init.constant_(self.measure_l, measure_l) 85 | nn.init.constant_(self.measure_u, measure_u) 86 | else: 87 | beta = self.args.ema_beta 88 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta) 89 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta) 90 | nn.init.constant_(self.measure_l, new_measure_l.item()) 91 | nn.init.constant_(self.measure_u, new_measure_u.item()) 92 | 93 | self.ema_epoch += 1 94 | bit_img = torch.Tensor([0.0]).cuda() 95 | 96 | else: 97 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l)) 98 | bit_img_soft = self.tanh(bit_img_soft) 99 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0) 100 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach() # the order matters 101 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1) 102 | 103 | x = self.sub_mean(x) 104 | 105 | if self.args.fq: 106 | bit_fq = torch.zeros(x.shape[0]).cuda() 107 | x, bit_fq= self.head([x, bit_fq]) 108 | 109 | else: 110 | x = self.head(x) 111 | 112 | feat = None 113 | bit = torch.zeros(x.shape[0]).cuda() 114 | 115 | res = x 116 | 117 | if self.args.imgwise: 118 | res, feat, bit, bit_img = self.body[:-1]([res, feat, bit, bit_img]) 119 | else: 120 | res, feat, bit = self.body[:-1]([res, feat, bit]) 121 | 122 | 123 | if self.args.fq: 124 | if self.args.imgwise: 125 | res, bit = self.body[-1:]([res, bit, bit_img]) 126 | else: 127 | res, bit = self.body[-1:]([res, bit]) 128 | else: 129 | res = self.body[-1:](res) 130 | 131 | res += x 132 | 133 | if self.args.fq: 134 | res, bit_fq = self.tail[0][0]([res, bit_fq]) # conv 135 | res = self.tail[0][1](res) # ps 136 | if len(self.tail[0]) == 4: 137 | res, bit_fq = self.tail[0][2]([res, bit_fq]) # conv 138 | res = self.tail[0][3](res) # ps 139 | x, bit_fq = self.tail[-1]([res, bit_fq]) # conv 140 | else: 141 | x = self.tail(res) 142 | 143 | x = self.add_mean(x) 144 | 145 | 146 | return x, feat, bit 147 | 148 | 149 | -------------------------------------------------------------------------------- /src/model/quantize.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import time 10 | 11 | class Round(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x_in): 14 | x_out = torch.round(x_in) 15 | return x_out 16 | @staticmethod 17 | def backward(ctx, g): 18 | return g, None 19 | 20 | class QConv2d(nn.Conv2d): 21 | def __init__(self, args, in_channels, out_channels, kernel_size, stride=1, 22 | padding=1, bias=False, dilation=1, groups=1, non_adaptive=False, to_8bit=False): 23 | super(QConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation, groups=groups) 24 | self.args = args 25 | if not bias: self.bias = None 26 | self.dilation = (dilation, dilation) 27 | 28 | # For quantizing activations 29 | self.lower_a = nn.Parameter(torch.FloatTensor([-128]).cuda()) 30 | self.upper_a = nn.Parameter(torch.FloatTensor([128]).cuda()) 31 | self.round_a = Round.apply 32 | self.a_bit = self.args.quantize_a 33 | 34 | # For quantizing weights 35 | self.upper_w = nn.Parameter(torch.FloatTensor([128]).cuda()) 36 | self.round_w = Round.apply 37 | self.w_bit = self.args.quantize_w 38 | 39 | self.non_adaptive = non_adaptive 40 | self.to_8bit = to_8bit 41 | 42 | if self.to_8bit: 43 | self.w_bit = 8.0 44 | self.a_bit = 8.0 45 | 46 | if not self.non_adaptive and self.args.layerwise: 47 | self.std_layer = [] 48 | self.measure_layer = nn.Parameter(torch.FloatTensor([0]).cuda()) 49 | self.tanh = nn.Tanh() 50 | 51 | self.ema_epoch = 1 52 | self.bac_epoch = 1 53 | self.init = False 54 | 55 | def init_qparams_a(self, x, quantizer=None): 56 | # Obtain statistics 57 | if quantizer == 'minmax': 58 | lower_a = torch.min(x).detach().cpu() 59 | upper_a = torch.max(x).detach().cpu() 60 | 61 | elif quantizer == 'percentile': 62 | try: 63 | lower_a = torch.quantile(x.reshape(-1), 1.0-self.args.percentile_alpha).detach().cpu() 64 | upper_a = torch.quantile(x.reshape(-1), self.args.percentile_alpha).detach().cpu() 65 | except: 66 | lower_a = np.percentile(x.reshape(-1).detach().cpu(), (1.0-self.args.percentile_alpha)*100.0) 67 | upper_a = np.percentile(x.reshape(-1).detach().cpu(), self.args.percentile_alpha*100.0) 68 | 69 | elif quantizer == 'omse': 70 | lower_a = torch.min(x).detach() 71 | upper_a = torch.max(x).detach() 72 | best_score = 1e+10 73 | for i in range(90): 74 | new_lower = lower_a * (1.0 - i*0.01) 75 | new_upper = upper_a * (1.0 - i*0.01) 76 | x_q = torch.clamp(x, min= new_lower, max=new_upper) 77 | x_q = (x_q - new_lower) / (new_upper - new_lower) 78 | x_q = torch.round(x_q * (2**self.args.quantize_a -1)) / (2**self.args.quantize_a -1) 79 | x_q = x_q * (new_upper - new_lower) + new_lower 80 | score = (x - x_q).abs().pow(2.0).mean() 81 | if score < best_score: 82 | best_score = score 83 | best_lower = new_lower 84 | best_upper = new_upper 85 | lower = best_lower.cpu() 86 | upper = best_upper.cpu() 87 | 88 | # Update q params 89 | if self.ema_epoch == 1: 90 | nn.init.constant_(self.lower_a, lower_a) 91 | nn.init.constant_(self.upper_a, upper_a) 92 | else: 93 | beta = self.args.ema_beta 94 | lower_a = lower_a * (1-beta) + self.lower_a * beta 95 | upper_a = upper_a * (1-beta) + self.upper_a * beta 96 | nn.init.constant_(self.lower_a, lower_a.item()) 97 | nn.init.constant_(self.upper_a, upper_a.item()) 98 | 99 | 100 | def init_qparams_w(self, w, quantizer=None): 101 | if quantizer == 'minmax': 102 | upper_w = torch.max(torch.abs(self.weight)).detach() 103 | elif quantizer == 'percentile': 104 | try: 105 | upper_w = torch.quantile(self.weight.reshape(-1), self.args.percentile_alpha).detach().cpu() 106 | except: 107 | upper_w = np.percentile(self.weight.reshape(-1).detach().cpu(), self.args.percentile_alpha*100.0) 108 | elif quantizer == 'omse': 109 | upper_w = torch.max(self.weight).detach() 110 | best_score_w = 1e+10 111 | for i in range(50): 112 | new_upper_w = upper_w * (1.0 - i*0.01) 113 | w_q = torch.clamp(self.weight, min=-new_upper_w, max=new_upper_w).detach() 114 | w_q = (w_q + new_upper_w) / (2*new_upper_w ) 115 | w_q = torch.round(w_q * (2**self.args.quantize_w -1)) / (2**self.args.quantize_w -1) 116 | w_q = w_q * (2*new_upper_w) - new_upper_w 117 | score = (self.weight - w_q).abs().pow(2.0).mean().detach().cpu() 118 | if score < best_score_w: 119 | best_score_w = score 120 | best_i = i 121 | upper = new_upper_w 122 | upper_w = upper.cpu() 123 | 124 | nn.init.constant_(self.upper_w, upper_w) 125 | 126 | def forward(self,x): 127 | if self.args.imgwise and not self.non_adaptive: 128 | bit_img = x[2] 129 | bit = x[1] 130 | x = x[0] 131 | 132 | if self.w_bit == 32: 133 | if self.init: 134 | # Initialize q params 135 | self.init_qparams_a(x, quantizer=self.args.quantizer) 136 | if self.ema_epoch == 1: 137 | self.init_qparams_w(self.weight, quantizer=self.args.quantizer_w) 138 | 139 | if not self.non_adaptive and self.args.layerwise: 140 | measure = lambda x: torch.mean(torch.std(x.detach(), dim=(1,2,3,))) 141 | measure_layer = measure(x) 142 | self.std_layer.append(measure_layer.detach().cpu().numpy()) 143 | 144 | self.ema_epoch += 1 145 | 146 | a_bit = torch.Tensor([32.0]).cuda() 147 | w = self.weight 148 | 149 | else: 150 | # Obtain bit-width 151 | if not self.non_adaptive and (self.args.imgwise or self.args.layerwise): 152 | a_bit = self.a_bit 153 | if self.args.imgwise: 154 | a_bit += bit_img 155 | if self.args.layerwise: 156 | bit_layer_hard = torch.round(torch.clamp(self.measure_layer, min=-1.0, max=1.0)) 157 | bit_layer_soft = self.tanh(self.measure_layer) 158 | bit_layer = bit_layer_soft - bit_layer_soft.detach() + bit_layer_hard.detach() 159 | a_bit += bit_layer 160 | else: 161 | a_bit = torch.tensor(self.a_bit).repeat(x.shape[0], 1, 1, 1).cuda() 162 | 163 | # Bit-aware Clipping 164 | if self.args.bac: 165 | do_bac = self.bac_epoch == 1 166 | # Do BaC after init phase ends 167 | if do_bac: 168 | self.bac_epoch += 1 169 | if self.training and not self.to_8bit: 170 | best_score = 1e+10 171 | lower_a = self.lower_a 172 | upper_a = self.upper_a 173 | 174 | for i in range(100): 175 | new_lower_a = self.lower_a * (1.0 - i*0.01) 176 | new_upper_a = self.upper_a * (1.0 - i*0.01) 177 | x_q_temp = torch.clamp(x.clone().detach(), min= new_lower_a, max=new_upper_a) 178 | x_q_temp = (x_q_temp - new_lower_a) / (new_upper_a - new_lower_a) 179 | if not self.non_adaptive and self.args.layerwise: 180 | x_q_temp = torch.round(x_q_temp * (2**(self.a_bit+bit_layer_hard) -1)) / (2**(self.a_bit+bit_layer_hard) -1) 181 | else: 182 | x_q_temp = torch.round(x_q_temp * (2**(self.a_bit) -1)) / (2**(self.a_bit) -1) 183 | 184 | x_q_temp = x_q_temp * (new_upper_a - new_lower_a) + new_lower_a 185 | score = (x.clone().detach() - x_q_temp).abs().pow(2.0).mean() 186 | if score < best_score: 187 | best_i = i 188 | best_score = score 189 | lower_a = new_lower_a 190 | upper_a = new_upper_a 191 | 192 | new_lower = self.lower_a * self.args.bac_beta + lower_a * (1-self.args.bac_beta) 193 | new_upper = self.upper_a * self.args.bac_beta + upper_a * (1-self.args.bac_beta) 194 | nn.init.constant_(self.lower_a, new_lower.item()) 195 | nn.init.constant_(self.upper_a, new_upper.item()) 196 | 197 | # Quantize activations 198 | x_c = torch.clamp(x, min=self.lower_a, max=self.upper_a) 199 | x_c2 = (x_c - self.lower_a) / (self.upper_a - self.lower_a) 200 | x_c2 = x_c2 * (2**a_bit -1) 201 | x_int = self.round_a(x_c2) 202 | x_int = x_int / (2**a_bit -1) 203 | x_q = x_int * (self.upper_a - self.lower_a) + self.lower_a 204 | x = x_q 205 | 206 | # Quantize weights 207 | w_c = torch.clamp(self.weight, min=-self.upper_w, max=self.upper_w) 208 | w_c2 = (w_c + self.upper_w) / (2*self.upper_w ) 209 | w_c2 = (w_c2) * (2**self.w_bit-1) 210 | w_int = self.round_w(w_c2) 211 | w_int = w_int / (2**self.w_bit-1) 212 | w_q = w_int * (2*self.upper_w) - self.upper_w 213 | w = w_q 214 | 215 | self.padding = (self.kernel_size[0]//2, self.kernel_size[1]//2) 216 | out = F.conv2d(x, w, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 217 | bit += a_bit.view(-1) 218 | 219 | return out, bit -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import quantize 3 | import torch 4 | import torch.nn as nn 5 | import kornia as K 6 | 7 | def make_model(args, parent=False): 8 | return RDN(args) 9 | 10 | class RDB_Conv(nn.Module): 11 | def __init__(self, args, inChannels, growRate, kSize=3): 12 | super(RDB_Conv, self).__init__() 13 | Cin = inChannels 14 | G = growRate 15 | bias = True 16 | self.args = args 17 | self.conv = nn.Sequential(*[ 18 | quantize.QConv2d(args, Cin, G, kSize, stride=1, padding=(kSize-1)//2, bias=True, dilation=1, groups=1, non_adaptive=False), 19 | nn.ReLU() 20 | ]) 21 | 22 | def forward(self, x): 23 | if self.args.imgwise: 24 | bit_img = x[2] 25 | bit = x[1] 26 | x = x[0] 27 | out = x 28 | 29 | if self.args.imgwise: 30 | out, bit = self.conv[0]([out, bit, bit_img]) 31 | else: 32 | out, bit = self.conv[0]([out, bit]) 33 | 34 | out = self.conv[1](out) 35 | out = torch.cat((x, out), 1) 36 | 37 | if self.args.imgwise: 38 | return [out, bit, bit_img] 39 | else: 40 | return [out, bit] 41 | 42 | 43 | class RDB(nn.Module): 44 | def __init__(self, args, growRate0, growRate, nConvLayers, kSize=3): 45 | super(RDB, self).__init__() 46 | G0 = growRate0 47 | G = growRate 48 | C = nConvLayers 49 | 50 | convs = [] 51 | for c in range(C): 52 | convs.append(RDB_Conv(args, G0 + c*G, G, kSize)) 53 | self.convs = nn.Sequential(*convs) 54 | 55 | # Local Feature Fusion 56 | self.LFF = quantize.QConv2d(args, G0 + C*G, G0, 1, padding=0, stride=1, bias=True, non_adaptive=True, to_8bit=False) # 1x1 conv is non-adaptively quantized 57 | self.args = args 58 | 59 | def forward(self, x): 60 | if self.args.imgwise: 61 | bit_img = x[3] 62 | bit = x[2] 63 | feat = x[1] 64 | x = x[0] 65 | 66 | out = x 67 | if self.args.imgwise: 68 | out, bit, bit_img = self.convs([out, bit, bit_img]) 69 | else: 70 | out, bit = self.convs([out, bit]) 71 | 72 | out, bit = self.LFF([out, bit]) 73 | out += x 74 | 75 | feat_ = out / torch.norm(out, p=2) / (out.shape[1]*out.shape[2]*out.shape[3]) 76 | if feat is None: 77 | feat = feat_ 78 | else: 79 | feat = torch.cat([feat, feat_]) 80 | 81 | if self.args.imgwise: 82 | return out, feat, bit, bit_img 83 | else: 84 | return out, feat, bit 85 | 86 | class RDN(nn.Module): 87 | def __init__(self, args): 88 | super(RDN, self).__init__() 89 | r = args.scale[0] 90 | G0 = args.G0 91 | kSize = args.RDNkSize 92 | 93 | self.args = args 94 | 95 | # number of RDB blocks, conv layers, out channels 96 | self.D, self.C, self.G = { 97 | 'A': (20, 6, 32), 98 | 'B': (16, 8, 64), 99 | }[args.RDNconfig] 100 | 101 | # Shallow feature extraction net 102 | if args.fq: 103 | self.SFENet1 = quantize.QConv2d(args, args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True) 104 | self.SFENet2 = quantize.QConv2d(args, G0, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True) 105 | else: 106 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 107 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 108 | 109 | # Redidual dense blocks and dense feature fusion 110 | self.RDBs = nn.ModuleList() 111 | for i in range(self.D): 112 | self.RDBs.append( 113 | RDB(args, growRate0 = G0, growRate = self.G, nConvLayers = self.C) 114 | ) 115 | 116 | # Global Feature Fusion 117 | self.GFF = nn.Sequential(*[ 118 | quantize.QConv2d(args, self.D * G0, G0, 1, padding=0, stride=1, bias=True), 119 | quantize.QConv2d(args, G0, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True) 120 | ]) 121 | 122 | # Up-sampling net 123 | if args.fq: 124 | if r == 2 or r == 3: 125 | self.UPNet = nn.Sequential(*[ 126 | quantize.QConv2d(args, G0, self.G * r * r, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True), 127 | nn.PixelShuffle(r), 128 | quantize.QConv2d(args, self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True) 129 | ]) 130 | elif r == 4: 131 | self.UPNet = nn.Sequential(*[ 132 | quantize.QConv2d(args, G0, self.G * 4, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True), 133 | nn.PixelShuffle(2), 134 | quantize.QConv2d(args, self.G, self.G * 4, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True), 135 | nn.PixelShuffle(2), 136 | quantize.QConv2d(args, self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True) 137 | ]) 138 | else: 139 | raise ValueError("scale must be 2 or 3 or 4.") 140 | 141 | else: 142 | if r == 2 or r == 3: 143 | self.UPNet = nn.Sequential(*[ 144 | nn.Conv2d(G0, self.G * r * r, kSize, padding=(kSize-1)//2, stride=1), 145 | nn.PixelShuffle(r), 146 | nn.Conv2d(self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 147 | ]) 148 | elif r == 4: 149 | self.UPNet = nn.Sequential(*[ 150 | nn.Conv2d(G0, self.G * 4, kSize, padding=(kSize-1)//2, stride=1), 151 | nn.PixelShuffle(2), 152 | nn.Conv2d(self.G, self.G * 4, kSize, padding=(kSize-1)//2, stride=1), 153 | nn.PixelShuffle(2), 154 | nn.Conv2d(self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 155 | ]) 156 | else: 157 | raise ValueError("scale must be 2 or 3 or 4.") 158 | 159 | if args.imgwise: 160 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda()) 161 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda()) 162 | self.tanh = nn.Tanh() 163 | self.ema_epoch = 1 164 | self.init = False 165 | 166 | def forward(self, x): 167 | if self.args.imgwise: 168 | image = x 169 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 170 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3 171 | 172 | if self.init: 173 | # print(image_grad) 174 | if self.ema_epoch == 1: 175 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) 176 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) 177 | nn.init.constant_(self.measure_l, measure_l) 178 | nn.init.constant_(self.measure_u, measure_u) 179 | else: 180 | beta = self.args.ema_beta 181 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta) 182 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta) 183 | nn.init.constant_(self.measure_l, new_measure_l.item()) 184 | nn.init.constant_(self.measure_u, new_measure_u.item()) 185 | 186 | self.ema_epoch += 1 187 | bit_img = torch.Tensor([0.0]).cuda() 188 | else: 189 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l)) 190 | bit_img_soft = self.tanh(bit_img_soft) 191 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0) 192 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach()# the order matters 193 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1) 194 | 195 | feat = None; bit = torch.zeros(x.shape[0]).cuda(); bit_fq = torch.zeros(x.shape[0]).cuda() 196 | 197 | if self.args.fq: 198 | f__1, bit_fq = self.SFENet1([x, bit_fq]) 199 | x, bit_fq = self.SFENet2([f__1, bit_fq]) 200 | else: 201 | f__1 = self.SFENet1(x) 202 | x = self.SFENet2(f__1) 203 | 204 | RDBs_out = [] 205 | for i in range(self.D): 206 | if self.args.imgwise: 207 | x, feat, bit, bit_img = self.RDBs[i]([x, feat, bit, bit_img]) 208 | else: 209 | x, feat, bit = self.RDBs[i]([x, feat, bit]) 210 | RDBs_out.append(x) 211 | 212 | if self.args.imgwise: 213 | x, bit = self.GFF[0]([torch.cat(RDBs_out,1), bit, bit_img]) 214 | x, bit = self.GFF[1]([x, bit, bit_img]) 215 | else: 216 | x, bit = self.GFF[0]([torch.cat(RDBs_out,1), bit]) 217 | x, bit = self.GFF[1]([x, bit]) 218 | 219 | x += f__1 220 | 221 | if self.args.fq: 222 | out, bit_fq = self.UPNet[0]([x, bit_fq]) 223 | out = self.UPNet[1](out) 224 | out, bit_fq = self.UPNet[2]([out, bit_fq]) 225 | if len(self.UPNet) > 3: 226 | out = self.UPNet[3](out) 227 | out, bit_fq = self.UPNet[4]([out, bit_fq]) 228 | else: 229 | out = self.UPNet(x) 230 | 231 | return out, feat, bit 232 | -------------------------------------------------------------------------------- /src/model/srresnet.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from model import quantize 6 | import kornia as K 7 | 8 | def make_model(args, parent=False): 9 | return SRResNet(args) 10 | 11 | class SRResNet(nn.Module): 12 | def __init__(self, args, conv=common.default_conv): 13 | super(SRResNet, self).__init__() 14 | 15 | n_resblocks = args.n_resblocks 16 | n_feats = args.n_feats 17 | kernel_size = 3 18 | scale = args.scale[0] 19 | 20 | # Head module 21 | self.fq = args.fq 22 | if args.fq: 23 | m_head = [quantize.QConv2d(args, args.n_colors, n_feats, kernel_size=9, bias=False, non_adaptive=True, to_8bit=True)] 24 | else: 25 | m_head = [conv(args.n_colors, n_feats, kernel_size=9, bias=False)] 26 | m_head.append(nn.PReLU()) 27 | 28 | # Body module 29 | act = 'prelu' 30 | m_body = [ 31 | common.ResBlock( 32 | args, conv, n_feats, kernel_size, bias=False, bn=True, act=act, res_scale=args.res_scale 33 | ) for _ in range(n_resblocks) 34 | ] 35 | 36 | if args.fq: 37 | m_body.append(quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=False)) 38 | else: 39 | m_body.append(conv(n_feats, n_feats, kernel_size, bias=False)) 40 | m_body.append( nn.BatchNorm2d(n_feats)) 41 | 42 | # Tail module 43 | if args.fq: 44 | m_tail = [ 45 | common.Upsampler(args, quantize.QConv2d, scale, n_feats, act=act, fq=args.fq, bias=False), 46 | quantize.QConv2d(args, n_feats, args.n_colors, kernel_size=9, bias=False, non_adaptive=True, to_8bit=True) 47 | ] 48 | else: 49 | m_tail = [ 50 | common.Upsampler(args, conv, scale, n_feats, act=act, bias=False), 51 | conv(n_feats, args.n_colors, kernel_size=9, bias=False) 52 | ] 53 | 54 | self.head = nn.Sequential(*m_head) 55 | self.body = nn.Sequential(*m_body) 56 | self.tail = nn.Sequential(*m_tail) 57 | 58 | if args.imgwise: 59 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda()) 60 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda()) 61 | self.tanh = nn.Tanh() 62 | self.ema_epoch = 1 63 | self.init = False 64 | 65 | self.args = args 66 | 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | if m.bias is not None: 73 | m.bias.data.zero_() 74 | 75 | def forward(self, x): 76 | if self.args.imgwise: 77 | image = x 78 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 79 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3 80 | 81 | if self.init: 82 | # print(image_grad) 83 | if self.ema_epoch == 1: 84 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) 85 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) 86 | nn.init.constant_(self.measure_l, measure_l) 87 | nn.init.constant_(self.measure_u, measure_u) 88 | else: 89 | beta = self.args.ema_beta 90 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta) 91 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta) 92 | nn.init.constant_(self.measure_l, new_measure_l.item()) 93 | nn.init.constant_(self.measure_u, new_measure_u.item()) 94 | 95 | self.ema_epoch += 1 96 | bit_img = torch.Tensor([0.0]).cuda() 97 | 98 | else: 99 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l)) 100 | bit_img_soft = self.tanh(bit_img_soft) 101 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0) 102 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach() # the order matters 103 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1) 104 | 105 | bit_fq = torch.zeros(x.shape[0]).cuda() 106 | 107 | if self.fq: 108 | x, bit_fq= self.head[0]([x, bit_fq]) 109 | x = self.head[1:](x) 110 | else: 111 | x = self.head(x) 112 | 113 | feat = None 114 | bit = torch.zeros(x.shape[0]).cuda() 115 | 116 | res = x 117 | 118 | 119 | if self.args.imgwise: 120 | res, feat, bit, bit_img = self.body[:-2]([res, feat, bit, bit_img]) 121 | else: 122 | res, feat, bit = self.body[:-2]([res, feat, bit]) 123 | 124 | 125 | if self.fq: 126 | if self.args.imgwise: 127 | res, bit = self.body[-2]([res, bit, bit_img]) 128 | else: 129 | res, bit =self.body[-2]([res, bit]) 130 | 131 | res = self.body[-1](res) 132 | else: 133 | res = self.body[-2:](res) 134 | 135 | 136 | res+= x 137 | 138 | if self.fq: 139 | res, bit_fq = self.tail[0][0]([res, bit_fq]) # conv 140 | res = self.tail[0][1](res) # PS 141 | res = self.tail[0][2](res) # prelu 142 | res1 =res 143 | 144 | if len(self.tail[0]) > 2: 145 | res, bit_fq = self.tail[0][3]([res, bit_fq]) # conv 146 | res = self.tail[0][4](res) # PS 147 | res = self.tail[0][5](res) # prelu 148 | res2 = res 149 | x, bit_fq = self.tail[-1]([res, bit_fq]) # conv 150 | else: 151 | x = self.tail(res) 152 | 153 | return x, feat, bit 154 | 155 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 4 | 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=6, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='../dataset', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../test', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='DIV2K', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default='DIV2K', 29 | help='test dataset name') 30 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 31 | help='train/test data range') 32 | parser.add_argument('--ext', type=str, default='sep', 33 | help='dataset file extension') 34 | parser.add_argument('--scale', type=str, default='4', 35 | help='super resolution scale') 36 | parser.add_argument('--patch_size', type=int, default=192, 37 | help='output patch size') 38 | parser.add_argument('--rgb_range', type=int, default=255, 39 | help='maximum value of RGB') 40 | parser.add_argument('--n_colors', type=int, default=3, 41 | help='number of color channels to use') 42 | parser.add_argument('--chop', action='store_true', 43 | help='enable memory-efficient forward') 44 | parser.add_argument('--no_augment', action='store_true', 45 | help='do not use data augmentation') 46 | 47 | # Model specifications 48 | parser.add_argument('--model', default='EDSR', 49 | help='model name') 50 | parser.add_argument('--act', type=str, default='relu', 51 | help='activation function') 52 | parser.add_argument('--pre_train', type=str, default='', 53 | help='pre-trained model directory') 54 | parser.add_argument('--extend', type=str, default='.', 55 | help='pre-trained model directory') 56 | parser.add_argument('--n_resblocks', type=int, default=16, 57 | help='number of residual blocks') 58 | parser.add_argument('--n_feats', type=int, default=64, 59 | help='number of feature maps') 60 | parser.add_argument('--res_scale', type=float, default=1, 61 | help='residual scaling') 62 | parser.add_argument('--shift_mean', default=True, 63 | help='subtract pixel mean from the input') 64 | parser.add_argument('--dilation', action='store_true', 65 | help='use dilated convolution') 66 | parser.add_argument('--precision', type=str, default='single', 67 | choices=('single', 'half'), 68 | help='FP precision for test (single | half)') 69 | 70 | # Option for Residual dense network (RDN) 71 | parser.add_argument('--G0', type=int, default=64, 72 | help='default number of filters. (Use in RDN)') 73 | parser.add_argument('--RDNkSize', type=int, default=3, 74 | help='default kernel size. (Use in RDN)') 75 | parser.add_argument('--RDNconfig', type=str, default='B', 76 | help='parameters config of RDN. (Use in RDN)') 77 | 78 | 79 | # Training specifications 80 | parser.add_argument('--reset', action='store_true', 81 | help='reset the training') 82 | parser.add_argument('--test_every', type=int, default=1000, 83 | help='do test per every N batches') 84 | # this is same as iters per batch 85 | parser.add_argument('--epochs', type=int, default=300, 86 | help='number of epochs to train') 87 | parser.add_argument('--batch_size', type=int, default=16, 88 | help='input batch size for training') 89 | parser.add_argument('--split_batch', type=int, default=1, 90 | help='split the batch into smaller chunks') 91 | parser.add_argument('--self_ensemble', action='store_true', 92 | help='use self-ensemble method for test') 93 | parser.add_argument('--test_only', action='store_true', 94 | help='set this option to test the model') 95 | parser.add_argument('--gan_k', type=int, default=1, 96 | help='k value for adversarial loss') 97 | 98 | # Optimization specifications 99 | parser.add_argument('--step', type=int, default='1', 100 | help='learning rate step size') 101 | parser.add_argument('--gamma', type=float, default=0.9, 102 | help='learning rate decay factor for step decay') 103 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 104 | help='ADAM beta') 105 | parser.add_argument('--epsilon', type=float, default=1e-8, 106 | help='ADAM epsilon for numerical stability') 107 | 108 | # Log specifications 109 | parser.add_argument('--save', type=str, default='test', 110 | help='file name to save') 111 | parser.add_argument('--load', type=str, default='', 112 | help='file name to load') 113 | parser.add_argument('--resume', type=int, default=0, 114 | help='resume from specific checkpoint') 115 | parser.add_argument('--save_models', action='store_true', 116 | help='save all intermediate models') 117 | parser.add_argument('--print_every', type=int, default=100, 118 | help='how many batches to wait before logging training status') 119 | parser.add_argument('--save_results', action='store_true', 120 | help='save output results') 121 | parser.add_argument('--save_gt', action='store_true', 122 | help='save low-resolution and high-resolution images together') 123 | 124 | parser.add_argument('--quantize_a', type=float, default=32, help='activation_bit') 125 | parser.add_argument('--quantize_w', type=float, default=32, help='weight_bit') 126 | 127 | parser.add_argument('--batch_size_calib', type=int, default=16, help='input batch size for calib') 128 | parser.add_argument('--batch_size_update', type=int, default=2, help='input batch size for update') 129 | parser.add_argument('--num_data', type=int, default=800, help='number of data for PTQ') 130 | 131 | parser.add_argument('--fq', action='store_true', help='fully quantize') 132 | parser.add_argument('--imgwise', action='store_true', help='imgwise') 133 | parser.add_argument('--layerwise', action='store_true', help='layerwise') 134 | parser.add_argument('--bac', action='store_true', help='bac') 135 | 136 | parser.add_argument('--lr_w', type=float, default=0.001, help='learning rate') 137 | parser.add_argument('--lr_a', type=float, default=0.05, help='learning rate') 138 | parser.add_argument('--lr_measure_layer', type=float, default=0.005, help='learning rate') 139 | parser.add_argument('--lr_measure_img', type=float, default=0.01, help='learning rate') 140 | parser.add_argument('--w_bitloss', type=float, default=50.0, help='weight for bit loss') 141 | parser.add_argument('--w_sktloss', type=float, default=10.0, help='weight for skt loss') 142 | 143 | parser.add_argument('--test_patch', action='store_true', help='test patch') 144 | parser.add_argument('--test_patch_size', type=int, default=96, help='test patch size') 145 | parser.add_argument('--test_step_size', type=int, default=90, help='test step size') 146 | parser.add_argument('--test_own', type=str, default=None, help='directory for own test image') 147 | parser.add_argument('--n_parallel', type=int, default=1, help='number of patches for parallel processing') 148 | 149 | parser.add_argument('--quantizer', default='minmax', choices=('minmax', 'percentile', 'omse'), help='quantizer to use') 150 | parser.add_argument('--quantizer_w', default='minmax', choices=('minmax', 'percentile', 'omse'), help='quantizer to use') 151 | parser.add_argument('--percentile_alpha', type=float, default=0.99, help='used when quantizer is percentile') 152 | 153 | parser.add_argument('--ema_beta', type=float, default=0.9, help='beta for EMA') 154 | parser.add_argument('--bac_beta', type=float, default=0.5, help='beta for EMA in BaC') 155 | 156 | parser.add_argument('--img_percentile', type=float, default=10.0, help='clip percentile for u,l that ranges from 0~100') 157 | parser.add_argument('--layer_percentile', type=float, default=30.0, help='clip percentile for u,l that ranges from 0~100') 158 | 159 | args = parser.parse_args() 160 | 161 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 162 | args.data_train = args.data_train.split('+') 163 | args.data_test = args.data_test.split('+') 164 | 165 | if args.epochs == 0: 166 | args.epochs = 1e8 167 | 168 | for arg in vars(args): 169 | if vars(args)[arg] == 'True': 170 | vars(args)[arg] = True 171 | elif vars(args)[arg] == 'False': 172 | vars(args)[arg] = False 173 | 174 | -------------------------------------------------------------------------------- /src/run.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=/mnt/disk1/cheeun914/datasets/ 2 | scale=4 3 | 4 | edsr() { 5 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 6 | --model EDSR --scale $scale \ 7 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 8 | --pre_train ../pretrained_model/edsr_baseline_x$scale.pt \ 9 | --epochs 10 --test_every 50 --print_every 10 \ 10 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \ 11 | --data_test Set5 --dir_data $DATA_DIR \ 12 | --quantize_a $2 --quantize_w $3 \ 13 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \ 14 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \ 15 | --w_bitloss 50.0 --w_sktloss 10.0 --imgwise --layerwise --bac \ 16 | --img_percentile 10.0 --layer_percentile 30.0 \ 17 | --save edsrbaseline_x$scale/w$3a$2-adabm-nonfq \ 18 | --seed 1 \ 19 | # 20 | } 21 | 22 | edsr_fq() { 23 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 24 | --model EDSR --scale $scale \ 25 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 26 | --pre_train ../pretrained_model/edsr_baseline_x$scale.pt \ 27 | --epochs 10 --test_every 50 --print_every 10 \ 28 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \ 29 | --data_test Set5+Urban100 --dir_data $DATA_DIR \ 30 | --quantize_a $2 --quantize_w $3 \ 31 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \ 32 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \ 33 | --w_bitloss 50.0 --w_sktloss 10.0 --imgwise --layerwise --bac \ 34 | --img_percentile 10.0 --layer_percentile 30.0 \ 35 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq \ 36 | --seed 1 \ 37 | --fq \ 38 | # 39 | } 40 | 41 | edsr_eval() { 42 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 43 | --model EDSR --scale $scale \ 44 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 45 | --data_test Urban100+test2k+test4k --dir_data $DATA_DIR \ 46 | --quantize_a $2 --quantize_w $3 \ 47 | --imgwise --layerwise \ 48 | --test_only \ 49 | --save edsrbaseline_x$scale/w$3a$2-adabm-nonfq \ 50 | --test_patch --test_patch_size 96 --test_step_size 96 \ 51 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-nonfq.pt \ 52 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-nonfq/model/checkpoint.pt \ 53 | # --save_results \ 54 | } 55 | 56 | edsr_fq_eval() { 57 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 58 | --model EDSR --scale $scale \ 59 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 60 | --data_test Set5+Set14+B100+Urban100 --dir_data $DATA_DIR \ 61 | --quantize_a $2 --quantize_w $3 \ 62 | --imgwise --layerwise \ 63 | --fq \ 64 | --test_only \ 65 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq-test \ 66 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-fq.pt \ 67 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \ 68 | # --save_results \ 69 | } 70 | 71 | edsr_fq_eval_own() { 72 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 73 | --model EDSR --scale $scale \ 74 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 75 | --data_test Set5 --dir_data $DATA_DIR \ 76 | --quantize_a $2 --quantize_w $3 \ 77 | --imgwise --layerwise \ 78 | --test_only \ 79 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq-test \ 80 | --fq \ 81 | --test_patch --test_patch_size 96 --test_step_size 96 \ 82 | --test_own '/dir/to/own/test/img' \ 83 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-fq.pt \ 84 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \ 85 | # --save_results \ 86 | } 87 | 88 | rdn_fq(){ 89 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 90 | --model RDN --scale $scale \ 91 | --pre_train ../pretrained_model/rdn_baseline_x$scale.pt \ 92 | --epochs 10 --test_every 50 --print_every 10 \ 93 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 288 \ 94 | --data_test Set5 --dir_data $DATA_DIR \ 95 | --quantize_a $2 --quantize_w $3 \ 96 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \ 97 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \ 98 | --w_bitloss 50.0 --imgwise --layerwise --bac \ 99 | --img_percentile 10.0 --layer_percentile 30.0 \ 100 | --save rdn_x$scale/w$3a$2-adabm-fq \ 101 | --seed 1 \ 102 | --fq \ 103 | # 104 | } 105 | 106 | rdn_fq_eval(){ 107 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 108 | --model RDN --scale $scale \ 109 | --data_test Set5+Set14+B100+Urban100 --dir_data $DATA_DIR \ 110 | --quantize_a $2 --quantize_w $3 \ 111 | --imgwise --layerwise \ 112 | --test_only \ 113 | --fq \ 114 | --save rdn_x$scale/w$3a$2-adabm-fq-test \ 115 | --pre_train ../pretrained_model/rdn_x$scale-w$3a$2-adabm-fq.pt \ 116 | # --pre_train ../experiment/rdn_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \ 117 | # --save_results \ 118 | } 119 | 120 | srresnet() { 121 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 122 | --model SRResNet --scale $scale \ 123 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 124 | --pre_train ../pretrained_model/bnsrresnet_x$scale.pt \ 125 | --epochs 10 --test_every 50 --print_every 10 \ 126 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \ 127 | --data_test Set5 --dir_data $DATA_DIR \ 128 | --quantize_a $2 --quantize_w $3 \ 129 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \ 130 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \ 131 | --w_bitloss 50.0 --imgwise --layerwise --bac \ 132 | --img_percentile 10.0 --layer_percentile 30.0 \ 133 | --save srresnet_x$scale/w$3a$2-adabm-nonfq \ 134 | --seed 1 \ 135 | # 136 | } 137 | 138 | srresnet_eval() { 139 | CUDA_VISIBLE_DEVICES=$1 python main.py \ 140 | --model SRResNet --scale $scale \ 141 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \ 142 | --data_test Urban100 --dir_data $DATA_DIR \ 143 | --quantize_a $2 --quantize_w $3 \ 144 | --imgwise --layerwise \ 145 | --test_only \ 146 | --save srresnet_x$scale/w$3a$2-adabm-nonfq-test \ 147 | --test_patch --test_patch_size 96 --test_step_size 96 \ 148 | --pre_train ../pretrained_model/srresnet_x$scale-w$3a$2-adabm-nonfq.pt \ 149 | # --pre_train ../experiment/srresnet_x$scale/w$3a$2-adabm-nonfq/model/checkpoint.pt \ 150 | # --save_results \ 151 | } 152 | 153 | "$@" 154 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import math 5 | import time 6 | import datetime 7 | import shutil 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn.utils as utils 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | from torchvision.utils import save_image 18 | 19 | from decimal import Decimal 20 | from tqdm import tqdm 21 | import cv2 22 | 23 | import utility 24 | from model.quantize import QConv2d 25 | import kornia as K 26 | 27 | class Trainer(): 28 | def __init__(self, args, loader, my_model, ckp): 29 | self.args = args 30 | self.scale = args.scale 31 | self.ckp = ckp 32 | self.loader_init = loader.loader_init 33 | self.loader_train = loader.loader_train 34 | self.loader_test = loader.loader_test 35 | 36 | self.model = my_model 37 | self.epoch = 0 38 | 39 | shutil.copyfile('./trainer.py', os.path.join(self.ckp.dir, 'trainer.py')) 40 | shutil.copyfile('./model/quantize.py', os.path.join(self.ckp.dir, 'quantize.py')) 41 | 42 | quant_params_a = [v for k, v in self.model.model.named_parameters() if '_a' in k] 43 | quant_params_w = [v for k, v in self.model.model.named_parameters() if '_w' in k] 44 | 45 | if args.layerwise or args.imgwise: 46 | quant_params_measure= [] 47 | if args.layerwise: 48 | quant_params_measure_layer = [v for k, v in self.model.model.named_parameters() if 'measure_layer' in k] 49 | quant_params_measure.append({'params': quant_params_measure_layer, 'lr': args.lr_measure_layer}) 50 | if args.imgwise: 51 | quant_params_measure_image = [v for k, v in self.model.model.named_parameters() if 'measure' in k and 'measure_layer' not in k] 52 | quant_params_measure.append({'params': quant_params_measure_image, 'lr': args.lr_measure_img}) 53 | 54 | self.optimizer_measure = torch.optim.Adam(quant_params_measure, betas=args.betas, eps=args.epsilon) 55 | self.scheduler_measure = lrs.StepLR(self.optimizer_measure, step_size=args.step, gamma=args.gamma) 56 | 57 | self.optimizer_a = torch.optim.Adam(quant_params_a, lr=args.lr_a, betas=args.betas, eps=args.epsilon) 58 | self.optimizer_w = torch.optim.Adam(quant_params_w, lr=args.lr_w, betas=args.betas, eps=args.epsilon) 59 | self.scheduler_a = lrs.StepLR(self.optimizer_a, step_size=args.step, gamma=args.gamma) 60 | self.scheduler_w = lrs.StepLR(self.optimizer_w, step_size=args.step, gamma=args.gamma) 61 | 62 | self.skt_losses = utility.AverageMeter() 63 | self.pix_losses = utility.AverageMeter() 64 | self.bit_losses = utility.AverageMeter() 65 | 66 | self.num_quant_modules = 0 67 | for n, m in self.model.named_modules(): 68 | if isinstance(m, QConv2d): 69 | if not m.to_8bit: # 8-bit (first or last) modules are excluded for the bit count 70 | self.num_quant_modules +=1 71 | 72 | # for initialization 73 | if not args.test_only: 74 | for n, m in self.model.named_modules(): 75 | if isinstance(m, QConv2d): 76 | setattr(m, 'w_bit', 32.0) 77 | setattr(m, 'a_bit', 32.0) 78 | setattr(m, 'init', True) 79 | if args.imgwise: 80 | setattr(self.model.model, 'init', True) 81 | 82 | def get_stage_optimizer_scheduler(self): 83 | if self.args.imgwise or self.args.layerwise: 84 | # w -> a -> measure 85 | if (self.epoch-1) % 3 == 0: 86 | param_name = '_w' 87 | optimizer = self.optimizer_w 88 | scheduler = self.scheduler_w 89 | elif (self.epoch-1) % 3 == 1: 90 | param_name = '_a' 91 | optimizer = self.optimizer_a 92 | scheduler = self.scheduler_a 93 | else: 94 | param_name = 'measure' 95 | optimizer = self.optimizer_measure 96 | scheduler = self.scheduler_measure 97 | else: 98 | if (self.epoch-1) % 2 == 0: 99 | param_name = '_w' 100 | optimizer = self.optimizer_w 101 | scheduler = self.scheduler_w 102 | else: 103 | param_name = '_a' 104 | optimizer = self.optimizer_a 105 | scheduler = self.scheduler_a 106 | 107 | return param_name, optimizer, scheduler 108 | 109 | def set_bit(self, teacher=False): 110 | for n, m in self.model.named_modules(): 111 | if isinstance(m, QConv2d): 112 | if teacher: 113 | setattr(m, 'w_bit', 32.0) 114 | setattr(m, 'a_bit', 32.0) 115 | elif m.non_adaptive: 116 | if m.to_8bit: 117 | setattr(m, 'w_bit', 8.0) 118 | setattr(m, 'a_bit', 8.0) 119 | else: 120 | setattr(m, 'w_bit', self.args.quantize_w) 121 | setattr(m, 'a_bit', self.args.quantize_a) 122 | else: 123 | setattr(m, 'w_bit', self.args.quantize_w) 124 | setattr(m, 'a_bit', self.args.quantize_a) 125 | 126 | setattr(m, 'init', False) 127 | 128 | if self.args.imgwise: 129 | setattr(self.model.model, 'init', False) 130 | 131 | 132 | def train(self): 133 | if self.epoch > 0: 134 | param_name, optimizer, scheduler = self.get_stage_optimizer_scheduler() 135 | epoch_update = 'Update param ' + param_name 136 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 137 | self.ckp.write_log( 138 | '\n[Epoch {}]\t {}\t Learning rate for param: {:.2e}'.format( 139 | self.epoch, 140 | epoch_update, 141 | Decimal(lr)) 142 | ) 143 | 144 | self.model.train() 145 | 146 | timer_data, timer_model = utility.timer(), utility.timer() 147 | start_time = time.time() 148 | 149 | if self.epoch == 0: 150 | # Initialize Q parameters using freezed FP model 151 | params = self.model.named_parameters() 152 | for name1, params1 in params: 153 | params1.requires_grad=False 154 | 155 | for batch, (lr, _, idx_scale,) in enumerate(self.loader_init): 156 | lr, = self.prepare(lr) 157 | 158 | timer_data.hold() 159 | timer_model.tic() 160 | torch.cuda.empty_cache() 161 | with torch.no_grad(): 162 | sr_temp, feat_temp, bit_temp = self.model(lr, idx_scale) 163 | display_bit = bit_temp.mean() / self.num_quant_modules 164 | 165 | if self.args.imgwise or self.args.layerwise: 166 | self.ckp.write_log('[{}/{}] [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format( 167 | (batch + 1) * self.loader_init.batch_size, 168 | len(self.loader_init.dataset), 169 | display_bit, 170 | timer_model.release(), 171 | timer_data.release(), 172 | )) 173 | 174 | if self.args.layerwise: 175 | measure_layer_list = [] 176 | for n, m in self.model.named_modules(): 177 | if isinstance(m, QConv2d) and not m.non_adaptive: 178 | if hasattr(m, 'std_layer') and len(m.std_layer) > 0: 179 | measure_layer_list.append(np.mean(m.std_layer)) 180 | mu = np.mean(measure_layer_list) 181 | sig = np.std(measure_layer_list) 182 | 183 | lower = np.percentile(measure_layer_list, self.args.layer_percentile) 184 | upper = np.percentile(measure_layer_list, 100.0-self.args.layer_percentile) 185 | 186 | for n, m in self.model.named_modules(): 187 | if isinstance(m, QConv2d): 188 | if not m.non_adaptive: 189 | normalized_measure = (np.mean(m.std_layer) < lower) * (-1.0) + (np.mean(m.std_layer) > upper) * (1.0) 190 | torch.nn.init.constant_(m.measure_layer, normalized_measure) 191 | m.std_layer.clear() 192 | 193 | print('Calibration done!') 194 | # self.set_bit(teacher=False) 195 | 196 | if self.args.imgwise: 197 | print('image-lower:{:.3f}, image-upper:{:.3f}'.format( 198 | self.model.model.measure_l.data.item(), 199 | self.model.model.measure_u.data.item(), 200 | )) 201 | if self.args.layerwise: 202 | bit_layer_list = [] 203 | for n, m in self.model.named_modules(): 204 | if isinstance(m, QConv2d): 205 | if not m.non_adaptive: 206 | bit_layer_list.append(int(m.measure_layer.data.item())) 207 | print(bit_layer_list) 208 | print(np.mean(bit_layer_list)) 209 | else: 210 | # Update quantization parameters 211 | for k, v in self.model.named_parameters(): 212 | if param_name in k: 213 | v.requires_grad=True 214 | else: 215 | v.requires_grad=False 216 | 217 | self.bit_losses.reset() 218 | self.pix_losses.reset() 219 | self.skt_losses.reset() 220 | 221 | for batch, (lr, _, idx_scale,) in enumerate(self.loader_train): 222 | lr, = self.prepare(lr) 223 | timer_data.hold() 224 | timer_model.tic() 225 | 226 | optimizer.zero_grad() 227 | 228 | with torch.no_grad(): 229 | self.set_bit(teacher=True) 230 | sr_t, feat_t, bit_t = self.model(lr, idx_scale) 231 | 232 | self.set_bit(teacher=False) 233 | sr, feat, bit = self.model(lr, idx_scale) 234 | 235 | loss = 0.0 236 | pix_loss = F.l1_loss(sr, sr_t) 237 | 238 | loss += pix_loss 239 | 240 | skt_loss = 0.0 241 | for block in range(len(feat)): 242 | skt_loss += self.args.w_sktloss * torch.norm(feat[block]-feat_t[block], p=2) / sr.shape[0] / len(feat) 243 | loss += skt_loss 244 | 245 | if self.args.layerwise or self.args.imgwise: 246 | average_bit = bit.mean() / self.num_quant_modules 247 | bit_loss = self.args.w_bitloss* torch.max(average_bit-(self.args.quantize_a), torch.zeros_like(average_bit)) 248 | if param_name == 'measure': 249 | loss += bit_loss 250 | 251 | loss.backward() 252 | 253 | self.pix_losses.update(pix_loss.item(), lr.size(0)) 254 | display_pix_loss = f'L_pix: {self.pix_losses.avg: .3f}' 255 | self.skt_losses.update(skt_loss.item(), lr.size(0)) 256 | display_skt_loss = f'L_skt: {self.skt_losses.avg: .3f}' 257 | 258 | if self.args.layerwise or self.args.imgwise: 259 | self.bit_losses.update(bit_loss.item(), lr.size(0)) 260 | display_bit_loss = f'L_bit: {self.bit_losses.avg: .3f}' 261 | 262 | optimizer.step() 263 | timer_model.hold() 264 | 265 | if (batch + 1) % self.args.print_every == 0: 266 | display_bit = bit.mean() / self.num_quant_modules 267 | if self.args.layerwise or self.args.imgwise: 268 | self.ckp.write_log('[{}/{}]\t{} \t{} \t{} [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format( 269 | (batch + 1) * self.loader_train.batch_size, 270 | len(self.loader_train.dataset), 271 | display_pix_loss, 272 | display_skt_loss, 273 | display_bit_loss, 274 | display_bit, 275 | timer_model.release(), 276 | timer_data.release(), 277 | )) 278 | else: 279 | self.ckp.write_log('[{}/{}]\t{} \t{} [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format( 280 | (batch + 1) * self.loader_train.batch_size, 281 | len(self.loader_train.dataset), 282 | display_pix_loss, 283 | display_skt_loss, 284 | display_bit, 285 | timer_model.release(), 286 | timer_data.release(), 287 | )) 288 | timer_data.tic() 289 | 290 | scheduler.step() 291 | 292 | self.epoch += 1 293 | 294 | end_time = time.time() 295 | time_interval = end_time - start_time 296 | t_string = "Epoch Running Time is: " + str(datetime.timedelta(seconds=time_interval)) + "\n" 297 | self.ckp.write_log('{}'.format(t_string)) 298 | 299 | def patch_inference(self, model, lr, idx_scale): 300 | patch_idx = 0 301 | tot_bit_image = 0 302 | if self.args.n_parallel!=1: 303 | lr_list, num_h, num_w, h, w = utility.crop_parallel(lr, self.args.test_patch_size, self.args.test_step_size) 304 | sr_list = torch.Tensor().cuda() 305 | for lr_sub_index in range(len(lr_list)// self.args.n_parallel + 1): 306 | torch.cuda.empty_cache() 307 | with torch.no_grad(): 308 | sr_sub, feat, bit = self.model(lr_list[lr_sub_index* self.args.n_parallel: (lr_sub_index+1)*self.args.n_parallel], idx_scale) 309 | sr_sub = utility.quantize(sr_sub, self.args.rgb_range) 310 | sr_list = torch.cat([sr_list, sr_sub]) 311 | average_bit = bit.mean() / self.num_quant_modules 312 | tot_bit_image += average_bit 313 | patch_idx += 1 314 | sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0]) 315 | else: 316 | lr_list, num_h, num_w, h, w = utility.crop(lr, self.args.test_patch_size, self.args.test_step_size) 317 | sr_list = [] 318 | for lr_sub_img in lr_list: 319 | torch.cuda.empty_cache() 320 | with torch.no_grad(): 321 | sr_sub, feat, bit = self.model(lr_sub_img, idx_scale) 322 | sr_sub = utility.quantize(sr_sub, self.args.rgb_range) 323 | sr_list.append(sr_sub) 324 | average_bit = bit.mean() / self.num_quant_modules 325 | tot_bit_image += average_bit 326 | patch_idx += 1 327 | sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0]) 328 | 329 | bit = tot_bit_image / patch_idx 330 | 331 | return sr, feat, bit 332 | 333 | def test(self): 334 | torch.set_grad_enabled(False) 335 | 336 | # if True: 337 | if self.epoch > 1 or self.args.test_only: 338 | self.ckp.write_log('\nEvaluation:') 339 | self.ckp.add_log( 340 | torch.zeros(1, len(self.loader_test), len(self.scale)) 341 | ) 342 | self.model.eval() 343 | timer_test = utility.timer() 344 | 345 | if self.epoch == 2 or self.args.test_only: 346 | ################### Num of Params, Storage Size #################### 347 | n_params = 0 348 | n_params_q = 0 349 | for k, v in self.model.named_parameters(): 350 | nn = np.prod(v.size()) 351 | n_params += nn 352 | 353 | if 'weight' in k: 354 | name_split = k.split(".") 355 | del name_split[-1] 356 | module_temp = self.model 357 | for n in name_split: 358 | module_temp = getattr(module_temp, n) 359 | if isinstance(module_temp, QConv2d): 360 | n_params_q += nn * module_temp.w_bit / 32.0 361 | # print(k, module_temp.w_bit) 362 | else: 363 | n_params_q += nn 364 | else: 365 | n_params_q += nn 366 | 367 | self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3))) 368 | self.ckp.write_log('Model Size: {:.3f}K'.format(n_params_q/(10**3))) 369 | 370 | if self.args.save_results: 371 | self.ckp.begin_background() 372 | 373 | ############################## TEST FOR OWN ############################# 374 | if self.args.test_own is not None: 375 | test_img = cv2.imread(self.args.test_own) 376 | lr = torch.tensor(test_img).permute(2,0,1).float().cuda() 377 | lr = torch.flip(lr, (0,)) # for color 378 | lr = lr.unsqueeze(0) 379 | 380 | tot_bit = 0 381 | for idx_scale, scale in enumerate(self.scale): 382 | if self.args.test_patch: 383 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale) 384 | img_bit = bit 385 | else: 386 | with torch.no_grad(): 387 | sr, feat, bit = self.model(lr, idx_scale) 388 | img_bit = bit.mean() / self.num_quant_modules 389 | 390 | sr = utility.quantize(sr, self.args.rgb_range) 391 | save_list = [sr] 392 | 393 | 394 | filename = self.args.test_own.split('/')[-1].split('.')[0] 395 | if self.args.save_results: 396 | save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit) 397 | self.ckp.save_results('test_own', save_name, save_list) 398 | 399 | self.ckp.write_log('[{} x{}] Average Bit: {:.2f} '.format(filename, scale, img_bit)) 400 | 401 | ############################## TEST FOR TEST SET ############################# 402 | if self.args.test_own is None: 403 | for idx_data, d in enumerate(self.loader_test): 404 | for idx_scale, scale in enumerate(self.scale): 405 | d.dataset.set_scale(idx_scale) 406 | tot_ssim =0 407 | tot_bit =0 408 | i=0 409 | bitops =0 410 | for lr, hr, filename in tqdm(d, ncols=80): 411 | i+=1 412 | lr, hr = self.prepare(lr, hr) 413 | 414 | if self.args.test_patch: 415 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale) 416 | if self.args.n_parallel!=1: hr = hr[:, :, :lr.shape[2]*self.scale[0], :lr.shape[2]*self.scale[0]] 417 | img_bit = bit.item() 418 | else: 419 | with torch.no_grad(): 420 | sr, feat, bit = self.model(lr, idx_scale) 421 | img_bit = bit.mean().item() / self.num_quant_modules 422 | 423 | 424 | sr = utility.quantize(sr, self.args.rgb_range) 425 | save_list = [sr] 426 | 427 | psnr, ssim = utility.calc_psnr(sr, hr, scale, self.args.rgb_range, dataset=d) 428 | 429 | self.ckp.ssim_log[-1, idx_data, idx_scale] += ssim 430 | self.ckp.log[-1, idx_data, idx_scale] += psnr 431 | self.ckp.bit_log[-1, idx_data, idx_scale] += img_bit 432 | 433 | if self.args.save_gt: 434 | save_list.extend([lr, hr]) 435 | 436 | if self.args.save_results: 437 | save_name = '{}_x{}_{:.2f}dB_{:.2f}bit'.format(filename[0], scale, psnr, img_bit) 438 | self.ckp.save_results(d, save_name, save_list, scale) 439 | 440 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 441 | self.ckp.ssim_log[-1, idx_data, idx_scale] /= len(d) 442 | self.ckp.bit_log[-1, idx_data, idx_scale] /= len(d) 443 | 444 | best = self.ckp.log.max(0) 445 | self.ckp.write_log( 446 | '[{} x{}]\tPSNR: {:.3f} \t SSIM: {:.4f} \tBit: {:.2f} \t(Best: {:.3f} @epoch {})'.format( 447 | d.dataset.name, 448 | scale, 449 | self.ckp.log[-1, idx_data, idx_scale], 450 | self.ckp.ssim_log[-1, idx_data, idx_scale], 451 | self.ckp.bit_log[-1, idx_data, idx_scale], 452 | best[0][idx_data, idx_scale], 453 | best[1][idx_data, idx_scale] + 1 454 | ) 455 | ) 456 | 457 | if self.args.save_results: 458 | self.ckp.end_background() 459 | 460 | # save models 461 | if not self.args.test_only: 462 | self.ckp.save(self, self.epoch, is_best=(best[1][0, 0] + 1 == self.epoch -1)) 463 | 464 | torch.set_grad_enabled(True) 465 | 466 | def test_teacher(self): 467 | torch.set_grad_enabled(False) 468 | self.model.eval() 469 | self.ckp.write_log('Teacher Evaluation') 470 | 471 | ############################## Num of Params #################### 472 | n_params = 0 473 | for k, v in self.model.named_parameters(): 474 | if '_a' not in k and '_w' not in k and 'measure' not in k: # for teacher model 475 | n_params += np.prod(v.size()) 476 | self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3))) 477 | 478 | if self.args.save_results: 479 | self.ckp.begin_background() 480 | 481 | ############################## TEST FOR OWN ############################# 482 | if self.args.test_own is not None: 483 | test_img = cv2.imread(self.args.test_own) 484 | lr = torch.tensor(test_img).permute(2,0,1).float().cuda() 485 | lr = torch.flip(lr, (0,)) # for color 486 | lr = lr.unsqueeze(0) 487 | 488 | tot_bit = 0 489 | for idx_scale, scale in enumerate(self.scale): 490 | self.set_bit(teacher=True) 491 | if self.args.test_patch: 492 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale) 493 | img_bit = bit 494 | else: 495 | with torch.no_grad(): 496 | sr, feat, bit = self.model(lr, idx_scale) 497 | img_bit = bit.mean() / self.num_quant_modules 498 | self.set_bit(teacher=False) 499 | 500 | sr = utility.quantize(sr, self.args.rgb_range) 501 | save_list = [sr] 502 | if self.args.save_results: 503 | filename = self.args.test_own.split('/')[-1].split('.')[0] 504 | save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit) 505 | self.ckp.save_results('test_own', save_name, save_list) 506 | 507 | ############################## TEST FOR TEST SET ############################# 508 | if self.args.test_own is None: 509 | for idx_data, d in enumerate(self.loader_test): 510 | for idx_scale, scale in enumerate(self.scale): 511 | d.dataset.set_scale(idx_scale) 512 | tot_ssim =0 513 | tot_bit =0 514 | tot_psnr =0.0 515 | i=0 516 | for lr, hr, filename in tqdm(d, ncols=80): 517 | i+=1 518 | lr, hr = self.prepare(lr, hr) 519 | self.set_bit(teacher=True) 520 | if self.args.test_patch: 521 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale) 522 | img_bit = bit 523 | else: 524 | with torch.no_grad(): 525 | sr, feat, bit = self.model(lr, idx_scale) 526 | img_bit = bit.mean() / self.num_quant_modules 527 | self.set_bit(teacher=False) 528 | 529 | sr = utility.quantize(sr, self.args.rgb_range) 530 | save_list = [sr] 531 | psnr, ssim = utility.calc_psnr(sr, hr, scale, self.args.rgb_range, dataset=d) 532 | 533 | tot_bit += img_bit 534 | tot_psnr += psnr 535 | tot_ssim += ssim 536 | 537 | if self.args.save_gt: 538 | save_list.extend([lr, hr]) 539 | 540 | if self.args.save_results: 541 | save_name = '{}_x{}_{:.2f}dB'.format(filename[0], scale, cur_psnr) 542 | self.ckp.save_results(d, save_name, save_list) 543 | 544 | tot_psnr /= len(d) 545 | tot_ssim /= len(d) 546 | tot_bit /= len(d) 547 | 548 | self.ckp.write_log( 549 | '[{} x{}]\tPSNR: {:.3f} \t SSIM: {:.4f} \tBit: {:.2f}'.format( 550 | d.dataset.name, 551 | scale, 552 | tot_psnr, 553 | tot_ssim, 554 | tot_bit.item(), 555 | ) 556 | ) 557 | 558 | if self.args.save_results: 559 | self.ckp.end_background() 560 | 561 | torch.set_grad_enabled(True) 562 | 563 | 564 | 565 | def prepare(self, *args): 566 | device = torch.device('cpu' if self.args.cpu else 'cuda') 567 | def _prepare(tensor): 568 | if self.args.precision == 'half': tensor = tensor.half() 569 | return tensor.to(device) 570 | 571 | return [_prepare(a) for a in args] 572 | 573 | def terminate(self): 574 | if self.args.test_only: 575 | self.test() 576 | return True 577 | else: 578 | # return self.epoch >= self.args.epochs 579 | return self.epoch > self.args.epochs 580 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.optim.lr_scheduler as lrs 18 | 19 | from pathlib import Path 20 | import shutil 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import logging 24 | import coloredlogs 25 | import cv2 26 | import functools 27 | from torchvision.utils import make_grid 28 | from decimal import Decimal 29 | from math import exp 30 | 31 | 32 | class AverageMeter(object): 33 | def __init__(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | if self.count > 0: 50 | self.avg = self.sum / self.count 51 | 52 | def accumulate(self, val, n=1): 53 | self.sum += val 54 | self.count += n 55 | if self.count > 0: 56 | self.avg = self.sum / self.count 57 | 58 | class timer(): 59 | def __init__(self): 60 | self.acc = 0 61 | self.tic() 62 | 63 | def tic(self): 64 | self.t0 = time.time() 65 | 66 | def toc(self, restart=False): 67 | diff = time.time() - self.t0 68 | if restart: self.t0 = time.time() 69 | return diff 70 | 71 | def hold(self): 72 | self.acc += self.toc() 73 | 74 | def release(self): 75 | ret = self.acc 76 | self.acc = 0 77 | 78 | return ret 79 | 80 | def reset(self): 81 | self.acc = 0 82 | 83 | class checkpoint(): 84 | def __init__(self, args): 85 | self.args = args 86 | self.ok = True 87 | self.log = torch.Tensor() 88 | self.ssim_log = torch.Tensor() 89 | self.bit_log = torch.Tensor() 90 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 91 | 92 | if not args.load: 93 | if not args.save: 94 | args.save = now 95 | # self.dir = os.path.join('..', 'experiment', '{}_sd{}'.format(args.save, args.seed)) 96 | self.dir = os.path.join('..', 'experiment', '{}'.format(args.save)) 97 | else: 98 | self.dir = os.path.join('..', 'experiment', args.load) 99 | if os.path.exists(self.dir): 100 | self.log = torch.load(self.get_path('psnr_log.pt')) 101 | print('Continue from epoch {}...'.format(len(self.log))) 102 | else: 103 | args.load = '' 104 | 105 | if args.reset: 106 | os.system('rm -rf ' + self.dir) 107 | args.load = '' 108 | 109 | os.makedirs(self.dir, exist_ok=True) 110 | os.makedirs(self.get_path('model'), exist_ok=True) 111 | for d in args.data_test: 112 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 113 | if args.test_own is not None: 114 | os.makedirs(self.get_path('results-test_own'), exist_ok=True) 115 | 116 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 117 | self.log_file = open(self.get_path('log.txt'), open_type) 118 | with open(self.get_path('config.txt'), open_type) as f: 119 | f.write(now + '\n\n') 120 | for arg in vars(args): 121 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 122 | f.write('\n') 123 | 124 | self.n_processes = 8 125 | 126 | def get_path(self, *subdir): 127 | return os.path.join(self.dir, *subdir) 128 | 129 | def save(self, trainer, epoch, is_best=False): 130 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 131 | 132 | # self.plot_bit(epoch) 133 | # self.plot_psnr(epoch) 134 | 135 | def add_log(self, log): 136 | self.log = torch.cat([self.log, log]) 137 | self.ssim_log = torch.cat([self.ssim_log, log]) 138 | self.bit_log = torch.cat([self.bit_log, log]) 139 | 140 | def write_log(self, log, refresh=False): 141 | print(log) 142 | self.log_file.write(log + '\n') 143 | if refresh: 144 | self.log_file.close() 145 | self.log_file = open(self.get_path('log.txt'), 'a') 146 | 147 | def done(self): 148 | self.log_file.close() 149 | 150 | def plot_psnr(self, epoch): 151 | # axis = np.linspace(1, epoch, epoch) 152 | axis = np.linspace(1, epoch-1, epoch-1) 153 | for idx_data, d in enumerate(self.args.data_test): 154 | label = 'SR on {}'.format(d) 155 | fig = plt.figure() 156 | plt.title(label) 157 | for idx_scale, scale in enumerate(self.args.scale): 158 | plt.plot( 159 | axis, 160 | self.log[:, idx_data, idx_scale].numpy(), 161 | label='Scale {}'.format(scale) 162 | ) 163 | plt.legend() 164 | plt.xlabel('Epochs') 165 | plt.ylabel('PSNR') 166 | plt.grid(True) 167 | plt.savefig(self.get_path('test_{}.png'.format(d))) 168 | 169 | plt.close(fig) 170 | 171 | def plot_bit(self, epoch): 172 | # axis = np.linspace(1, epoch, epoch) 173 | axis = np.linspace(1, epoch-1, epoch-1) 174 | for idx_data, d in enumerate(self.args.data_test): 175 | label = 'SR on {}'.format(d) 176 | fig = plt.figure() 177 | plt.title(label) 178 | for idx_scale, scale in enumerate(self.args.scale): 179 | plt.plot( 180 | axis, 181 | self.bit_log[:, idx_data, idx_scale].numpy(), 182 | label='Scale {}'.format(scale) 183 | ) 184 | plt.legend() 185 | plt.xlabel('Epochs') 186 | plt.ylabel('Average Bit') 187 | plt.grid(True) 188 | plt.savefig(self.get_path('test_{}_bit.png'.format(d))) 189 | 190 | plt.close(fig) 191 | 192 | def begin_background(self): 193 | self.queue = Queue() 194 | 195 | def bg_target(queue): 196 | while True: 197 | if not queue.empty(): 198 | filename, tensor = queue.get() 199 | if filename is None: break 200 | imageio.imwrite(filename, tensor.numpy()) 201 | 202 | self.process = [ 203 | Process(target=bg_target, args=(self.queue,)) \ 204 | for _ in range(self.n_processes) 205 | ] 206 | 207 | for p in self.process: p.start() 208 | 209 | def end_background(self): 210 | for _ in range(self.n_processes): self.queue.put((None, None)) 211 | while not self.queue.empty(): time.sleep(1) 212 | for p in self.process: p.join() 213 | 214 | def save_results(self, dataset, filename, save_list): 215 | if self.args.save_results: 216 | if isinstance(dataset, str): 217 | filename = self.get_path( 218 | 'results-{}'.format(dataset), 219 | '{}'.format(filename) 220 | ) 221 | else: 222 | filename = self.get_path( 223 | 'results-{}'.format(dataset.dataset.name), 224 | '{}'.format(filename) 225 | ) 226 | 227 | postfix = ('', 'LR', 'HR') 228 | for v, p in zip(save_list, postfix): 229 | normalized = v[0].mul(255 / self.args.rgb_range) 230 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 231 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 232 | 233 | def quantize(img, rgb_range): 234 | pixel_range = 255 / rgb_range 235 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 236 | 237 | def gaussian(window_size, sigma): 238 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 239 | return gauss/gauss.sum() 240 | 241 | def create_window_3d(window_size, channel=1): 242 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 243 | _2D_window = _1D_window.mm(_1D_window.t()) 244 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 245 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() 246 | return window 247 | 248 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 249 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 250 | if val_range is None: 251 | if torch.max(img1) > 128: 252 | max_val = 255 253 | else: 254 | max_val = 1 255 | 256 | if torch.min(img1) < -0.5: 257 | min_val = -1 258 | else: 259 | min_val = 0 260 | L = max_val - min_val 261 | else: 262 | L = val_range 263 | 264 | padd = 0 265 | (_, _, height, width) = img1.size() 266 | if window is None: 267 | real_size = min(window_size, height, width) 268 | window = create_window_3d(real_size, channel=1).to(img1.device) 269 | # Channel is set to 1 since we consider color images as volumetric images 270 | 271 | img1 = img1.unsqueeze(1) 272 | img2 = img2.unsqueeze(1) 273 | 274 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 275 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 276 | 277 | mu1_sq = mu1.pow(2) 278 | mu2_sq = mu2.pow(2) 279 | mu1_mu2 = mu1 * mu2 280 | 281 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq 282 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq 283 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 284 | 285 | C1 = (0.01 * L) ** 2 286 | C2 = (0.03 * L) ** 2 287 | 288 | v1 = 2.0 * sigma12 + C2 289 | v2 = sigma1_sq + sigma2_sq + C2 290 | cs = torch.mean(v1 / v2) # contrast sensitivity 291 | 292 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 293 | 294 | if size_average: 295 | ret = ssim_map.mean() 296 | else: 297 | ret = ssim_map.mean(1).mean(1).mean(1) 298 | 299 | if full: 300 | return ret, cs 301 | return ret 302 | 303 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 304 | if hr.nelement() == 1: return 0 305 | 306 | diff = (sr - hr) / rgb_range 307 | if dataset and dataset.dataset.benchmark: 308 | shave = scale 309 | if diff.size(1) > 1: 310 | gray_coeffs = [65.738, 129.057, 25.064] 311 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 312 | diff = diff.mul(convert).sum(dim=1) 313 | sr = sr.mul(convert).sum(dim=1) 314 | hr = hr.mul(convert).sum(dim=1) 315 | else: 316 | shave = scale + 6 317 | 318 | valid = diff[..., shave:-shave, shave:-shave] 319 | mse = valid.pow(2).mean() 320 | 321 | sr = sr[..., shave:-shave, shave:-shave] 322 | hr = hr[..., shave:-shave, shave:-shave] 323 | 324 | ssim_out = ssim_matlab(sr.unsqueeze(0), hr.unsqueeze(0), val_range=255).item() 325 | 326 | return -10 * math.log10(mse), ssim_out 327 | 328 | 329 | def crop(img, crop_sz, step): 330 | b, c, h, w = img.shape 331 | h_space = np.arange(0, max(h - crop_sz,0) + 1, step) 332 | w_space = np.arange(0, max(w - crop_sz,0) + 1, step) 333 | num_h = 0 334 | lr_list=[] 335 | for x in h_space: 336 | num_h += 1 337 | num_w = 0 338 | for y in w_space: 339 | num_w += 1 340 | # remaining borders are NOT clipped 341 | x_end = x + crop_sz if x != h_space[-1] else h 342 | y_end = y + crop_sz if y != w_space[-1] else w 343 | 344 | crop_img = img[:,:, x : x_end, y : y_end] 345 | lr_list.append(crop_img) 346 | 347 | return lr_list, num_h, num_w, h, w 348 | 349 | def crop_parallel(img, crop_sz, step): 350 | # remaining borders are clipped 351 | b, c, h, w = img.shape 352 | h_space = np.arange(0, h - crop_sz + 1, step) 353 | w_space = np.arange(0, w - crop_sz + 1, step) 354 | index = 0 355 | num_h = 0 356 | lr_list=torch.Tensor().to(img.device) 357 | for x in h_space: 358 | num_h += 1 359 | num_w = 0 360 | for y in w_space: 361 | num_w += 1 362 | index += 1 363 | crop_img = img[:, :, x:x + crop_sz, y:y + crop_sz] 364 | lr_list = torch.cat([lr_list, crop_img]) 365 | new_h=x + crop_sz # new height after crop 366 | new_w=y + crop_sz # new width after crop 367 | return lr_list, num_h, num_w, new_h, new_w 368 | 369 | 370 | def combine(sr_list, num_h, num_w, h, w, patch_size, step, scale): 371 | index=0 372 | sr_img = torch.zeros((1, 3, h*scale, w*scale)).cuda() 373 | step = step * scale 374 | patch_size = patch_size * scale 375 | 376 | for x in range(num_h): 377 | for y in range(num_w): 378 | x_patch_size = sr_list[index].shape[2] 379 | y_patch_size = sr_list[index].shape[3] 380 | sr_img[:, :, x*step : x*step+x_patch_size, y*step : y*step+y_patch_size] += sr_list[index] 381 | index += 1 382 | 383 | # mean the overlap region 384 | for x in range(1, num_h): 385 | sr_img[:, :, x*step : x*step+ (patch_size - step), :]/=2 386 | for y in range(1, num_w): 387 | sr_img[:, :, :, y*step : y*step+ (patch_size - step)]/=2 388 | 389 | return sr_img 390 | 391 | --------------------------------------------------------------------------------