├── .gitignore
├── README.md
├── assets
└── cover_adabm.png
├── environment.yaml
└── src
├── data
├── __init__.py
├── benchmark.py
├── common.py
├── div2k.py
├── div2k_valid.py
├── srdata.py
├── test2k.py
├── test4k.py
└── test8k.py
├── main.py
├── model
├── __init__.py
├── common.py
├── edsr.py
├── quantize.py
├── rdn.py
└── srresnet.py
├── option.py
├── run.sh
├── trainer.py
└── utility.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *.pyc
5 |
6 | /.vscode
7 |
8 | src/run_exp.sh
9 | experiment/
10 | pretrained_model/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaBM
2 | This repository includes the official implementation of the paper [**AdaBM: On-the-Fly Adaptive Bit Mapping for Image Super-Resolution**](https://arxiv.org/abs/2404.03296) (CVPR2024).
3 |
4 |
5 |
6 |
7 |
8 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Requirements
19 | A suitable [conda](https://conda.io/) environment named `adabm` can be created and activated with:
20 | ```
21 | conda env create -f environment.yaml
22 | conda activate adabm
23 | ```
24 |
25 | ## Preparation
26 | ### Dataset
27 | * For training, we use LR images sampled from [DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
28 | * For testing, we use [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) and large input datasets [Test2K,4K,8K](https://drive.google.com/drive/folders/18b3QKaDJdrd9y0KwtrWU2Vp9nHxvfTZH?usp=sharing).
29 | Test8K contains the images (index 1401-1500) from [DIV8K](https://competitions.codalab.org/competitions/22217#participate). Test2K/4K contain the images (index 1201-1300/1301-1400) from DIV8K which are downsampled to 2K and 4K resolution.
30 | After downloading the datasets, the dataset directory should be organized as follows:
31 |
32 | ```
33 | datasets
34 | -DIV2K
35 | - DIV2K_train_LR_bicubic # for training
36 | - DIV2K_train_HR
37 | - test2k # for testing
38 | - test4k
39 | - test8k
40 | -benchmark # for testing
41 | ```
42 |
43 | ### Pretrained Models
44 | Please download the pretrained models from [here](https://drive.google.com/drive/folders/1GLuvwy3WWFG2H6iEA6-7tqRj_Jnzcn86?usp=drive_link) and place them in `pretrained_model`.
45 |
46 | ## Usage
47 |
48 | ### How to train
49 |
50 | ```
51 | sh run.sh edsr 0 6 8 # gpu_id a_bit w_bit
52 | sh run.sh edsr 0 4 4 # gpu_id a_bit w_bit
53 | ```
54 |
55 | ### How to test
56 |
57 | ```
58 | sh run.sh edsr_eval 0 6 8 # gpu_id a_bit w_bit
59 | sh run.sh edsr_eval 0 4 4 # gpu_id a_bit w_bit
60 | ```
61 |
62 | > * set `--dir_data` to the directory path for datasets.
63 | > * set `--pre_train` to the saved model path for testing model.
64 | > * the trained model is saved in `experiment` directory.
65 | > * set `--test_own` to the own image path for testing.
66 |
67 | More running scripts can be found in `run.sh`.
68 |
69 | ## Comments
70 | Our implementation is based on [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
71 |
72 | #### Coming Soon...
73 | - [ ] parallel patch inference
74 |
75 | ## BibTeX
76 | If you found our implementation useful, please consider citing our paper:
77 | ```
78 | @misc{hong2024adabm,
79 | title={AdaBM: On-the-Fly Adaptive Bit Mapping for Image Super-Resolution},
80 | author={Cheeun Hong and Kyoung Mu Lee},
81 | year={2024},
82 | eprint={2404.03296},
83 | archivePrefix={arXiv},
84 | primaryClass={cs.CV}
85 | }
86 | ```
87 |
88 | ## Contact
89 | > Email: [cheeun914@snu.ac.kr](cheeun914@snu.ac.kr)
--------------------------------------------------------------------------------
/assets/cover_adabm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cheeun/AdaBM/c855f5204d7e6f7c5652219aecef55ece4d36930/assets/cover_adabm.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: adabm
2 | channels:
3 | - pytorch
4 | - comet_ml
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - attrs=21.4.0=pyhd3eb1b0_0
10 | - blas=1.0=mkl
11 | - blinker=1.4=py_1
12 | - brotlipy=0.7.0=py36h27cfd23_1003
13 | - c-ares=1.18.1=h7f8727e_0
14 | - ca-certificates=2021.10.8=ha878542_0
15 | - certifi=2021.5.30=py36h06a4308_0
16 | - cffi=1.14.6=py36h400218f_0
17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
18 | - comet-git-pure=0.19.16=py_0
19 | - comet_ml=3.2.0=py36
20 | - configobj=5.0.6=py36h06a4308_1
21 | - coverage=4.0.3=py36_1
22 | - cryptography=35.0.0=py36hd23ed53_0
23 | - cudatoolkit=10.1.243=h6bb024c_0
24 | - cycler=0.10.0=py36_0
25 | - dataclasses=0.8=pyh4f3eec9_6
26 | - dbus=1.13.14=hb2f20db_0
27 | - everett=1.0.2=py_1
28 | - expat=2.2.6=he6710b0_0
29 | - fontconfig=2.13.0=h9420a91_0
30 | - freetype=2.9.1=h8a8886c_1
31 | - fsspec=2022.1.0=pyhd8ed1ab_0
32 | - future=0.18.2=py36h5fab9bb_3
33 | - glib=2.63.1=h3eb4bd4_1
34 | - google-auth-oauthlib=0.4.1=py_1
35 | - gst-plugins-base=1.14.0=hbbd80ab_1
36 | - gstreamer=1.14.0=hb31296c_0
37 | - icu=58.2=he6710b0_3
38 | - imageio=2.8.0=py_0
39 | - importlib_metadata=4.8.1=hd3eb1b0_0
40 | - intel-openmp=2020.1=217
41 | - jpeg=9b=h024ee3a_2
42 | - jsonschema=3.2.0=pyhd3eb1b0_2
43 | - kiwisolver=1.2.0=py36hfd86e86_0
44 | - ld_impl_linux-64=2.33.1=h53a641e_7
45 | - libedit=3.1.20181209=hc058e9b_0
46 | - libffi=3.3=he6710b0_1
47 | - libgcc-ng=9.1.0=hdf63c60_0
48 | - libgfortran-ng=7.3.0=hdf63c60_0
49 | - libpng=1.6.37=hbc83047_0
50 | - libprotobuf=3.13.0.1=h8b12597_0
51 | - libstdcxx-ng=9.1.0=hdf63c60_0
52 | - libtiff=4.1.0=h2733197_0
53 | - libuuid=1.0.3=h1bed415_2
54 | - libuv=1.40.0=h7b6447c_0
55 | - libxcb=1.13=h1bed415_1
56 | - libxml2=2.9.9=hea5a465_1
57 | - matplotlib=3.1.3=py36_0
58 | - matplotlib-base=3.1.3=py36hef1b27d_0
59 | - mkl=2020.1=217
60 | - mkl-service=2.3.0=py36he904b0f_0
61 | - mkl_fft=1.0.15=py36ha843d7b_0
62 | - mkl_random=1.1.0=py36hd6b4f25_0
63 | - ncurses=6.2=he6710b0_1
64 | - netifaces=0.10.9=py36h27cfd23_1004
65 | - ninja=1.9.0=py36hfd86e86_0
66 | - nvidia-ml=7.352.0=pyhd3eb1b0_0
67 | - olefile=0.46=py36_0
68 | - onnx=1.8.0=py36h197aa4f_0
69 | - openssl=1.1.1m=h7f8727e_0
70 | - pcre=8.43=he6710b0_0
71 | - pillow=7.1.2=py36hb39fc2d_0
72 | - pip=20.0.2=py36_3
73 | - psutil=5.8.0=py36h27cfd23_1
74 | - pyasn1=0.4.8=py_0
75 | - pycparser=2.21=pyhd3eb1b0_0
76 | - pyjwt=2.3.0=pyhd8ed1ab_1
77 | - pyopenssl=22.0.0=pyhd3eb1b0_0
78 | - pyparsing=2.4.7=py_0
79 | - pyqt=5.9.2=py36h05f1152_2
80 | - pyrsistent=0.17.3=py36h7b6447c_0
81 | - pysocks=1.7.1=py36h06a4308_0
82 | - python=3.6.10=h7579374_2
83 | - python-dateutil=2.8.1=py_0
84 | - python_abi=3.6=1_cp36m
85 | - pytorch-lightning=1.5.9=pyhd8ed1ab_0
86 | - qt=5.9.7=h5867ecd_1
87 | - readline=8.0=h7b6447c_0
88 | - sip=4.19.8=py36hf484d3e_0
89 | - six=1.14.0=py36_0
90 | - sqlite=3.31.1=h62c20be_1
91 | - tk=8.6.8=hbc83047_0
92 | - torchaudio=0.7.0=py36
93 | - torchmetrics=0.7.1=pyhd8ed1ab_0
94 | - tornado=6.0.4=py36h7b6447c_1
95 | - tqdm=4.46.0=py_0
96 | - typing-extensions=3.7.4.3=0
97 | - typing_extensions=3.7.4.3=py_0
98 | - websocket-client=0.58.0=py36h06a4308_4
99 | - werkzeug=1.0.1=pyh9f0ad1d_0
100 | - wheel=0.34.2=py36_0
101 | - wrapt=1.12.1=py36h7b6447c_1
102 | - wurlitzer=1.0.3=py_2
103 | - xz=5.2.5=h7b6447c_0
104 | - yaml=0.2.5=h516909a_0
105 | - zlib=1.2.11=h7b6447c_3
106 | - zstd=1.3.7=h0b5b093_0
107 | - pip:
108 | - absl-py==0.10.0
109 | - addict==2.4.0
110 | - antlr4-python3-runtime==4.9.3
111 | - astunparse==1.6.3
112 | - backcall==0.2.0
113 | - basicsr==1.4.2
114 | - cachetools==4.1.1
115 | - chardet==3.0.4
116 | - coloredlogs==14.0
117 | - cupy-cuda102==8.4.0
118 | - decorator==4.4.2
119 | - fastrlock==0.5
120 | - filelock==3.4.1
121 | - gast==0.3.3
122 | - google-auth==1.21.2
123 | - google-pasta==0.2.0
124 | - grpcio==1.32.0
125 | - h5py==2.10.0
126 | - huggingface-hub==0.4.0
127 | - humanfriendly==9.0
128 | - idna==2.10
129 | - imagecodecs==2020.2.18
130 | - imagenet-stubs==0.0.7
131 | - importlib-metadata==1.7.0
132 | - install==1.3.4
133 | - ipython==7.16.3
134 | - ipython-genutils==0.2.0
135 | - jedi==0.17.2
136 | - joblib==1.1.0
137 | - jsonpatch==1.32
138 | - jsonpointer==2.3
139 | - keras-preprocessing==1.1.2
140 | - kornia==0.5.11
141 | - lmdb==1.2.1
142 | - lpips==0.1.4
143 | - markdown==3.2.2
144 | - networkx==2.4
145 | - numpy==1.18.1
146 | - oauthlib==3.1.0
147 | - onnxruntime==1.6.0
148 | - opencv-python==4.2.0.34
149 | - opt-einsum==3.3.0
150 | - ort-nightly==1.7.0.dev202102191
151 | - packaging==21.0
152 | - pandas==1.1.5
153 | - parso==0.7.1
154 | - pexpect==4.8.0
155 | - pickle5==0.0.12
156 | - pickleshare==0.7.5
157 | - progress==1.6
158 | - progressbar==2.5
159 | - prompt-toolkit==3.0.36
160 | - protobuf==3.13.0
161 | - ptyprocess==0.7.0
162 | - pyasn1-modules==0.2.8
163 | - pydeprecate==0.3.1
164 | - pydot==1.4.1
165 | - pygments==2.14.0
166 | - pyhocon==0.3.59
167 | - python-graphviz==0.14.2
168 | - pytorchcv==0.0.67
169 | - pytz==2022.2.1
170 | - pywavelets==1.1.1
171 | - pyyaml==5.4.1
172 | - pyzmq==23.2.1
173 | - requests==2.24.0
174 | - requests-oauthlib==1.3.0
175 | - rsa==4.6
176 | - scikit-image==0.17.2
177 | - scikit-learn==0.24.2
178 | - scipy==1.5.0
179 | - seaborn==0.11.2
180 | - setuptools==59.5.0
181 | - tb-nightly==2.11.0a20220816
182 | - tensorboard==2.3.0
183 | - tensorboard-data-server==0.6.1
184 | - tensorboard-plugin-wit==1.7.0
185 | - tensorboardx==2.1
186 | - tensorflow==2.3.0
187 | - tensorflow-estimator==2.3.0
188 | - termcolor==1.1.0
189 | - thop==0.0.31-2005241907
190 | - threadpoolctl==3.1.0
191 | - tifffile==2020.5.11
192 | - timm==0.6.12
193 | - torch==1.7.0
194 | - torch-dct==0.1.5
195 | - torchensemble==0.1.7
196 | - torchfile==0.1.0
197 | - torchsummary==1.5.1
198 | - torchvision==0.11.2
199 | - traitlets==4.3.3
200 | - urllib3==1.25.10
201 | - visdom==0.1.8.9
202 | - warmup-scheduler==0.3
203 | - wcwidth==0.2.6
204 | - yapf==0.32.0
205 | - zipp==3.1.0
206 | prefix: /home/cheeun914/anaconda3/envs/adabm
207 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | from torch.utils.data import dataloader
3 | from torch.utils.data import ConcatDataset
4 |
5 | import random
6 | import torch.utils.data as data
7 |
8 | import numpy
9 | import torch
10 |
11 | class MyConcatDataset(ConcatDataset):
12 | def __init__(self, datasets):
13 | super(MyConcatDataset, self).__init__(datasets)
14 | self.train = datasets[0].train
15 |
16 | def set_scale(self, idx_scale):
17 | for d in self.datasets:
18 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
19 |
20 | class Data:
21 | def __init__(self, args):
22 | self.loader_train = None
23 | if not args.test_only:
24 | datasets = []
25 | for d in args.data_train:
26 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
27 | m = import_module('data.' + module_name.lower())
28 | datasets.append(getattr(m, module_name)(args, name=d))
29 |
30 | train_dataset = MyConcatDataset(datasets)
31 | indices = random.sample(range(0, len(train_dataset)), args.num_data)
32 | # sampled over (train_dataset[0]=001.png~train_dataset[799]=800.png)
33 |
34 | print('Indices of {} sampled data...'.format(len(indices)))
35 | print(indices)
36 |
37 | train_dataset_sampled = data.Subset(train_dataset, indices)
38 |
39 | self.loader_train = dataloader.DataLoader(
40 | train_dataset_sampled,
41 | batch_size=args.batch_size_update,
42 | shuffle=True,
43 | pin_memory=not args.cpu,
44 | num_workers=args.n_threads,
45 | )
46 |
47 | self.loader_test = []
48 | for d in args.data_test:
49 | if d in ['Set5', 'Set14', 'B100', 'Urban100']:
50 | m = import_module('data.benchmark')
51 | testset = getattr(m, 'Benchmark')(args, train=False, name=d)
52 | else:
53 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
54 | m = import_module('data.' + module_name.lower())
55 | testset = getattr(m, module_name)(args, train=False, name=d)
56 |
57 | self.loader_test.append(
58 | dataloader.DataLoader(
59 | testset,
60 | batch_size=1,
61 | shuffle=False,
62 | pin_memory=not args.cpu,
63 | num_workers=args.n_threads,
64 | )
65 | )
66 |
67 | self.loader_init = None
68 | if not args.test_only:
69 | datasets = []
70 | for d in args.data_train:
71 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
72 | m = import_module('data.' + module_name.lower())
73 | datasets.append(getattr(m, module_name)(args, name=d))
74 |
75 | self.loader_init = dataloader.DataLoader(
76 | train_dataset_sampled,
77 | batch_size=args.batch_size_calib,
78 | shuffle=True,
79 | pin_memory=not args.cpu,
80 | num_workers=args.n_threads,
81 | )
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/src/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Benchmark(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(Benchmark, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 | def _set_filesystem(self, dir_data):
17 | self.apath = os.path.join(dir_data, 'benchmark', self.name)
18 | self.dir_hr = os.path.join(self.apath, 'HR')
19 | if self.input_large:
20 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
21 | else:
22 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
23 | self.ext = ('', '.png')
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/data/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import skimage.color as sc
5 |
6 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
7 | ih, iw = args[0].shape[:2]
8 |
9 | if not input_large:
10 | p = scale if multi else 1
11 | tp = p * patch_size
12 | ip = tp // scale
13 | else:
14 | tp = patch_size
15 | ip = patch_size
16 |
17 | ix = random.randrange(0, iw - ip + 1)
18 | iy = random.randrange(0, ih - ip + 1)
19 | # print(ix, iy)
20 |
21 | if not input_large:
22 | tx, ty = scale * ix, scale * iy
23 | else:
24 | tx, ty = ix, iy
25 |
26 | ret = [
27 | args[0][iy:iy + ip, ix:ix + ip, :],
28 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
29 | ]
30 |
31 | return ret
32 |
33 | def set_channel(*args, n_channels=3):
34 | def _set_channel(img):
35 | if img.ndim == 2:
36 | img = np.expand_dims(img, axis=2)
37 |
38 | c = img.shape[2]
39 | if n_channels == 1 and c == 3:
40 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
41 | elif n_channels == 3 and c == 1:
42 | img = np.concatenate([img] * n_channels, 2)
43 |
44 | return img
45 |
46 | return [_set_channel(a) for a in args]
47 |
48 | def np2Tensor(*args, rgb_range=255):
49 | def _np2Tensor(img):
50 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
51 | tensor = torch.from_numpy(np_transpose).float()
52 | tensor.mul_(rgb_range / 255)
53 |
54 | return tensor
55 |
56 | return [_np2Tensor(a) for a in args]
57 |
58 | def augment(*args, hflip=True, rot=True):
59 | hflip = hflip and random.random() < 0.5
60 | vflip = rot and random.random() < 0.5
61 | rot90 = rot and random.random() < 0.5
62 |
63 | def _augment(img):
64 | if hflip: img = img[:, ::-1, :]
65 | if vflip: img = img[::-1, :, :]
66 | if rot90: img = img.transpose(1, 0, 2)
67 |
68 | return img
69 |
70 | return [_augment(a) for a in args]
--------------------------------------------------------------------------------
/src/data/div2k.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class DIV2K(srdata.SRData):
5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6 | data_range = [r.split('-') for r in args.data_range.split('/')]
7 | if train:
8 | data_range = data_range[0]
9 | else:
10 | if args.test_only and len(data_range) == 1:
11 | data_range = data_range[0]
12 | else:
13 | data_range = data_range[1]
14 |
15 | self.begin, self.end = list(map(lambda x: int(x), data_range))
16 | super(DIV2K, self).__init__(
17 | args, name=name, train=train, benchmark=benchmark
18 | )
19 |
20 | def _scan(self):
21 | names_hr, names_lr = super(DIV2K, self)._scan()
22 | names_hr = names_hr[self.begin-1 : self.end]
23 | names_lr = [n[self.begin-1 : self.end] for n in names_lr]
24 |
25 | return names_hr, names_lr
26 |
27 | def _set_filesystem(self, dir_data):
28 | super(DIV2K, self)._set_filesystem(dir_data)
29 | # self.dir_hr = None
30 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
31 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
32 | if self.input_large: self.dir_lr += 'L'
33 |
34 |
--------------------------------------------------------------------------------
/src/data/div2k_valid.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class div2k_valid(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(div2k_valid, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'DIV2K' )
19 | self.dir_hr = os.path.join(self.apath, 'DIV2K_valid_HR')
20 | self.dir_lr = os.path.join(self.apath, 'DIV2K_valid_LR_bicubic')
21 | self.ext = ('', '.png')
22 |
--------------------------------------------------------------------------------
/src/data/srdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | import pickle
5 |
6 | from data import common
7 |
8 | import numpy as np
9 | import imageio
10 | import torch
11 | import torch.utils.data as data
12 |
13 | class SRData(data.Dataset):
14 | def __init__(self, args, name='', train=True, benchmark=False):
15 | self.args = args
16 | self.name = name
17 | self.train = train
18 | self.split = 'train' if train else 'test'
19 | self.do_eval = True
20 | self.benchmark = benchmark
21 | self.input_large = False #(args.model == 'VDSR')
22 | self.scale = args.scale
23 | self.idx_scale = 0
24 |
25 | self._set_filesystem(args.dir_data)
26 | if args.ext.find('img') < 0:
27 | path_bin = os.path.join(self.apath, 'bin')
28 | os.makedirs(path_bin, exist_ok=True)
29 |
30 | list_hr, list_lr = self._scan()
31 | if args.ext.find('bin') >= 0:
32 | # Binary files are stored in 'bin' folder
33 | # If the binary file exists, load it. If not, make it.
34 | list_hr, list_lr = self._scan()
35 | self.images_hr = self._check_and_load(
36 | args.ext, list_hr, self._name_hrbin()
37 | )
38 | self.images_lr = [
39 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \
40 | for s, l in zip(self.scale, list_lr)
41 | ]
42 | else:
43 | if args.ext.find('img') >= 0 or benchmark:
44 | self.images_hr, self.images_lr = list_hr, list_lr
45 | elif args.ext.find('sep') >= 0:
46 | os.makedirs(
47 | self.dir_hr.replace(self.apath, path_bin),
48 | exist_ok=True
49 | )
50 | for s in self.scale:
51 | os.makedirs(
52 | os.path.join(
53 | self.dir_lr.replace(self.apath, path_bin),
54 | 'X{}'.format(s)
55 | ),
56 | exist_ok=True
57 | )
58 |
59 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
60 | for h in list_hr:
61 | b = h.replace(self.apath, path_bin)
62 | b = b.replace(self.ext[0], '.pt')
63 | self.images_hr.append(b)
64 | self._check_and_load(
65 | args.ext, [h], b, verbose=True, load=False
66 | )
67 |
68 | for i, ll in enumerate(list_lr):
69 | for l in ll:
70 | b = l.replace(self.apath, path_bin)
71 | b = b.replace(self.ext[1], '.pt')
72 | self.images_lr[i].append(b)
73 | self._check_and_load(
74 | args.ext, [l], b, verbose=True, load=False
75 | )
76 |
77 | if train:
78 | n_patches = args.batch_size_update * args.test_every
79 | n_images = len(args.data_train) * len(self.images_hr)
80 | if args.num_data < 800:
81 | n_images = args.num_data
82 |
83 | if n_images == 0:
84 | self.repeat = 0
85 | else:
86 | self.repeat = max(n_patches // n_images, 1)
87 | # self.repeat == 1 for PTQ
88 |
89 | # Below functions as used to prepare images
90 | def _scan(self):
91 | names_hr = sorted(
92 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
93 | )
94 | names_lr = [[] for _ in self.scale]
95 | for f in names_hr:
96 | filename, _ = os.path.splitext(os.path.basename(f))
97 | for si, s in enumerate(self.scale):
98 | if 'DIV2K' in self.dir_hr and 'test' not in self.dir_hr:
99 | names_lr[si].append(os.path.join(
100 | self.dir_lr, 'X{}/{}x{}{}'.format(
101 | s, filename, s, self.ext[1]
102 | )
103 | ))
104 | # for test2k, test4k, test8k
105 | elif 'test' in self.dir_hr:
106 | names_lr[si].append(os.path.join(
107 | self.dir_lr, 'X{}/{}{}'.format(
108 | s, filename, self.ext[1]
109 | )
110 | ))
111 | else:
112 | names_lr[si].append(os.path.join(
113 | self.dir_lr, 'X{}/{}{}'.format(
114 | s, filename, self.ext[1]
115 | )
116 | ))
117 | return names_hr, names_lr
118 |
119 | def _set_filesystem(self, dir_data):
120 | self.apath = os.path.join(dir_data, self.name)
121 | self.dir_hr = os.path.join(self.apath, 'HR')
122 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
123 | if self.input_large: self.dir_lr += 'L'
124 | self.ext = ('.png', '.png')
125 |
126 | def _name_hrbin(self):
127 | return os.path.join(
128 | self.apath,
129 | 'bin',
130 | '{}_bin_HR.pt'.format(self.split)
131 | )
132 |
133 | def _name_lrbin(self, scale):
134 | return os.path.join(
135 | self.apath,
136 | 'bin',
137 | '{}_bin_LR_X{}.pt'.format(self.split, scale)
138 | )
139 |
140 | def _check_and_load(self, ext, l, f, verbose=True, load=True):
141 | if os.path.isfile(f) and ext.find('reset') < 0:
142 | if load:
143 | if verbose: print('Loading {}...'.format(f))
144 | with open(f, 'rb') as _f: ret = pickle.load(_f)
145 | return ret
146 | else:
147 | return None
148 | else:
149 | if verbose:
150 | if ext.find('reset') >= 0:
151 | print('Making a new binary: {}'.format(f))
152 | else:
153 | print('{} does not exist. Now making binary...'.format(f))
154 | b = [{
155 | 'name': os.path.splitext(os.path.basename(_l))[0],
156 | 'image': imageio.imread(_l)
157 | } for _l in l]
158 | with open(f, 'wb') as _f: pickle.dump(b, _f)
159 |
160 | return b
161 |
162 | def __getitem__(self, idx):
163 | lr, hr, filename = self._load_file(idx)
164 | pair = self.get_patch(lr, hr)
165 | pair = common.set_channel(*pair, n_channels=self.args.n_colors)
166 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
167 |
168 | return pair_t[0], pair_t[1], filename
169 |
170 | def __len__(self):
171 | if self.train:
172 | return len(self.images_hr) * self.repeat
173 | else:
174 | return len(self.images_hr)
175 |
176 | def _get_index(self, idx):
177 | if self.train:
178 | return idx % len(self.images_hr)
179 | else:
180 | return idx
181 |
182 | def _load_file(self, idx):
183 | idx = self._get_index(idx)
184 | f_hr = self.images_hr[idx]
185 | f_lr = self.images_lr[self.idx_scale][idx]
186 |
187 | if self.args.ext.find('bin') >= 0:
188 | filename = f_hr['name']
189 | hr = f_hr['image']
190 | lr = f_lr['image']
191 | else:
192 | filename, _ = os.path.splitext(os.path.basename(f_hr))
193 | if self.args.ext == 'img' or self.benchmark:
194 | hr = imageio.imread(f_hr)
195 | lr = imageio.imread(f_lr)
196 | elif self.args.ext.find('sep') >= 0:
197 | with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image']
198 | with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image']
199 |
200 | return lr, hr, filename
201 |
202 | def get_patch(self, lr, hr):
203 | scale = self.scale[self.idx_scale]
204 | if self.train:
205 | lr, hr = common.get_patch(
206 | lr, hr,
207 | patch_size=self.args.patch_size,
208 | scale=scale,
209 | multi=(len(self.scale) > 1),
210 | input_large=self.input_large
211 | )
212 | if not self.args.no_augment: lr, hr = common.augment(lr, hr)
213 | else:
214 | ih, iw = lr.shape[:2]
215 | hr = hr[0:ih * scale, 0:iw * scale]
216 |
217 | return lr, hr
218 |
219 | def set_scale(self, idx_scale):
220 | if not self.input_large:
221 | self.idx_scale = idx_scale
222 | else:
223 | self.idx_scale = random.randint(0, len(self.scale) - 1)
224 |
--------------------------------------------------------------------------------
/src/data/test2k.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class test2k(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(test2k, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | self.dir_lr = os.path.join(self.apath, 'LR')
21 | self.ext = ('', '.png')
22 |
--------------------------------------------------------------------------------
/src/data/test4k.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class test4k(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(test4k, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | self.dir_lr = os.path.join(self.apath, 'LR')
21 | self.ext = ('', '.png')
22 |
--------------------------------------------------------------------------------
/src/data/test8k.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class test8k(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(test8k, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'DIV2K', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | self.dir_lr = os.path.join(self.apath, 'LR')
21 | self.ext = ('', '.png')
22 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import utility
4 | import data
5 | import model
6 | from option import args
7 | from trainer import Trainer
8 |
9 | import time
10 | import datetime
11 |
12 | import numpy
13 | import random
14 |
15 | torch.manual_seed(args.seed)
16 | torch.backends.cudnn.deterministic=True
17 | torch.backends.cudnn.benchmark=False
18 | numpy.random.seed(args.seed)
19 | random.seed(args.seed)
20 | torch.cuda.manual_seed(args.seed)
21 | torch.cuda.manual_seed_all(args.seed)
22 | checkpoint = utility.checkpoint(args)
23 |
24 | if checkpoint.ok:
25 | exp_start_time = time.time()
26 | _loader = data.Data(args)
27 | _model = model.Model(args, checkpoint)
28 | t = Trainer(args, _loader, _model, checkpoint)
29 |
30 | # t.test_teacher()
31 | while not t.terminate():
32 | torch.manual_seed(args.seed)
33 | t.train()
34 | t.test()
35 |
36 | exp_end_time = time.time()
37 | exp_time_interval = exp_end_time - exp_start_time
38 | t_string = "Total Running Time is: " + str(datetime.timedelta(seconds=exp_time_interval)) + "\n"
39 | checkpoint.write_log('{}'.format(t_string))
40 | checkpoint.done()
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel as P
7 | import torch.utils.model_zoo
8 |
9 | from model.quantize import *
10 |
11 | class Model(nn.Module):
12 | def __init__(self, args, ckp):
13 | super(Model, self).__init__()
14 | print('Making model...')
15 | self.args = args
16 | self.scale = args.scale
17 | self.idx_scale = 0
18 | self.input_large = (args.model == 'VDSR')
19 | self.self_ensemble = args.self_ensemble
20 | self.chop = args.chop
21 | self.precision = args.precision
22 | self.cpu = args.cpu
23 | self.device = torch.device('cpu' if args.cpu else 'cuda')
24 | self.n_GPUs = args.n_GPUs
25 | self.save_models = args.save_models
26 |
27 | module = import_module('model.' + args.model.lower())
28 | self.model = module.make_model(args).to(self.device)
29 | if args.precision == 'half':
30 | self.model.half()
31 |
32 | self.load(
33 | ckp.get_path('model'),
34 | pre_train=args.pre_train,
35 | resume=args.resume,
36 | cpu=args.cpu
37 | )
38 | # print(self.model, file=ckp.log_file)
39 |
40 | def forward(self, x, idx_scale):
41 | self.idx_scale = idx_scale
42 |
43 | if hasattr(self.model, 'set_scale'):
44 | self.model.set_scale(idx_scale)
45 |
46 | if self.training:
47 | if self.n_GPUs > 1:
48 | return P.data_parallel(self.model, x, range(self.n_GPUs))
49 | else:
50 | return self.model(x)
51 | else:
52 | if self.chop:
53 | forward_function = self.forward_chop
54 | else:
55 | forward_function = self.model.forward
56 |
57 | if self.self_ensemble:
58 | return self.forward_x8(x, forward_function=forward_function)
59 | else:
60 | return forward_function(x)
61 |
62 | def save(self, apath, epoch, is_best=False):
63 | save_dirs = [os.path.join(apath, 'checkpoint.pt')]
64 |
65 | if is_best:
66 | save_dirs.append(os.path.join(apath, 'checkpoint_bestpsnr.pt'))
67 | if self.save_models:
68 | save_dirs.append(
69 | os.path.join(apath, 'model_{}.pt'.format(epoch))
70 | )
71 |
72 | for s in save_dirs:
73 | torch.save(self.model.state_dict(), s)
74 |
75 | def load(self, apath, pre_train='', resume=-1, cpu=False):
76 | load_from = None
77 | kwargs = {}
78 | if cpu:
79 | kwargs = {'map_location': lambda storage, loc: storage}
80 |
81 | if resume == -1:
82 | load_from = torch.load(
83 | os.path.join(apath, 'model_latest.pt'),
84 | **kwargs
85 | )
86 | elif resume == 0:
87 | if pre_train == 'download':
88 | print('Download the model')
89 | dir_model = os.path.join('..', 'models')
90 | os.makedirs(dir_model, exist_ok=True)
91 | load_from = torch.utils.model_zoo.load_url(
92 | self.model.url,
93 | model_dir=dir_model,
94 | **kwargs
95 | )
96 | elif pre_train:
97 | print('Load the model from {}'.format(pre_train))
98 | load_from = torch.load(pre_train, **kwargs)
99 | else:
100 | load_from = torch.load(
101 | os.path.join(apath, 'model_{}.pt'.format(resume)),
102 | **kwargs
103 | )
104 |
105 | if load_from:
106 | load_from = load_from['state_dict'] if 'state_dict' in load_from else load_from
107 | self.model.load_state_dict(load_from, strict=False)
108 | # self.model.load_state_dict(load_from, strict=True) # for debugging
109 |
110 | def forward_chop(self, *args, shave=10, min_size=160000):
111 | scale = 1 if self.input_large else self.scale[self.idx_scale]
112 | n_GPUs = min(self.n_GPUs, 4)
113 | # height, width
114 | h, w = args[0].size()[-2:]
115 |
116 | top = slice(0, h//2 + shave)
117 | bottom = slice(h - h//2 - shave, h)
118 | left = slice(0, w//2 + shave)
119 | right = slice(w - w//2 - shave, w)
120 | x_chops = [torch.cat([
121 | a[..., top, left],
122 | a[..., top, right],
123 | a[..., bottom, left],
124 | a[..., bottom, right]
125 | ]) for a in args]
126 |
127 | y_chops = []
128 | if h * w < 4 * min_size:
129 | for i in range(0, 4, n_GPUs):
130 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
131 | y = P.data_parallel(self.model, *x, range(n_GPUs))
132 | if not isinstance(y, list): y = [y]
133 | if not y_chops:
134 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
135 | else:
136 | for y_chop, _y in zip(y_chops, y):
137 | y_chop.extend(_y.chunk(n_GPUs, dim=0))
138 | else:
139 | for p in zip(*x_chops):
140 | y = self.forward_chop(*p, shave=shave, min_size=min_size)
141 | if not isinstance(y, list): y = [y]
142 | if not y_chops:
143 | y_chops = [[_y] for _y in y]
144 | else:
145 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
146 |
147 | h *= scale
148 | w *= scale
149 | top = slice(0, h//2)
150 | bottom = slice(h - h//2, h)
151 | bottom_r = slice(h//2 - h, None)
152 | left = slice(0, w//2)
153 | right = slice(w - w//2, w)
154 | right_r = slice(w//2 - w, None)
155 |
156 | # batch size, number of color channels
157 | b, c = y_chops[0][0].size()[:-2]
158 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
159 | for y_chop, _y in zip(y_chops, y):
160 | _y[..., top, left] = y_chop[0][..., top, left]
161 | _y[..., top, right] = y_chop[1][..., top, right_r]
162 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left]
163 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
164 |
165 | if len(y) == 1: y = y[0]
166 |
167 | return y
168 |
169 | def forward_x8(self, *args, forward_function=None):
170 | def _transform(v, op):
171 | if self.precision != 'single': v = v.float()
172 |
173 | v2np = v.data.cpu().numpy()
174 | if op == 'v':
175 | tfnp = v2np[:, :, :, ::-1].copy()
176 | elif op == 'h':
177 | tfnp = v2np[:, :, ::-1, :].copy()
178 | elif op == 't':
179 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
180 |
181 | ret = torch.Tensor(tfnp).to(self.device)
182 | if self.precision == 'half': ret = ret.half()
183 |
184 | return ret
185 |
186 | list_x = []
187 | for a in args:
188 | x = [a]
189 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
190 |
191 | list_x.append(x)
192 |
193 | list_y = []
194 | for x in zip(*list_x):
195 | y = forward_function(*x)
196 | if not isinstance(y, list): y = [y]
197 | if not list_y:
198 | list_y = [[_y] for _y in y]
199 | else:
200 | for _list_y, _y in zip(list_y, y): _list_y.append(_y)
201 |
202 | for _list_y in list_y:
203 | for i in range(len(_list_y)):
204 | if i > 3:
205 | _list_y[i] = _transform(_list_y[i], 't')
206 | if i % 4 > 1:
207 | _list_y[i] = _transform(_list_y[i], 'h')
208 | if (i % 4) % 2 == 1:
209 | _list_y[i] = _transform(_list_y[i], 'v')
210 |
211 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
212 | if len(y) == 1: y = y[0]
213 |
214 | return y
215 |
--------------------------------------------------------------------------------
/src/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from model import quantize
8 | import numpy as np
9 |
10 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
11 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
12 |
13 | class ShortCut(nn.Module):
14 | def __init__(self):
15 | super(ShortCut, self).__init__()
16 |
17 | def forward(self, input):
18 | return input
19 |
20 | class MeanShift(nn.Conv2d):
21 | def __init__(
22 | self, rgb_range,
23 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
24 |
25 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
26 | std = torch.Tensor(rgb_std)
27 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
28 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
29 | for p in self.parameters():
30 | p.requires_grad = False
31 |
32 | class BasicBlock(nn.Sequential):
33 | def __init__(
34 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
35 | bn=True, act=nn.ReLU(True)):
36 |
37 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
38 | if bn:
39 | m.append(nn.BatchNorm2d(out_channels))
40 | if act is not None:
41 | m.append(act)
42 |
43 | super(BasicBlock, self).__init__(*m)
44 |
45 |
46 | class ResBlock(nn.Module):
47 | def __init__(self, args, conv, n_feats, kernel_size, bias=True, bn=False, act='relu', res_scale=1):
48 | super(ResBlock, self).__init__()
49 |
50 | self.args = args
51 |
52 | self.conv1 = quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=bias, non_adaptive=False)
53 | self.conv2 = quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=bias, non_adaptive=False)
54 |
55 | if act == 'relu':
56 | self.act = nn.ReLU(True)
57 | elif act== 'prelu':
58 | self.act = nn.PReLU()
59 |
60 | self.res_scale = res_scale
61 | self.is_bn = bn
62 | if bn:
63 | self.bn1 = nn.BatchNorm2d(n_feats)
64 | self.bn2 = nn.BatchNorm2d(n_feats)
65 | self.shortcut = ShortCut()
66 |
67 | def forward(self, x):
68 | if self.args.imgwise:
69 | bit_img = x[3]
70 | bit = x[2]
71 | feat = x[1]
72 | x = x[0]
73 |
74 | residual = self.shortcut(x)
75 |
76 | if self.args.imgwise:
77 | out,bit = self.conv1([x, bit, bit_img])
78 | else:
79 | out,bit = self.conv1([x, bit])
80 |
81 | if self.is_bn:
82 | out = self.bn1(out)
83 |
84 | out1 = self.act(out)
85 |
86 | if self.args.imgwise:
87 | res, bit = self.conv2([out1, bit, bit_img])
88 | else:
89 | res, bit = self.conv2([out1,bit])
90 |
91 | if self.is_bn:
92 | res = self.bn2(res)
93 |
94 | res = res.mul(self.res_scale)
95 | res += residual
96 |
97 | if feat is None:
98 | feat = res / torch.norm(res, p=2)
99 | else:
100 | feat = torch.cat([feat, res / torch.norm(res, p=2)])
101 |
102 | if self.args.imgwise:
103 | return [res, feat, bit, bit_img]
104 | else:
105 | return [res, feat, bit]
106 |
107 |
108 | class Upsampler(nn.Sequential):
109 | def __init__(self, args, conv, scale, n_feats, bn=False, act=False, bias=True, fq=False):
110 | m = []
111 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
112 | for _ in range(int(math.log(scale, 2))):
113 | if fq:
114 | m.append(conv(args, n_feats, 4*n_feats, 3, bias=bias, non_adaptive=True, to_8bit=True))
115 | else:
116 | m.append(conv(n_feats, 4 * n_feats, 3, bias=bias))
117 |
118 | m.append(nn.PixelShuffle(2))
119 | if bn:
120 | m.append(nn.BatchNorm2d(n_feats))
121 | if act == 'relu':
122 | m.append(nn.ReLU(True))
123 | elif act == 'prelu':
124 | m.append(nn.PReLU())
125 |
126 | elif scale == 3:
127 | if fq:
128 | m.append(conv(args, n_feats, 9 * n_feats, 3, bias=bias, non_adaptive=True, to_8bit=True))
129 | else:
130 | m.append(conv(n_feats, 9 * n_feats, 3, bias=bias))
131 |
132 | m.append(nn.PixelShuffle(3))
133 | if bn:
134 | m.append(nn.BatchNorm2d(n_feats))
135 | if act == 'relu':
136 | m.append(nn.ReLU(True))
137 | elif act == 'prelu':
138 | m.append(nn.PReLU())
139 | else:
140 | raise NotImplementedError
141 |
142 | super(Upsampler, self).__init__(*m)
--------------------------------------------------------------------------------
/src/model/edsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | from model import quantize
3 |
4 | import torch.nn as nn
5 | import torch
6 |
7 | import kornia as K
8 | import numpy as np
9 |
10 | import cv2
11 | import math
12 |
13 | def make_model(args, parent=False):
14 | return EDSR(args)
15 |
16 | class EDSR(nn.Module):
17 | def __init__(self, args, conv=common.default_conv):
18 | super(EDSR, self).__init__()
19 | n_feats = args.n_feats
20 | kernel_size = 3
21 | scale = args.scale[0]
22 | act = 'relu'
23 |
24 | self.sub_mean = common.MeanShift(args.rgb_range)
25 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
26 | self.fq = args.fq
27 |
28 |
29 | # Head module
30 | if args.fq:
31 | m_head = [quantize.QConv2d(args, args.n_colors, n_feats, kernel_size, bias=True, non_adaptive=True, to_8bit=True)]
32 | else:
33 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
34 |
35 | # Body module
36 | m_body = [
37 | common.ResBlock(
38 | args, conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
39 | ) for _ in range(args.n_resblocks)
40 | ]
41 |
42 | if args.fq:
43 | m_body.append(quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=True))
44 | else:
45 | m_body.append(conv(n_feats, n_feats, kernel_size))
46 |
47 | # Tail module
48 | if args.fq:
49 | m_tail = [
50 | common.Upsampler(args, quantize.QConv2d, scale, n_feats, act=False, fq=args.fq),
51 | quantize.QConv2d(args, n_feats, args.n_colors, kernel_size, bias=True, non_adaptive=True, to_8bit=True)
52 | ]
53 | else:
54 | m_tail = [
55 | common.Upsampler(args, conv, scale, n_feats, act=False),
56 | conv(n_feats, args.n_colors, kernel_size)
57 | ]
58 |
59 | self.head = nn.Sequential(*m_head)
60 | self.body = nn.Sequential(*m_body)
61 | self.tail = nn.Sequential(*m_tail)
62 |
63 | if args.imgwise:
64 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda())
65 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda())
66 | self.tanh = nn.Tanh()
67 | self.ema_epoch = 1
68 | self.init = False
69 |
70 | self.args = args
71 |
72 |
73 | def forward(self, x):
74 | if self.args.imgwise:
75 | image = x
76 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1)
77 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3
78 |
79 | if self.init:
80 | # print(image_grad)
81 | if self.ema_epoch == 1:
82 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0)
83 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0)
84 | nn.init.constant_(self.measure_l, measure_l)
85 | nn.init.constant_(self.measure_u, measure_u)
86 | else:
87 | beta = self.args.ema_beta
88 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta)
89 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta)
90 | nn.init.constant_(self.measure_l, new_measure_l.item())
91 | nn.init.constant_(self.measure_u, new_measure_u.item())
92 |
93 | self.ema_epoch += 1
94 | bit_img = torch.Tensor([0.0]).cuda()
95 |
96 | else:
97 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l))
98 | bit_img_soft = self.tanh(bit_img_soft)
99 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0)
100 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach() # the order matters
101 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1)
102 |
103 | x = self.sub_mean(x)
104 |
105 | if self.args.fq:
106 | bit_fq = torch.zeros(x.shape[0]).cuda()
107 | x, bit_fq= self.head([x, bit_fq])
108 |
109 | else:
110 | x = self.head(x)
111 |
112 | feat = None
113 | bit = torch.zeros(x.shape[0]).cuda()
114 |
115 | res = x
116 |
117 | if self.args.imgwise:
118 | res, feat, bit, bit_img = self.body[:-1]([res, feat, bit, bit_img])
119 | else:
120 | res, feat, bit = self.body[:-1]([res, feat, bit])
121 |
122 |
123 | if self.args.fq:
124 | if self.args.imgwise:
125 | res, bit = self.body[-1:]([res, bit, bit_img])
126 | else:
127 | res, bit = self.body[-1:]([res, bit])
128 | else:
129 | res = self.body[-1:](res)
130 |
131 | res += x
132 |
133 | if self.args.fq:
134 | res, bit_fq = self.tail[0][0]([res, bit_fq]) # conv
135 | res = self.tail[0][1](res) # ps
136 | if len(self.tail[0]) == 4:
137 | res, bit_fq = self.tail[0][2]([res, bit_fq]) # conv
138 | res = self.tail[0][3](res) # ps
139 | x, bit_fq = self.tail[-1]([res, bit_fq]) # conv
140 | else:
141 | x = self.tail(res)
142 |
143 | x = self.add_mean(x)
144 |
145 |
146 | return x, feat, bit
147 |
148 |
149 |
--------------------------------------------------------------------------------
/src/model/quantize.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import time
10 |
11 | class Round(torch.autograd.Function):
12 | @staticmethod
13 | def forward(ctx, x_in):
14 | x_out = torch.round(x_in)
15 | return x_out
16 | @staticmethod
17 | def backward(ctx, g):
18 | return g, None
19 |
20 | class QConv2d(nn.Conv2d):
21 | def __init__(self, args, in_channels, out_channels, kernel_size, stride=1,
22 | padding=1, bias=False, dilation=1, groups=1, non_adaptive=False, to_8bit=False):
23 | super(QConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation, groups=groups)
24 | self.args = args
25 | if not bias: self.bias = None
26 | self.dilation = (dilation, dilation)
27 |
28 | # For quantizing activations
29 | self.lower_a = nn.Parameter(torch.FloatTensor([-128]).cuda())
30 | self.upper_a = nn.Parameter(torch.FloatTensor([128]).cuda())
31 | self.round_a = Round.apply
32 | self.a_bit = self.args.quantize_a
33 |
34 | # For quantizing weights
35 | self.upper_w = nn.Parameter(torch.FloatTensor([128]).cuda())
36 | self.round_w = Round.apply
37 | self.w_bit = self.args.quantize_w
38 |
39 | self.non_adaptive = non_adaptive
40 | self.to_8bit = to_8bit
41 |
42 | if self.to_8bit:
43 | self.w_bit = 8.0
44 | self.a_bit = 8.0
45 |
46 | if not self.non_adaptive and self.args.layerwise:
47 | self.std_layer = []
48 | self.measure_layer = nn.Parameter(torch.FloatTensor([0]).cuda())
49 | self.tanh = nn.Tanh()
50 |
51 | self.ema_epoch = 1
52 | self.bac_epoch = 1
53 | self.init = False
54 |
55 | def init_qparams_a(self, x, quantizer=None):
56 | # Obtain statistics
57 | if quantizer == 'minmax':
58 | lower_a = torch.min(x).detach().cpu()
59 | upper_a = torch.max(x).detach().cpu()
60 |
61 | elif quantizer == 'percentile':
62 | try:
63 | lower_a = torch.quantile(x.reshape(-1), 1.0-self.args.percentile_alpha).detach().cpu()
64 | upper_a = torch.quantile(x.reshape(-1), self.args.percentile_alpha).detach().cpu()
65 | except:
66 | lower_a = np.percentile(x.reshape(-1).detach().cpu(), (1.0-self.args.percentile_alpha)*100.0)
67 | upper_a = np.percentile(x.reshape(-1).detach().cpu(), self.args.percentile_alpha*100.0)
68 |
69 | elif quantizer == 'omse':
70 | lower_a = torch.min(x).detach()
71 | upper_a = torch.max(x).detach()
72 | best_score = 1e+10
73 | for i in range(90):
74 | new_lower = lower_a * (1.0 - i*0.01)
75 | new_upper = upper_a * (1.0 - i*0.01)
76 | x_q = torch.clamp(x, min= new_lower, max=new_upper)
77 | x_q = (x_q - new_lower) / (new_upper - new_lower)
78 | x_q = torch.round(x_q * (2**self.args.quantize_a -1)) / (2**self.args.quantize_a -1)
79 | x_q = x_q * (new_upper - new_lower) + new_lower
80 | score = (x - x_q).abs().pow(2.0).mean()
81 | if score < best_score:
82 | best_score = score
83 | best_lower = new_lower
84 | best_upper = new_upper
85 | lower = best_lower.cpu()
86 | upper = best_upper.cpu()
87 |
88 | # Update q params
89 | if self.ema_epoch == 1:
90 | nn.init.constant_(self.lower_a, lower_a)
91 | nn.init.constant_(self.upper_a, upper_a)
92 | else:
93 | beta = self.args.ema_beta
94 | lower_a = lower_a * (1-beta) + self.lower_a * beta
95 | upper_a = upper_a * (1-beta) + self.upper_a * beta
96 | nn.init.constant_(self.lower_a, lower_a.item())
97 | nn.init.constant_(self.upper_a, upper_a.item())
98 |
99 |
100 | def init_qparams_w(self, w, quantizer=None):
101 | if quantizer == 'minmax':
102 | upper_w = torch.max(torch.abs(self.weight)).detach()
103 | elif quantizer == 'percentile':
104 | try:
105 | upper_w = torch.quantile(self.weight.reshape(-1), self.args.percentile_alpha).detach().cpu()
106 | except:
107 | upper_w = np.percentile(self.weight.reshape(-1).detach().cpu(), self.args.percentile_alpha*100.0)
108 | elif quantizer == 'omse':
109 | upper_w = torch.max(self.weight).detach()
110 | best_score_w = 1e+10
111 | for i in range(50):
112 | new_upper_w = upper_w * (1.0 - i*0.01)
113 | w_q = torch.clamp(self.weight, min=-new_upper_w, max=new_upper_w).detach()
114 | w_q = (w_q + new_upper_w) / (2*new_upper_w )
115 | w_q = torch.round(w_q * (2**self.args.quantize_w -1)) / (2**self.args.quantize_w -1)
116 | w_q = w_q * (2*new_upper_w) - new_upper_w
117 | score = (self.weight - w_q).abs().pow(2.0).mean().detach().cpu()
118 | if score < best_score_w:
119 | best_score_w = score
120 | best_i = i
121 | upper = new_upper_w
122 | upper_w = upper.cpu()
123 |
124 | nn.init.constant_(self.upper_w, upper_w)
125 |
126 | def forward(self,x):
127 | if self.args.imgwise and not self.non_adaptive:
128 | bit_img = x[2]
129 | bit = x[1]
130 | x = x[0]
131 |
132 | if self.w_bit == 32:
133 | if self.init:
134 | # Initialize q params
135 | self.init_qparams_a(x, quantizer=self.args.quantizer)
136 | if self.ema_epoch == 1:
137 | self.init_qparams_w(self.weight, quantizer=self.args.quantizer_w)
138 |
139 | if not self.non_adaptive and self.args.layerwise:
140 | measure = lambda x: torch.mean(torch.std(x.detach(), dim=(1,2,3,)))
141 | measure_layer = measure(x)
142 | self.std_layer.append(measure_layer.detach().cpu().numpy())
143 |
144 | self.ema_epoch += 1
145 |
146 | a_bit = torch.Tensor([32.0]).cuda()
147 | w = self.weight
148 |
149 | else:
150 | # Obtain bit-width
151 | if not self.non_adaptive and (self.args.imgwise or self.args.layerwise):
152 | a_bit = self.a_bit
153 | if self.args.imgwise:
154 | a_bit += bit_img
155 | if self.args.layerwise:
156 | bit_layer_hard = torch.round(torch.clamp(self.measure_layer, min=-1.0, max=1.0))
157 | bit_layer_soft = self.tanh(self.measure_layer)
158 | bit_layer = bit_layer_soft - bit_layer_soft.detach() + bit_layer_hard.detach()
159 | a_bit += bit_layer
160 | else:
161 | a_bit = torch.tensor(self.a_bit).repeat(x.shape[0], 1, 1, 1).cuda()
162 |
163 | # Bit-aware Clipping
164 | if self.args.bac:
165 | do_bac = self.bac_epoch == 1
166 | # Do BaC after init phase ends
167 | if do_bac:
168 | self.bac_epoch += 1
169 | if self.training and not self.to_8bit:
170 | best_score = 1e+10
171 | lower_a = self.lower_a
172 | upper_a = self.upper_a
173 |
174 | for i in range(100):
175 | new_lower_a = self.lower_a * (1.0 - i*0.01)
176 | new_upper_a = self.upper_a * (1.0 - i*0.01)
177 | x_q_temp = torch.clamp(x.clone().detach(), min= new_lower_a, max=new_upper_a)
178 | x_q_temp = (x_q_temp - new_lower_a) / (new_upper_a - new_lower_a)
179 | if not self.non_adaptive and self.args.layerwise:
180 | x_q_temp = torch.round(x_q_temp * (2**(self.a_bit+bit_layer_hard) -1)) / (2**(self.a_bit+bit_layer_hard) -1)
181 | else:
182 | x_q_temp = torch.round(x_q_temp * (2**(self.a_bit) -1)) / (2**(self.a_bit) -1)
183 |
184 | x_q_temp = x_q_temp * (new_upper_a - new_lower_a) + new_lower_a
185 | score = (x.clone().detach() - x_q_temp).abs().pow(2.0).mean()
186 | if score < best_score:
187 | best_i = i
188 | best_score = score
189 | lower_a = new_lower_a
190 | upper_a = new_upper_a
191 |
192 | new_lower = self.lower_a * self.args.bac_beta + lower_a * (1-self.args.bac_beta)
193 | new_upper = self.upper_a * self.args.bac_beta + upper_a * (1-self.args.bac_beta)
194 | nn.init.constant_(self.lower_a, new_lower.item())
195 | nn.init.constant_(self.upper_a, new_upper.item())
196 |
197 | # Quantize activations
198 | x_c = torch.clamp(x, min=self.lower_a, max=self.upper_a)
199 | x_c2 = (x_c - self.lower_a) / (self.upper_a - self.lower_a)
200 | x_c2 = x_c2 * (2**a_bit -1)
201 | x_int = self.round_a(x_c2)
202 | x_int = x_int / (2**a_bit -1)
203 | x_q = x_int * (self.upper_a - self.lower_a) + self.lower_a
204 | x = x_q
205 |
206 | # Quantize weights
207 | w_c = torch.clamp(self.weight, min=-self.upper_w, max=self.upper_w)
208 | w_c2 = (w_c + self.upper_w) / (2*self.upper_w )
209 | w_c2 = (w_c2) * (2**self.w_bit-1)
210 | w_int = self.round_w(w_c2)
211 | w_int = w_int / (2**self.w_bit-1)
212 | w_q = w_int * (2*self.upper_w) - self.upper_w
213 | w = w_q
214 |
215 | self.padding = (self.kernel_size[0]//2, self.kernel_size[1]//2)
216 | out = F.conv2d(x, w, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
217 | bit += a_bit.view(-1)
218 |
219 | return out, bit
--------------------------------------------------------------------------------
/src/model/rdn.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | from model import quantize
3 | import torch
4 | import torch.nn as nn
5 | import kornia as K
6 |
7 | def make_model(args, parent=False):
8 | return RDN(args)
9 |
10 | class RDB_Conv(nn.Module):
11 | def __init__(self, args, inChannels, growRate, kSize=3):
12 | super(RDB_Conv, self).__init__()
13 | Cin = inChannels
14 | G = growRate
15 | bias = True
16 | self.args = args
17 | self.conv = nn.Sequential(*[
18 | quantize.QConv2d(args, Cin, G, kSize, stride=1, padding=(kSize-1)//2, bias=True, dilation=1, groups=1, non_adaptive=False),
19 | nn.ReLU()
20 | ])
21 |
22 | def forward(self, x):
23 | if self.args.imgwise:
24 | bit_img = x[2]
25 | bit = x[1]
26 | x = x[0]
27 | out = x
28 |
29 | if self.args.imgwise:
30 | out, bit = self.conv[0]([out, bit, bit_img])
31 | else:
32 | out, bit = self.conv[0]([out, bit])
33 |
34 | out = self.conv[1](out)
35 | out = torch.cat((x, out), 1)
36 |
37 | if self.args.imgwise:
38 | return [out, bit, bit_img]
39 | else:
40 | return [out, bit]
41 |
42 |
43 | class RDB(nn.Module):
44 | def __init__(self, args, growRate0, growRate, nConvLayers, kSize=3):
45 | super(RDB, self).__init__()
46 | G0 = growRate0
47 | G = growRate
48 | C = nConvLayers
49 |
50 | convs = []
51 | for c in range(C):
52 | convs.append(RDB_Conv(args, G0 + c*G, G, kSize))
53 | self.convs = nn.Sequential(*convs)
54 |
55 | # Local Feature Fusion
56 | self.LFF = quantize.QConv2d(args, G0 + C*G, G0, 1, padding=0, stride=1, bias=True, non_adaptive=True, to_8bit=False) # 1x1 conv is non-adaptively quantized
57 | self.args = args
58 |
59 | def forward(self, x):
60 | if self.args.imgwise:
61 | bit_img = x[3]
62 | bit = x[2]
63 | feat = x[1]
64 | x = x[0]
65 |
66 | out = x
67 | if self.args.imgwise:
68 | out, bit, bit_img = self.convs([out, bit, bit_img])
69 | else:
70 | out, bit = self.convs([out, bit])
71 |
72 | out, bit = self.LFF([out, bit])
73 | out += x
74 |
75 | feat_ = out / torch.norm(out, p=2) / (out.shape[1]*out.shape[2]*out.shape[3])
76 | if feat is None:
77 | feat = feat_
78 | else:
79 | feat = torch.cat([feat, feat_])
80 |
81 | if self.args.imgwise:
82 | return out, feat, bit, bit_img
83 | else:
84 | return out, feat, bit
85 |
86 | class RDN(nn.Module):
87 | def __init__(self, args):
88 | super(RDN, self).__init__()
89 | r = args.scale[0]
90 | G0 = args.G0
91 | kSize = args.RDNkSize
92 |
93 | self.args = args
94 |
95 | # number of RDB blocks, conv layers, out channels
96 | self.D, self.C, self.G = {
97 | 'A': (20, 6, 32),
98 | 'B': (16, 8, 64),
99 | }[args.RDNconfig]
100 |
101 | # Shallow feature extraction net
102 | if args.fq:
103 | self.SFENet1 = quantize.QConv2d(args, args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True)
104 | self.SFENet2 = quantize.QConv2d(args, G0, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True)
105 | else:
106 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
107 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
108 |
109 | # Redidual dense blocks and dense feature fusion
110 | self.RDBs = nn.ModuleList()
111 | for i in range(self.D):
112 | self.RDBs.append(
113 | RDB(args, growRate0 = G0, growRate = self.G, nConvLayers = self.C)
114 | )
115 |
116 | # Global Feature Fusion
117 | self.GFF = nn.Sequential(*[
118 | quantize.QConv2d(args, self.D * G0, G0, 1, padding=0, stride=1, bias=True),
119 | quantize.QConv2d(args, G0, G0, kSize, padding=(kSize-1)//2, stride=1, bias=True)
120 | ])
121 |
122 | # Up-sampling net
123 | if args.fq:
124 | if r == 2 or r == 3:
125 | self.UPNet = nn.Sequential(*[
126 | quantize.QConv2d(args, G0, self.G * r * r, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True),
127 | nn.PixelShuffle(r),
128 | quantize.QConv2d(args, self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True)
129 | ])
130 | elif r == 4:
131 | self.UPNet = nn.Sequential(*[
132 | quantize.QConv2d(args, G0, self.G * 4, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True),
133 | nn.PixelShuffle(2),
134 | quantize.QConv2d(args, self.G, self.G * 4, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True),
135 | nn.PixelShuffle(2),
136 | quantize.QConv2d(args, self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1, bias=True, non_adaptive=True, to_8bit=True)
137 | ])
138 | else:
139 | raise ValueError("scale must be 2 or 3 or 4.")
140 |
141 | else:
142 | if r == 2 or r == 3:
143 | self.UPNet = nn.Sequential(*[
144 | nn.Conv2d(G0, self.G * r * r, kSize, padding=(kSize-1)//2, stride=1),
145 | nn.PixelShuffle(r),
146 | nn.Conv2d(self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
147 | ])
148 | elif r == 4:
149 | self.UPNet = nn.Sequential(*[
150 | nn.Conv2d(G0, self.G * 4, kSize, padding=(kSize-1)//2, stride=1),
151 | nn.PixelShuffle(2),
152 | nn.Conv2d(self.G, self.G * 4, kSize, padding=(kSize-1)//2, stride=1),
153 | nn.PixelShuffle(2),
154 | nn.Conv2d(self.G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
155 | ])
156 | else:
157 | raise ValueError("scale must be 2 or 3 or 4.")
158 |
159 | if args.imgwise:
160 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda())
161 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda())
162 | self.tanh = nn.Tanh()
163 | self.ema_epoch = 1
164 | self.init = False
165 |
166 | def forward(self, x):
167 | if self.args.imgwise:
168 | image = x
169 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1)
170 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3
171 |
172 | if self.init:
173 | # print(image_grad)
174 | if self.ema_epoch == 1:
175 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0)
176 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0)
177 | nn.init.constant_(self.measure_l, measure_l)
178 | nn.init.constant_(self.measure_u, measure_u)
179 | else:
180 | beta = self.args.ema_beta
181 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta)
182 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta)
183 | nn.init.constant_(self.measure_l, new_measure_l.item())
184 | nn.init.constant_(self.measure_u, new_measure_u.item())
185 |
186 | self.ema_epoch += 1
187 | bit_img = torch.Tensor([0.0]).cuda()
188 | else:
189 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l))
190 | bit_img_soft = self.tanh(bit_img_soft)
191 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0)
192 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach()# the order matters
193 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1)
194 |
195 | feat = None; bit = torch.zeros(x.shape[0]).cuda(); bit_fq = torch.zeros(x.shape[0]).cuda()
196 |
197 | if self.args.fq:
198 | f__1, bit_fq = self.SFENet1([x, bit_fq])
199 | x, bit_fq = self.SFENet2([f__1, bit_fq])
200 | else:
201 | f__1 = self.SFENet1(x)
202 | x = self.SFENet2(f__1)
203 |
204 | RDBs_out = []
205 | for i in range(self.D):
206 | if self.args.imgwise:
207 | x, feat, bit, bit_img = self.RDBs[i]([x, feat, bit, bit_img])
208 | else:
209 | x, feat, bit = self.RDBs[i]([x, feat, bit])
210 | RDBs_out.append(x)
211 |
212 | if self.args.imgwise:
213 | x, bit = self.GFF[0]([torch.cat(RDBs_out,1), bit, bit_img])
214 | x, bit = self.GFF[1]([x, bit, bit_img])
215 | else:
216 | x, bit = self.GFF[0]([torch.cat(RDBs_out,1), bit])
217 | x, bit = self.GFF[1]([x, bit])
218 |
219 | x += f__1
220 |
221 | if self.args.fq:
222 | out, bit_fq = self.UPNet[0]([x, bit_fq])
223 | out = self.UPNet[1](out)
224 | out, bit_fq = self.UPNet[2]([out, bit_fq])
225 | if len(self.UPNet) > 3:
226 | out = self.UPNet[3](out)
227 | out, bit_fq = self.UPNet[4]([out, bit_fq])
228 | else:
229 | out = self.UPNet(x)
230 |
231 | return out, feat, bit
232 |
--------------------------------------------------------------------------------
/src/model/srresnet.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from model import quantize
6 | import kornia as K
7 |
8 | def make_model(args, parent=False):
9 | return SRResNet(args)
10 |
11 | class SRResNet(nn.Module):
12 | def __init__(self, args, conv=common.default_conv):
13 | super(SRResNet, self).__init__()
14 |
15 | n_resblocks = args.n_resblocks
16 | n_feats = args.n_feats
17 | kernel_size = 3
18 | scale = args.scale[0]
19 |
20 | # Head module
21 | self.fq = args.fq
22 | if args.fq:
23 | m_head = [quantize.QConv2d(args, args.n_colors, n_feats, kernel_size=9, bias=False, non_adaptive=True, to_8bit=True)]
24 | else:
25 | m_head = [conv(args.n_colors, n_feats, kernel_size=9, bias=False)]
26 | m_head.append(nn.PReLU())
27 |
28 | # Body module
29 | act = 'prelu'
30 | m_body = [
31 | common.ResBlock(
32 | args, conv, n_feats, kernel_size, bias=False, bn=True, act=act, res_scale=args.res_scale
33 | ) for _ in range(n_resblocks)
34 | ]
35 |
36 | if args.fq:
37 | m_body.append(quantize.QConv2d(args, n_feats, n_feats, kernel_size, bias=False))
38 | else:
39 | m_body.append(conv(n_feats, n_feats, kernel_size, bias=False))
40 | m_body.append( nn.BatchNorm2d(n_feats))
41 |
42 | # Tail module
43 | if args.fq:
44 | m_tail = [
45 | common.Upsampler(args, quantize.QConv2d, scale, n_feats, act=act, fq=args.fq, bias=False),
46 | quantize.QConv2d(args, n_feats, args.n_colors, kernel_size=9, bias=False, non_adaptive=True, to_8bit=True)
47 | ]
48 | else:
49 | m_tail = [
50 | common.Upsampler(args, conv, scale, n_feats, act=act, bias=False),
51 | conv(n_feats, args.n_colors, kernel_size=9, bias=False)
52 | ]
53 |
54 | self.head = nn.Sequential(*m_head)
55 | self.body = nn.Sequential(*m_body)
56 | self.tail = nn.Sequential(*m_tail)
57 |
58 | if args.imgwise:
59 | self.measure_l = nn.Parameter(torch.FloatTensor([128]).cuda())
60 | self.measure_u = nn.Parameter(torch.FloatTensor([128]).cuda())
61 | self.tanh = nn.Tanh()
62 | self.ema_epoch = 1
63 | self.init = False
64 |
65 | self.args = args
66 |
67 |
68 | for m in self.modules():
69 | if isinstance(m, nn.Conv2d):
70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
71 | m.weight.data.normal_(0, math.sqrt(2. / n))
72 | if m.bias is not None:
73 | m.bias.data.zero_()
74 |
75 | def forward(self, x):
76 | if self.args.imgwise:
77 | image = x
78 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1)
79 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(1,2,3)) *1e+3
80 |
81 | if self.init:
82 | # print(image_grad)
83 | if self.ema_epoch == 1:
84 | measure_l = torch.quantile(image_grad.detach(), self.args.img_percentile/100.0)
85 | measure_u = torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0)
86 | nn.init.constant_(self.measure_l, measure_l)
87 | nn.init.constant_(self.measure_u, measure_u)
88 | else:
89 | beta = self.args.ema_beta
90 | new_measure_l = self.measure_l * beta + torch.quantile(image_grad.detach(), self.args.img_percentile/100.0) * (1-beta)
91 | new_measure_u = self.measure_u * beta + torch.quantile(image_grad.detach(), 1-self.args.img_percentile/100.0) * (1-beta)
92 | nn.init.constant_(self.measure_l, new_measure_l.item())
93 | nn.init.constant_(self.measure_u, new_measure_u.item())
94 |
95 | self.ema_epoch += 1
96 | bit_img = torch.Tensor([0.0]).cuda()
97 |
98 | else:
99 | bit_img_soft = (image_grad - (self.measure_u + self.measure_l)/2) * (2/(self.measure_u - self.measure_l))
100 | bit_img_soft = self.tanh(bit_img_soft)
101 | bit_img_hard = (image_grad < self.measure_l) * (-1.0) + (image_grad >= self.measure_l) * (image_grad <= self.measure_u) * (0.0) + (image_grad> self.measure_u) *(1.0)
102 | bit_img = bit_img_soft - bit_img_soft.detach() + bit_img_hard.detach() # the order matters
103 | bit_img = bit_img.view(bit_img.shape[0], 1, 1, 1)
104 |
105 | bit_fq = torch.zeros(x.shape[0]).cuda()
106 |
107 | if self.fq:
108 | x, bit_fq= self.head[0]([x, bit_fq])
109 | x = self.head[1:](x)
110 | else:
111 | x = self.head(x)
112 |
113 | feat = None
114 | bit = torch.zeros(x.shape[0]).cuda()
115 |
116 | res = x
117 |
118 |
119 | if self.args.imgwise:
120 | res, feat, bit, bit_img = self.body[:-2]([res, feat, bit, bit_img])
121 | else:
122 | res, feat, bit = self.body[:-2]([res, feat, bit])
123 |
124 |
125 | if self.fq:
126 | if self.args.imgwise:
127 | res, bit = self.body[-2]([res, bit, bit_img])
128 | else:
129 | res, bit =self.body[-2]([res, bit])
130 |
131 | res = self.body[-1](res)
132 | else:
133 | res = self.body[-2:](res)
134 |
135 |
136 | res+= x
137 |
138 | if self.fq:
139 | res, bit_fq = self.tail[0][0]([res, bit_fq]) # conv
140 | res = self.tail[0][1](res) # PS
141 | res = self.tail[0][2](res) # prelu
142 | res1 =res
143 |
144 | if len(self.tail[0]) > 2:
145 | res, bit_fq = self.tail[0][3]([res, bit_fq]) # conv
146 | res = self.tail[0][4](res) # PS
147 | res = self.tail[0][5](res) # prelu
148 | res2 = res
149 | x, bit_fq = self.tail[-1]([res, bit_fq]) # conv
150 | else:
151 | x = self.tail(res)
152 |
153 | return x, feat, bit
154 |
155 |
--------------------------------------------------------------------------------
/src/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='EDSR and MDSR')
4 |
5 |
6 | parser.add_argument('--debug', action='store_true',
7 | help='Enables debug mode')
8 | parser.add_argument('--template', default='.',
9 | help='You can set various templates in option.py')
10 |
11 | # Hardware specifications
12 | parser.add_argument('--n_threads', type=int, default=6,
13 | help='number of threads for data loading')
14 | parser.add_argument('--cpu', action='store_true',
15 | help='use cpu only')
16 | parser.add_argument('--n_GPUs', type=int, default=1,
17 | help='number of GPUs')
18 | parser.add_argument('--seed', type=int, default=1,
19 | help='random seed')
20 |
21 | # Data specifications
22 | parser.add_argument('--dir_data', type=str, default='../dataset',
23 | help='dataset directory')
24 | parser.add_argument('--dir_demo', type=str, default='../test',
25 | help='demo image directory')
26 | parser.add_argument('--data_train', type=str, default='DIV2K',
27 | help='train dataset name')
28 | parser.add_argument('--data_test', type=str, default='DIV2K',
29 | help='test dataset name')
30 | parser.add_argument('--data_range', type=str, default='1-800/801-810',
31 | help='train/test data range')
32 | parser.add_argument('--ext', type=str, default='sep',
33 | help='dataset file extension')
34 | parser.add_argument('--scale', type=str, default='4',
35 | help='super resolution scale')
36 | parser.add_argument('--patch_size', type=int, default=192,
37 | help='output patch size')
38 | parser.add_argument('--rgb_range', type=int, default=255,
39 | help='maximum value of RGB')
40 | parser.add_argument('--n_colors', type=int, default=3,
41 | help='number of color channels to use')
42 | parser.add_argument('--chop', action='store_true',
43 | help='enable memory-efficient forward')
44 | parser.add_argument('--no_augment', action='store_true',
45 | help='do not use data augmentation')
46 |
47 | # Model specifications
48 | parser.add_argument('--model', default='EDSR',
49 | help='model name')
50 | parser.add_argument('--act', type=str, default='relu',
51 | help='activation function')
52 | parser.add_argument('--pre_train', type=str, default='',
53 | help='pre-trained model directory')
54 | parser.add_argument('--extend', type=str, default='.',
55 | help='pre-trained model directory')
56 | parser.add_argument('--n_resblocks', type=int, default=16,
57 | help='number of residual blocks')
58 | parser.add_argument('--n_feats', type=int, default=64,
59 | help='number of feature maps')
60 | parser.add_argument('--res_scale', type=float, default=1,
61 | help='residual scaling')
62 | parser.add_argument('--shift_mean', default=True,
63 | help='subtract pixel mean from the input')
64 | parser.add_argument('--dilation', action='store_true',
65 | help='use dilated convolution')
66 | parser.add_argument('--precision', type=str, default='single',
67 | choices=('single', 'half'),
68 | help='FP precision for test (single | half)')
69 |
70 | # Option for Residual dense network (RDN)
71 | parser.add_argument('--G0', type=int, default=64,
72 | help='default number of filters. (Use in RDN)')
73 | parser.add_argument('--RDNkSize', type=int, default=3,
74 | help='default kernel size. (Use in RDN)')
75 | parser.add_argument('--RDNconfig', type=str, default='B',
76 | help='parameters config of RDN. (Use in RDN)')
77 |
78 |
79 | # Training specifications
80 | parser.add_argument('--reset', action='store_true',
81 | help='reset the training')
82 | parser.add_argument('--test_every', type=int, default=1000,
83 | help='do test per every N batches')
84 | # this is same as iters per batch
85 | parser.add_argument('--epochs', type=int, default=300,
86 | help='number of epochs to train')
87 | parser.add_argument('--batch_size', type=int, default=16,
88 | help='input batch size for training')
89 | parser.add_argument('--split_batch', type=int, default=1,
90 | help='split the batch into smaller chunks')
91 | parser.add_argument('--self_ensemble', action='store_true',
92 | help='use self-ensemble method for test')
93 | parser.add_argument('--test_only', action='store_true',
94 | help='set this option to test the model')
95 | parser.add_argument('--gan_k', type=int, default=1,
96 | help='k value for adversarial loss')
97 |
98 | # Optimization specifications
99 | parser.add_argument('--step', type=int, default='1',
100 | help='learning rate step size')
101 | parser.add_argument('--gamma', type=float, default=0.9,
102 | help='learning rate decay factor for step decay')
103 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
104 | help='ADAM beta')
105 | parser.add_argument('--epsilon', type=float, default=1e-8,
106 | help='ADAM epsilon for numerical stability')
107 |
108 | # Log specifications
109 | parser.add_argument('--save', type=str, default='test',
110 | help='file name to save')
111 | parser.add_argument('--load', type=str, default='',
112 | help='file name to load')
113 | parser.add_argument('--resume', type=int, default=0,
114 | help='resume from specific checkpoint')
115 | parser.add_argument('--save_models', action='store_true',
116 | help='save all intermediate models')
117 | parser.add_argument('--print_every', type=int, default=100,
118 | help='how many batches to wait before logging training status')
119 | parser.add_argument('--save_results', action='store_true',
120 | help='save output results')
121 | parser.add_argument('--save_gt', action='store_true',
122 | help='save low-resolution and high-resolution images together')
123 |
124 | parser.add_argument('--quantize_a', type=float, default=32, help='activation_bit')
125 | parser.add_argument('--quantize_w', type=float, default=32, help='weight_bit')
126 |
127 | parser.add_argument('--batch_size_calib', type=int, default=16, help='input batch size for calib')
128 | parser.add_argument('--batch_size_update', type=int, default=2, help='input batch size for update')
129 | parser.add_argument('--num_data', type=int, default=800, help='number of data for PTQ')
130 |
131 | parser.add_argument('--fq', action='store_true', help='fully quantize')
132 | parser.add_argument('--imgwise', action='store_true', help='imgwise')
133 | parser.add_argument('--layerwise', action='store_true', help='layerwise')
134 | parser.add_argument('--bac', action='store_true', help='bac')
135 |
136 | parser.add_argument('--lr_w', type=float, default=0.001, help='learning rate')
137 | parser.add_argument('--lr_a', type=float, default=0.05, help='learning rate')
138 | parser.add_argument('--lr_measure_layer', type=float, default=0.005, help='learning rate')
139 | parser.add_argument('--lr_measure_img', type=float, default=0.01, help='learning rate')
140 | parser.add_argument('--w_bitloss', type=float, default=50.0, help='weight for bit loss')
141 | parser.add_argument('--w_sktloss', type=float, default=10.0, help='weight for skt loss')
142 |
143 | parser.add_argument('--test_patch', action='store_true', help='test patch')
144 | parser.add_argument('--test_patch_size', type=int, default=96, help='test patch size')
145 | parser.add_argument('--test_step_size', type=int, default=90, help='test step size')
146 | parser.add_argument('--test_own', type=str, default=None, help='directory for own test image')
147 | parser.add_argument('--n_parallel', type=int, default=1, help='number of patches for parallel processing')
148 |
149 | parser.add_argument('--quantizer', default='minmax', choices=('minmax', 'percentile', 'omse'), help='quantizer to use')
150 | parser.add_argument('--quantizer_w', default='minmax', choices=('minmax', 'percentile', 'omse'), help='quantizer to use')
151 | parser.add_argument('--percentile_alpha', type=float, default=0.99, help='used when quantizer is percentile')
152 |
153 | parser.add_argument('--ema_beta', type=float, default=0.9, help='beta for EMA')
154 | parser.add_argument('--bac_beta', type=float, default=0.5, help='beta for EMA in BaC')
155 |
156 | parser.add_argument('--img_percentile', type=float, default=10.0, help='clip percentile for u,l that ranges from 0~100')
157 | parser.add_argument('--layer_percentile', type=float, default=30.0, help='clip percentile for u,l that ranges from 0~100')
158 |
159 | args = parser.parse_args()
160 |
161 | args.scale = list(map(lambda x: int(x), args.scale.split('+')))
162 | args.data_train = args.data_train.split('+')
163 | args.data_test = args.data_test.split('+')
164 |
165 | if args.epochs == 0:
166 | args.epochs = 1e8
167 |
168 | for arg in vars(args):
169 | if vars(args)[arg] == 'True':
170 | vars(args)[arg] = True
171 | elif vars(args)[arg] == 'False':
172 | vars(args)[arg] = False
173 |
174 |
--------------------------------------------------------------------------------
/src/run.sh:
--------------------------------------------------------------------------------
1 | DATA_DIR=/mnt/disk1/cheeun914/datasets/
2 | scale=4
3 |
4 | edsr() {
5 | CUDA_VISIBLE_DEVICES=$1 python main.py \
6 | --model EDSR --scale $scale \
7 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
8 | --pre_train ../pretrained_model/edsr_baseline_x$scale.pt \
9 | --epochs 10 --test_every 50 --print_every 10 \
10 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \
11 | --data_test Set5 --dir_data $DATA_DIR \
12 | --quantize_a $2 --quantize_w $3 \
13 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \
14 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \
15 | --w_bitloss 50.0 --w_sktloss 10.0 --imgwise --layerwise --bac \
16 | --img_percentile 10.0 --layer_percentile 30.0 \
17 | --save edsrbaseline_x$scale/w$3a$2-adabm-nonfq \
18 | --seed 1 \
19 | #
20 | }
21 |
22 | edsr_fq() {
23 | CUDA_VISIBLE_DEVICES=$1 python main.py \
24 | --model EDSR --scale $scale \
25 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
26 | --pre_train ../pretrained_model/edsr_baseline_x$scale.pt \
27 | --epochs 10 --test_every 50 --print_every 10 \
28 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \
29 | --data_test Set5+Urban100 --dir_data $DATA_DIR \
30 | --quantize_a $2 --quantize_w $3 \
31 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \
32 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \
33 | --w_bitloss 50.0 --w_sktloss 10.0 --imgwise --layerwise --bac \
34 | --img_percentile 10.0 --layer_percentile 30.0 \
35 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq \
36 | --seed 1 \
37 | --fq \
38 | #
39 | }
40 |
41 | edsr_eval() {
42 | CUDA_VISIBLE_DEVICES=$1 python main.py \
43 | --model EDSR --scale $scale \
44 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
45 | --data_test Urban100+test2k+test4k --dir_data $DATA_DIR \
46 | --quantize_a $2 --quantize_w $3 \
47 | --imgwise --layerwise \
48 | --test_only \
49 | --save edsrbaseline_x$scale/w$3a$2-adabm-nonfq \
50 | --test_patch --test_patch_size 96 --test_step_size 96 \
51 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-nonfq.pt \
52 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-nonfq/model/checkpoint.pt \
53 | # --save_results \
54 | }
55 |
56 | edsr_fq_eval() {
57 | CUDA_VISIBLE_DEVICES=$1 python main.py \
58 | --model EDSR --scale $scale \
59 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
60 | --data_test Set5+Set14+B100+Urban100 --dir_data $DATA_DIR \
61 | --quantize_a $2 --quantize_w $3 \
62 | --imgwise --layerwise \
63 | --fq \
64 | --test_only \
65 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq-test \
66 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-fq.pt \
67 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \
68 | # --save_results \
69 | }
70 |
71 | edsr_fq_eval_own() {
72 | CUDA_VISIBLE_DEVICES=$1 python main.py \
73 | --model EDSR --scale $scale \
74 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
75 | --data_test Set5 --dir_data $DATA_DIR \
76 | --quantize_a $2 --quantize_w $3 \
77 | --imgwise --layerwise \
78 | --test_only \
79 | --save edsrbaseline_x$scale/w$3a$2-adabm-fq-test \
80 | --fq \
81 | --test_patch --test_patch_size 96 --test_step_size 96 \
82 | --test_own '/dir/to/own/test/img' \
83 | --pre_train ../pretrained_model/edsr_baseline_x$scale-w$3a$2-adabm-fq.pt \
84 | # --pre_train ../experiment/edsrbaseline_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \
85 | # --save_results \
86 | }
87 |
88 | rdn_fq(){
89 | CUDA_VISIBLE_DEVICES=$1 python main.py \
90 | --model RDN --scale $scale \
91 | --pre_train ../pretrained_model/rdn_baseline_x$scale.pt \
92 | --epochs 10 --test_every 50 --print_every 10 \
93 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 288 \
94 | --data_test Set5 --dir_data $DATA_DIR \
95 | --quantize_a $2 --quantize_w $3 \
96 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \
97 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \
98 | --w_bitloss 50.0 --imgwise --layerwise --bac \
99 | --img_percentile 10.0 --layer_percentile 30.0 \
100 | --save rdn_x$scale/w$3a$2-adabm-fq \
101 | --seed 1 \
102 | --fq \
103 | #
104 | }
105 |
106 | rdn_fq_eval(){
107 | CUDA_VISIBLE_DEVICES=$1 python main.py \
108 | --model RDN --scale $scale \
109 | --data_test Set5+Set14+B100+Urban100 --dir_data $DATA_DIR \
110 | --quantize_a $2 --quantize_w $3 \
111 | --imgwise --layerwise \
112 | --test_only \
113 | --fq \
114 | --save rdn_x$scale/w$3a$2-adabm-fq-test \
115 | --pre_train ../pretrained_model/rdn_x$scale-w$3a$2-adabm-fq.pt \
116 | # --pre_train ../experiment/rdn_x$scale/w$3a$2-adabm-fq/model/checkpoint.pt \
117 | # --save_results \
118 | }
119 |
120 | srresnet() {
121 | CUDA_VISIBLE_DEVICES=$1 python main.py \
122 | --model SRResNet --scale $scale \
123 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
124 | --pre_train ../pretrained_model/bnsrresnet_x$scale.pt \
125 | --epochs 10 --test_every 50 --print_every 10 \
126 | --batch_size_update 2 --batch_size_calib 16 --num_data 100 --patch_size 384 \
127 | --data_test Set5 --dir_data $DATA_DIR \
128 | --quantize_a $2 --quantize_w $3 \
129 | --quantizer 'minmax' --ema_beta 0.9 --quantizer_w 'omse' \
130 | --lr_w 0.01 --lr_a 0.01 --lr_measure_img 0.1 --lr_measure_layer 0.01 \
131 | --w_bitloss 50.0 --imgwise --layerwise --bac \
132 | --img_percentile 10.0 --layer_percentile 30.0 \
133 | --save srresnet_x$scale/w$3a$2-adabm-nonfq \
134 | --seed 1 \
135 | #
136 | }
137 |
138 | srresnet_eval() {
139 | CUDA_VISIBLE_DEVICES=$1 python main.py \
140 | --model SRResNet --scale $scale \
141 | --n_feats 64 --n_resblocks 16 --res_scale 1.0 \
142 | --data_test Urban100 --dir_data $DATA_DIR \
143 | --quantize_a $2 --quantize_w $3 \
144 | --imgwise --layerwise \
145 | --test_only \
146 | --save srresnet_x$scale/w$3a$2-adabm-nonfq-test \
147 | --test_patch --test_patch_size 96 --test_step_size 96 \
148 | --pre_train ../pretrained_model/srresnet_x$scale-w$3a$2-adabm-nonfq.pt \
149 | # --pre_train ../experiment/srresnet_x$scale/w$3a$2-adabm-nonfq/model/checkpoint.pt \
150 | # --save_results \
151 | }
152 |
153 | "$@"
154 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import math
5 | import time
6 | import datetime
7 | import shutil
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn.utils as utils
13 | from torch.utils.data import Dataset
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import torch.optim.lr_scheduler as lrs
17 | from torchvision.utils import save_image
18 |
19 | from decimal import Decimal
20 | from tqdm import tqdm
21 | import cv2
22 |
23 | import utility
24 | from model.quantize import QConv2d
25 | import kornia as K
26 |
27 | class Trainer():
28 | def __init__(self, args, loader, my_model, ckp):
29 | self.args = args
30 | self.scale = args.scale
31 | self.ckp = ckp
32 | self.loader_init = loader.loader_init
33 | self.loader_train = loader.loader_train
34 | self.loader_test = loader.loader_test
35 |
36 | self.model = my_model
37 | self.epoch = 0
38 |
39 | shutil.copyfile('./trainer.py', os.path.join(self.ckp.dir, 'trainer.py'))
40 | shutil.copyfile('./model/quantize.py', os.path.join(self.ckp.dir, 'quantize.py'))
41 |
42 | quant_params_a = [v for k, v in self.model.model.named_parameters() if '_a' in k]
43 | quant_params_w = [v for k, v in self.model.model.named_parameters() if '_w' in k]
44 |
45 | if args.layerwise or args.imgwise:
46 | quant_params_measure= []
47 | if args.layerwise:
48 | quant_params_measure_layer = [v for k, v in self.model.model.named_parameters() if 'measure_layer' in k]
49 | quant_params_measure.append({'params': quant_params_measure_layer, 'lr': args.lr_measure_layer})
50 | if args.imgwise:
51 | quant_params_measure_image = [v for k, v in self.model.model.named_parameters() if 'measure' in k and 'measure_layer' not in k]
52 | quant_params_measure.append({'params': quant_params_measure_image, 'lr': args.lr_measure_img})
53 |
54 | self.optimizer_measure = torch.optim.Adam(quant_params_measure, betas=args.betas, eps=args.epsilon)
55 | self.scheduler_measure = lrs.StepLR(self.optimizer_measure, step_size=args.step, gamma=args.gamma)
56 |
57 | self.optimizer_a = torch.optim.Adam(quant_params_a, lr=args.lr_a, betas=args.betas, eps=args.epsilon)
58 | self.optimizer_w = torch.optim.Adam(quant_params_w, lr=args.lr_w, betas=args.betas, eps=args.epsilon)
59 | self.scheduler_a = lrs.StepLR(self.optimizer_a, step_size=args.step, gamma=args.gamma)
60 | self.scheduler_w = lrs.StepLR(self.optimizer_w, step_size=args.step, gamma=args.gamma)
61 |
62 | self.skt_losses = utility.AverageMeter()
63 | self.pix_losses = utility.AverageMeter()
64 | self.bit_losses = utility.AverageMeter()
65 |
66 | self.num_quant_modules = 0
67 | for n, m in self.model.named_modules():
68 | if isinstance(m, QConv2d):
69 | if not m.to_8bit: # 8-bit (first or last) modules are excluded for the bit count
70 | self.num_quant_modules +=1
71 |
72 | # for initialization
73 | if not args.test_only:
74 | for n, m in self.model.named_modules():
75 | if isinstance(m, QConv2d):
76 | setattr(m, 'w_bit', 32.0)
77 | setattr(m, 'a_bit', 32.0)
78 | setattr(m, 'init', True)
79 | if args.imgwise:
80 | setattr(self.model.model, 'init', True)
81 |
82 | def get_stage_optimizer_scheduler(self):
83 | if self.args.imgwise or self.args.layerwise:
84 | # w -> a -> measure
85 | if (self.epoch-1) % 3 == 0:
86 | param_name = '_w'
87 | optimizer = self.optimizer_w
88 | scheduler = self.scheduler_w
89 | elif (self.epoch-1) % 3 == 1:
90 | param_name = '_a'
91 | optimizer = self.optimizer_a
92 | scheduler = self.scheduler_a
93 | else:
94 | param_name = 'measure'
95 | optimizer = self.optimizer_measure
96 | scheduler = self.scheduler_measure
97 | else:
98 | if (self.epoch-1) % 2 == 0:
99 | param_name = '_w'
100 | optimizer = self.optimizer_w
101 | scheduler = self.scheduler_w
102 | else:
103 | param_name = '_a'
104 | optimizer = self.optimizer_a
105 | scheduler = self.scheduler_a
106 |
107 | return param_name, optimizer, scheduler
108 |
109 | def set_bit(self, teacher=False):
110 | for n, m in self.model.named_modules():
111 | if isinstance(m, QConv2d):
112 | if teacher:
113 | setattr(m, 'w_bit', 32.0)
114 | setattr(m, 'a_bit', 32.0)
115 | elif m.non_adaptive:
116 | if m.to_8bit:
117 | setattr(m, 'w_bit', 8.0)
118 | setattr(m, 'a_bit', 8.0)
119 | else:
120 | setattr(m, 'w_bit', self.args.quantize_w)
121 | setattr(m, 'a_bit', self.args.quantize_a)
122 | else:
123 | setattr(m, 'w_bit', self.args.quantize_w)
124 | setattr(m, 'a_bit', self.args.quantize_a)
125 |
126 | setattr(m, 'init', False)
127 |
128 | if self.args.imgwise:
129 | setattr(self.model.model, 'init', False)
130 |
131 |
132 | def train(self):
133 | if self.epoch > 0:
134 | param_name, optimizer, scheduler = self.get_stage_optimizer_scheduler()
135 | epoch_update = 'Update param ' + param_name
136 | lr = optimizer.state_dict()['param_groups'][0]['lr']
137 | self.ckp.write_log(
138 | '\n[Epoch {}]\t {}\t Learning rate for param: {:.2e}'.format(
139 | self.epoch,
140 | epoch_update,
141 | Decimal(lr))
142 | )
143 |
144 | self.model.train()
145 |
146 | timer_data, timer_model = utility.timer(), utility.timer()
147 | start_time = time.time()
148 |
149 | if self.epoch == 0:
150 | # Initialize Q parameters using freezed FP model
151 | params = self.model.named_parameters()
152 | for name1, params1 in params:
153 | params1.requires_grad=False
154 |
155 | for batch, (lr, _, idx_scale,) in enumerate(self.loader_init):
156 | lr, = self.prepare(lr)
157 |
158 | timer_data.hold()
159 | timer_model.tic()
160 | torch.cuda.empty_cache()
161 | with torch.no_grad():
162 | sr_temp, feat_temp, bit_temp = self.model(lr, idx_scale)
163 | display_bit = bit_temp.mean() / self.num_quant_modules
164 |
165 | if self.args.imgwise or self.args.layerwise:
166 | self.ckp.write_log('[{}/{}] [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format(
167 | (batch + 1) * self.loader_init.batch_size,
168 | len(self.loader_init.dataset),
169 | display_bit,
170 | timer_model.release(),
171 | timer_data.release(),
172 | ))
173 |
174 | if self.args.layerwise:
175 | measure_layer_list = []
176 | for n, m in self.model.named_modules():
177 | if isinstance(m, QConv2d) and not m.non_adaptive:
178 | if hasattr(m, 'std_layer') and len(m.std_layer) > 0:
179 | measure_layer_list.append(np.mean(m.std_layer))
180 | mu = np.mean(measure_layer_list)
181 | sig = np.std(measure_layer_list)
182 |
183 | lower = np.percentile(measure_layer_list, self.args.layer_percentile)
184 | upper = np.percentile(measure_layer_list, 100.0-self.args.layer_percentile)
185 |
186 | for n, m in self.model.named_modules():
187 | if isinstance(m, QConv2d):
188 | if not m.non_adaptive:
189 | normalized_measure = (np.mean(m.std_layer) < lower) * (-1.0) + (np.mean(m.std_layer) > upper) * (1.0)
190 | torch.nn.init.constant_(m.measure_layer, normalized_measure)
191 | m.std_layer.clear()
192 |
193 | print('Calibration done!')
194 | # self.set_bit(teacher=False)
195 |
196 | if self.args.imgwise:
197 | print('image-lower:{:.3f}, image-upper:{:.3f}'.format(
198 | self.model.model.measure_l.data.item(),
199 | self.model.model.measure_u.data.item(),
200 | ))
201 | if self.args.layerwise:
202 | bit_layer_list = []
203 | for n, m in self.model.named_modules():
204 | if isinstance(m, QConv2d):
205 | if not m.non_adaptive:
206 | bit_layer_list.append(int(m.measure_layer.data.item()))
207 | print(bit_layer_list)
208 | print(np.mean(bit_layer_list))
209 | else:
210 | # Update quantization parameters
211 | for k, v in self.model.named_parameters():
212 | if param_name in k:
213 | v.requires_grad=True
214 | else:
215 | v.requires_grad=False
216 |
217 | self.bit_losses.reset()
218 | self.pix_losses.reset()
219 | self.skt_losses.reset()
220 |
221 | for batch, (lr, _, idx_scale,) in enumerate(self.loader_train):
222 | lr, = self.prepare(lr)
223 | timer_data.hold()
224 | timer_model.tic()
225 |
226 | optimizer.zero_grad()
227 |
228 | with torch.no_grad():
229 | self.set_bit(teacher=True)
230 | sr_t, feat_t, bit_t = self.model(lr, idx_scale)
231 |
232 | self.set_bit(teacher=False)
233 | sr, feat, bit = self.model(lr, idx_scale)
234 |
235 | loss = 0.0
236 | pix_loss = F.l1_loss(sr, sr_t)
237 |
238 | loss += pix_loss
239 |
240 | skt_loss = 0.0
241 | for block in range(len(feat)):
242 | skt_loss += self.args.w_sktloss * torch.norm(feat[block]-feat_t[block], p=2) / sr.shape[0] / len(feat)
243 | loss += skt_loss
244 |
245 | if self.args.layerwise or self.args.imgwise:
246 | average_bit = bit.mean() / self.num_quant_modules
247 | bit_loss = self.args.w_bitloss* torch.max(average_bit-(self.args.quantize_a), torch.zeros_like(average_bit))
248 | if param_name == 'measure':
249 | loss += bit_loss
250 |
251 | loss.backward()
252 |
253 | self.pix_losses.update(pix_loss.item(), lr.size(0))
254 | display_pix_loss = f'L_pix: {self.pix_losses.avg: .3f}'
255 | self.skt_losses.update(skt_loss.item(), lr.size(0))
256 | display_skt_loss = f'L_skt: {self.skt_losses.avg: .3f}'
257 |
258 | if self.args.layerwise or self.args.imgwise:
259 | self.bit_losses.update(bit_loss.item(), lr.size(0))
260 | display_bit_loss = f'L_bit: {self.bit_losses.avg: .3f}'
261 |
262 | optimizer.step()
263 | timer_model.hold()
264 |
265 | if (batch + 1) % self.args.print_every == 0:
266 | display_bit = bit.mean() / self.num_quant_modules
267 | if self.args.layerwise or self.args.imgwise:
268 | self.ckp.write_log('[{}/{}]\t{} \t{} \t{} [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format(
269 | (batch + 1) * self.loader_train.batch_size,
270 | len(self.loader_train.dataset),
271 | display_pix_loss,
272 | display_skt_loss,
273 | display_bit_loss,
274 | display_bit,
275 | timer_model.release(),
276 | timer_data.release(),
277 | ))
278 | else:
279 | self.ckp.write_log('[{}/{}]\t{} \t{} [bit:{:.2f}] \t{:.1f}+{:.1f}s'.format(
280 | (batch + 1) * self.loader_train.batch_size,
281 | len(self.loader_train.dataset),
282 | display_pix_loss,
283 | display_skt_loss,
284 | display_bit,
285 | timer_model.release(),
286 | timer_data.release(),
287 | ))
288 | timer_data.tic()
289 |
290 | scheduler.step()
291 |
292 | self.epoch += 1
293 |
294 | end_time = time.time()
295 | time_interval = end_time - start_time
296 | t_string = "Epoch Running Time is: " + str(datetime.timedelta(seconds=time_interval)) + "\n"
297 | self.ckp.write_log('{}'.format(t_string))
298 |
299 | def patch_inference(self, model, lr, idx_scale):
300 | patch_idx = 0
301 | tot_bit_image = 0
302 | if self.args.n_parallel!=1:
303 | lr_list, num_h, num_w, h, w = utility.crop_parallel(lr, self.args.test_patch_size, self.args.test_step_size)
304 | sr_list = torch.Tensor().cuda()
305 | for lr_sub_index in range(len(lr_list)// self.args.n_parallel + 1):
306 | torch.cuda.empty_cache()
307 | with torch.no_grad():
308 | sr_sub, feat, bit = self.model(lr_list[lr_sub_index* self.args.n_parallel: (lr_sub_index+1)*self.args.n_parallel], idx_scale)
309 | sr_sub = utility.quantize(sr_sub, self.args.rgb_range)
310 | sr_list = torch.cat([sr_list, sr_sub])
311 | average_bit = bit.mean() / self.num_quant_modules
312 | tot_bit_image += average_bit
313 | patch_idx += 1
314 | sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0])
315 | else:
316 | lr_list, num_h, num_w, h, w = utility.crop(lr, self.args.test_patch_size, self.args.test_step_size)
317 | sr_list = []
318 | for lr_sub_img in lr_list:
319 | torch.cuda.empty_cache()
320 | with torch.no_grad():
321 | sr_sub, feat, bit = self.model(lr_sub_img, idx_scale)
322 | sr_sub = utility.quantize(sr_sub, self.args.rgb_range)
323 | sr_list.append(sr_sub)
324 | average_bit = bit.mean() / self.num_quant_modules
325 | tot_bit_image += average_bit
326 | patch_idx += 1
327 | sr = utility.combine(sr_list, num_h, num_w, h, w, self.args.test_patch_size, self.args.test_step_size, self.scale[0])
328 |
329 | bit = tot_bit_image / patch_idx
330 |
331 | return sr, feat, bit
332 |
333 | def test(self):
334 | torch.set_grad_enabled(False)
335 |
336 | # if True:
337 | if self.epoch > 1 or self.args.test_only:
338 | self.ckp.write_log('\nEvaluation:')
339 | self.ckp.add_log(
340 | torch.zeros(1, len(self.loader_test), len(self.scale))
341 | )
342 | self.model.eval()
343 | timer_test = utility.timer()
344 |
345 | if self.epoch == 2 or self.args.test_only:
346 | ################### Num of Params, Storage Size ####################
347 | n_params = 0
348 | n_params_q = 0
349 | for k, v in self.model.named_parameters():
350 | nn = np.prod(v.size())
351 | n_params += nn
352 |
353 | if 'weight' in k:
354 | name_split = k.split(".")
355 | del name_split[-1]
356 | module_temp = self.model
357 | for n in name_split:
358 | module_temp = getattr(module_temp, n)
359 | if isinstance(module_temp, QConv2d):
360 | n_params_q += nn * module_temp.w_bit / 32.0
361 | # print(k, module_temp.w_bit)
362 | else:
363 | n_params_q += nn
364 | else:
365 | n_params_q += nn
366 |
367 | self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3)))
368 | self.ckp.write_log('Model Size: {:.3f}K'.format(n_params_q/(10**3)))
369 |
370 | if self.args.save_results:
371 | self.ckp.begin_background()
372 |
373 | ############################## TEST FOR OWN #############################
374 | if self.args.test_own is not None:
375 | test_img = cv2.imread(self.args.test_own)
376 | lr = torch.tensor(test_img).permute(2,0,1).float().cuda()
377 | lr = torch.flip(lr, (0,)) # for color
378 | lr = lr.unsqueeze(0)
379 |
380 | tot_bit = 0
381 | for idx_scale, scale in enumerate(self.scale):
382 | if self.args.test_patch:
383 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
384 | img_bit = bit
385 | else:
386 | with torch.no_grad():
387 | sr, feat, bit = self.model(lr, idx_scale)
388 | img_bit = bit.mean() / self.num_quant_modules
389 |
390 | sr = utility.quantize(sr, self.args.rgb_range)
391 | save_list = [sr]
392 |
393 |
394 | filename = self.args.test_own.split('/')[-1].split('.')[0]
395 | if self.args.save_results:
396 | save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit)
397 | self.ckp.save_results('test_own', save_name, save_list)
398 |
399 | self.ckp.write_log('[{} x{}] Average Bit: {:.2f} '.format(filename, scale, img_bit))
400 |
401 | ############################## TEST FOR TEST SET #############################
402 | if self.args.test_own is None:
403 | for idx_data, d in enumerate(self.loader_test):
404 | for idx_scale, scale in enumerate(self.scale):
405 | d.dataset.set_scale(idx_scale)
406 | tot_ssim =0
407 | tot_bit =0
408 | i=0
409 | bitops =0
410 | for lr, hr, filename in tqdm(d, ncols=80):
411 | i+=1
412 | lr, hr = self.prepare(lr, hr)
413 |
414 | if self.args.test_patch:
415 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
416 | if self.args.n_parallel!=1: hr = hr[:, :, :lr.shape[2]*self.scale[0], :lr.shape[2]*self.scale[0]]
417 | img_bit = bit.item()
418 | else:
419 | with torch.no_grad():
420 | sr, feat, bit = self.model(lr, idx_scale)
421 | img_bit = bit.mean().item() / self.num_quant_modules
422 |
423 |
424 | sr = utility.quantize(sr, self.args.rgb_range)
425 | save_list = [sr]
426 |
427 | psnr, ssim = utility.calc_psnr(sr, hr, scale, self.args.rgb_range, dataset=d)
428 |
429 | self.ckp.ssim_log[-1, idx_data, idx_scale] += ssim
430 | self.ckp.log[-1, idx_data, idx_scale] += psnr
431 | self.ckp.bit_log[-1, idx_data, idx_scale] += img_bit
432 |
433 | if self.args.save_gt:
434 | save_list.extend([lr, hr])
435 |
436 | if self.args.save_results:
437 | save_name = '{}_x{}_{:.2f}dB_{:.2f}bit'.format(filename[0], scale, psnr, img_bit)
438 | self.ckp.save_results(d, save_name, save_list, scale)
439 |
440 | self.ckp.log[-1, idx_data, idx_scale] /= len(d)
441 | self.ckp.ssim_log[-1, idx_data, idx_scale] /= len(d)
442 | self.ckp.bit_log[-1, idx_data, idx_scale] /= len(d)
443 |
444 | best = self.ckp.log.max(0)
445 | self.ckp.write_log(
446 | '[{} x{}]\tPSNR: {:.3f} \t SSIM: {:.4f} \tBit: {:.2f} \t(Best: {:.3f} @epoch {})'.format(
447 | d.dataset.name,
448 | scale,
449 | self.ckp.log[-1, idx_data, idx_scale],
450 | self.ckp.ssim_log[-1, idx_data, idx_scale],
451 | self.ckp.bit_log[-1, idx_data, idx_scale],
452 | best[0][idx_data, idx_scale],
453 | best[1][idx_data, idx_scale] + 1
454 | )
455 | )
456 |
457 | if self.args.save_results:
458 | self.ckp.end_background()
459 |
460 | # save models
461 | if not self.args.test_only:
462 | self.ckp.save(self, self.epoch, is_best=(best[1][0, 0] + 1 == self.epoch -1))
463 |
464 | torch.set_grad_enabled(True)
465 |
466 | def test_teacher(self):
467 | torch.set_grad_enabled(False)
468 | self.model.eval()
469 | self.ckp.write_log('Teacher Evaluation')
470 |
471 | ############################## Num of Params ####################
472 | n_params = 0
473 | for k, v in self.model.named_parameters():
474 | if '_a' not in k and '_w' not in k and 'measure' not in k: # for teacher model
475 | n_params += np.prod(v.size())
476 | self.ckp.write_log('Parameters: {:.3f}K'.format(n_params/(10**3)))
477 |
478 | if self.args.save_results:
479 | self.ckp.begin_background()
480 |
481 | ############################## TEST FOR OWN #############################
482 | if self.args.test_own is not None:
483 | test_img = cv2.imread(self.args.test_own)
484 | lr = torch.tensor(test_img).permute(2,0,1).float().cuda()
485 | lr = torch.flip(lr, (0,)) # for color
486 | lr = lr.unsqueeze(0)
487 |
488 | tot_bit = 0
489 | for idx_scale, scale in enumerate(self.scale):
490 | self.set_bit(teacher=True)
491 | if self.args.test_patch:
492 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
493 | img_bit = bit
494 | else:
495 | with torch.no_grad():
496 | sr, feat, bit = self.model(lr, idx_scale)
497 | img_bit = bit.mean() / self.num_quant_modules
498 | self.set_bit(teacher=False)
499 |
500 | sr = utility.quantize(sr, self.args.rgb_range)
501 | save_list = [sr]
502 | if self.args.save_results:
503 | filename = self.args.test_own.split('/')[-1].split('.')[0]
504 | save_name = '{}_x{}_{:.2f}bit'.format(filename, scale, img_bit)
505 | self.ckp.save_results('test_own', save_name, save_list)
506 |
507 | ############################## TEST FOR TEST SET #############################
508 | if self.args.test_own is None:
509 | for idx_data, d in enumerate(self.loader_test):
510 | for idx_scale, scale in enumerate(self.scale):
511 | d.dataset.set_scale(idx_scale)
512 | tot_ssim =0
513 | tot_bit =0
514 | tot_psnr =0.0
515 | i=0
516 | for lr, hr, filename in tqdm(d, ncols=80):
517 | i+=1
518 | lr, hr = self.prepare(lr, hr)
519 | self.set_bit(teacher=True)
520 | if self.args.test_patch:
521 | sr, feat, bit = self.patch_inference(self.model, lr, idx_scale)
522 | img_bit = bit
523 | else:
524 | with torch.no_grad():
525 | sr, feat, bit = self.model(lr, idx_scale)
526 | img_bit = bit.mean() / self.num_quant_modules
527 | self.set_bit(teacher=False)
528 |
529 | sr = utility.quantize(sr, self.args.rgb_range)
530 | save_list = [sr]
531 | psnr, ssim = utility.calc_psnr(sr, hr, scale, self.args.rgb_range, dataset=d)
532 |
533 | tot_bit += img_bit
534 | tot_psnr += psnr
535 | tot_ssim += ssim
536 |
537 | if self.args.save_gt:
538 | save_list.extend([lr, hr])
539 |
540 | if self.args.save_results:
541 | save_name = '{}_x{}_{:.2f}dB'.format(filename[0], scale, cur_psnr)
542 | self.ckp.save_results(d, save_name, save_list)
543 |
544 | tot_psnr /= len(d)
545 | tot_ssim /= len(d)
546 | tot_bit /= len(d)
547 |
548 | self.ckp.write_log(
549 | '[{} x{}]\tPSNR: {:.3f} \t SSIM: {:.4f} \tBit: {:.2f}'.format(
550 | d.dataset.name,
551 | scale,
552 | tot_psnr,
553 | tot_ssim,
554 | tot_bit.item(),
555 | )
556 | )
557 |
558 | if self.args.save_results:
559 | self.ckp.end_background()
560 |
561 | torch.set_grad_enabled(True)
562 |
563 |
564 |
565 | def prepare(self, *args):
566 | device = torch.device('cpu' if self.args.cpu else 'cuda')
567 | def _prepare(tensor):
568 | if self.args.precision == 'half': tensor = tensor.half()
569 | return tensor.to(device)
570 |
571 | return [_prepare(a) for a in args]
572 |
573 | def terminate(self):
574 | if self.args.test_only:
575 | self.test()
576 | return True
577 | else:
578 | # return self.epoch >= self.args.epochs
579 | return self.epoch > self.args.epochs
580 |
--------------------------------------------------------------------------------
/src/utility.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | from multiprocessing import Process
6 | from multiprocessing import Queue
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 |
12 | import numpy as np
13 | import imageio
14 |
15 | import torch
16 | import torch.optim as optim
17 | import torch.optim.lr_scheduler as lrs
18 |
19 | from pathlib import Path
20 | import shutil
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | import logging
24 | import coloredlogs
25 | import cv2
26 | import functools
27 | from torchvision.utils import make_grid
28 | from decimal import Decimal
29 | from math import exp
30 |
31 |
32 | class AverageMeter(object):
33 | def __init__(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 |
39 | def reset(self):
40 | self.val = 0
41 | self.avg = 0
42 | self.sum = 0
43 | self.count = 0
44 |
45 | def update(self, val, n=1):
46 | self.val = val
47 | self.sum += val * n
48 | self.count += n
49 | if self.count > 0:
50 | self.avg = self.sum / self.count
51 |
52 | def accumulate(self, val, n=1):
53 | self.sum += val
54 | self.count += n
55 | if self.count > 0:
56 | self.avg = self.sum / self.count
57 |
58 | class timer():
59 | def __init__(self):
60 | self.acc = 0
61 | self.tic()
62 |
63 | def tic(self):
64 | self.t0 = time.time()
65 |
66 | def toc(self, restart=False):
67 | diff = time.time() - self.t0
68 | if restart: self.t0 = time.time()
69 | return diff
70 |
71 | def hold(self):
72 | self.acc += self.toc()
73 |
74 | def release(self):
75 | ret = self.acc
76 | self.acc = 0
77 |
78 | return ret
79 |
80 | def reset(self):
81 | self.acc = 0
82 |
83 | class checkpoint():
84 | def __init__(self, args):
85 | self.args = args
86 | self.ok = True
87 | self.log = torch.Tensor()
88 | self.ssim_log = torch.Tensor()
89 | self.bit_log = torch.Tensor()
90 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
91 |
92 | if not args.load:
93 | if not args.save:
94 | args.save = now
95 | # self.dir = os.path.join('..', 'experiment', '{}_sd{}'.format(args.save, args.seed))
96 | self.dir = os.path.join('..', 'experiment', '{}'.format(args.save))
97 | else:
98 | self.dir = os.path.join('..', 'experiment', args.load)
99 | if os.path.exists(self.dir):
100 | self.log = torch.load(self.get_path('psnr_log.pt'))
101 | print('Continue from epoch {}...'.format(len(self.log)))
102 | else:
103 | args.load = ''
104 |
105 | if args.reset:
106 | os.system('rm -rf ' + self.dir)
107 | args.load = ''
108 |
109 | os.makedirs(self.dir, exist_ok=True)
110 | os.makedirs(self.get_path('model'), exist_ok=True)
111 | for d in args.data_test:
112 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
113 | if args.test_own is not None:
114 | os.makedirs(self.get_path('results-test_own'), exist_ok=True)
115 |
116 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
117 | self.log_file = open(self.get_path('log.txt'), open_type)
118 | with open(self.get_path('config.txt'), open_type) as f:
119 | f.write(now + '\n\n')
120 | for arg in vars(args):
121 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
122 | f.write('\n')
123 |
124 | self.n_processes = 8
125 |
126 | def get_path(self, *subdir):
127 | return os.path.join(self.dir, *subdir)
128 |
129 | def save(self, trainer, epoch, is_best=False):
130 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
131 |
132 | # self.plot_bit(epoch)
133 | # self.plot_psnr(epoch)
134 |
135 | def add_log(self, log):
136 | self.log = torch.cat([self.log, log])
137 | self.ssim_log = torch.cat([self.ssim_log, log])
138 | self.bit_log = torch.cat([self.bit_log, log])
139 |
140 | def write_log(self, log, refresh=False):
141 | print(log)
142 | self.log_file.write(log + '\n')
143 | if refresh:
144 | self.log_file.close()
145 | self.log_file = open(self.get_path('log.txt'), 'a')
146 |
147 | def done(self):
148 | self.log_file.close()
149 |
150 | def plot_psnr(self, epoch):
151 | # axis = np.linspace(1, epoch, epoch)
152 | axis = np.linspace(1, epoch-1, epoch-1)
153 | for idx_data, d in enumerate(self.args.data_test):
154 | label = 'SR on {}'.format(d)
155 | fig = plt.figure()
156 | plt.title(label)
157 | for idx_scale, scale in enumerate(self.args.scale):
158 | plt.plot(
159 | axis,
160 | self.log[:, idx_data, idx_scale].numpy(),
161 | label='Scale {}'.format(scale)
162 | )
163 | plt.legend()
164 | plt.xlabel('Epochs')
165 | plt.ylabel('PSNR')
166 | plt.grid(True)
167 | plt.savefig(self.get_path('test_{}.png'.format(d)))
168 |
169 | plt.close(fig)
170 |
171 | def plot_bit(self, epoch):
172 | # axis = np.linspace(1, epoch, epoch)
173 | axis = np.linspace(1, epoch-1, epoch-1)
174 | for idx_data, d in enumerate(self.args.data_test):
175 | label = 'SR on {}'.format(d)
176 | fig = plt.figure()
177 | plt.title(label)
178 | for idx_scale, scale in enumerate(self.args.scale):
179 | plt.plot(
180 | axis,
181 | self.bit_log[:, idx_data, idx_scale].numpy(),
182 | label='Scale {}'.format(scale)
183 | )
184 | plt.legend()
185 | plt.xlabel('Epochs')
186 | plt.ylabel('Average Bit')
187 | plt.grid(True)
188 | plt.savefig(self.get_path('test_{}_bit.png'.format(d)))
189 |
190 | plt.close(fig)
191 |
192 | def begin_background(self):
193 | self.queue = Queue()
194 |
195 | def bg_target(queue):
196 | while True:
197 | if not queue.empty():
198 | filename, tensor = queue.get()
199 | if filename is None: break
200 | imageio.imwrite(filename, tensor.numpy())
201 |
202 | self.process = [
203 | Process(target=bg_target, args=(self.queue,)) \
204 | for _ in range(self.n_processes)
205 | ]
206 |
207 | for p in self.process: p.start()
208 |
209 | def end_background(self):
210 | for _ in range(self.n_processes): self.queue.put((None, None))
211 | while not self.queue.empty(): time.sleep(1)
212 | for p in self.process: p.join()
213 |
214 | def save_results(self, dataset, filename, save_list):
215 | if self.args.save_results:
216 | if isinstance(dataset, str):
217 | filename = self.get_path(
218 | 'results-{}'.format(dataset),
219 | '{}'.format(filename)
220 | )
221 | else:
222 | filename = self.get_path(
223 | 'results-{}'.format(dataset.dataset.name),
224 | '{}'.format(filename)
225 | )
226 |
227 | postfix = ('', 'LR', 'HR')
228 | for v, p in zip(save_list, postfix):
229 | normalized = v[0].mul(255 / self.args.rgb_range)
230 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
231 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
232 |
233 | def quantize(img, rgb_range):
234 | pixel_range = 255 / rgb_range
235 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
236 |
237 | def gaussian(window_size, sigma):
238 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
239 | return gauss/gauss.sum()
240 |
241 | def create_window_3d(window_size, channel=1):
242 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
243 | _2D_window = _1D_window.mm(_1D_window.t())
244 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
245 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda()
246 | return window
247 |
248 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
249 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
250 | if val_range is None:
251 | if torch.max(img1) > 128:
252 | max_val = 255
253 | else:
254 | max_val = 1
255 |
256 | if torch.min(img1) < -0.5:
257 | min_val = -1
258 | else:
259 | min_val = 0
260 | L = max_val - min_val
261 | else:
262 | L = val_range
263 |
264 | padd = 0
265 | (_, _, height, width) = img1.size()
266 | if window is None:
267 | real_size = min(window_size, height, width)
268 | window = create_window_3d(real_size, channel=1).to(img1.device)
269 | # Channel is set to 1 since we consider color images as volumetric images
270 |
271 | img1 = img1.unsqueeze(1)
272 | img2 = img2.unsqueeze(1)
273 |
274 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
275 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
276 |
277 | mu1_sq = mu1.pow(2)
278 | mu2_sq = mu2.pow(2)
279 | mu1_mu2 = mu1 * mu2
280 |
281 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
282 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
283 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
284 |
285 | C1 = (0.01 * L) ** 2
286 | C2 = (0.03 * L) ** 2
287 |
288 | v1 = 2.0 * sigma12 + C2
289 | v2 = sigma1_sq + sigma2_sq + C2
290 | cs = torch.mean(v1 / v2) # contrast sensitivity
291 |
292 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
293 |
294 | if size_average:
295 | ret = ssim_map.mean()
296 | else:
297 | ret = ssim_map.mean(1).mean(1).mean(1)
298 |
299 | if full:
300 | return ret, cs
301 | return ret
302 |
303 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
304 | if hr.nelement() == 1: return 0
305 |
306 | diff = (sr - hr) / rgb_range
307 | if dataset and dataset.dataset.benchmark:
308 | shave = scale
309 | if diff.size(1) > 1:
310 | gray_coeffs = [65.738, 129.057, 25.064]
311 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
312 | diff = diff.mul(convert).sum(dim=1)
313 | sr = sr.mul(convert).sum(dim=1)
314 | hr = hr.mul(convert).sum(dim=1)
315 | else:
316 | shave = scale + 6
317 |
318 | valid = diff[..., shave:-shave, shave:-shave]
319 | mse = valid.pow(2).mean()
320 |
321 | sr = sr[..., shave:-shave, shave:-shave]
322 | hr = hr[..., shave:-shave, shave:-shave]
323 |
324 | ssim_out = ssim_matlab(sr.unsqueeze(0), hr.unsqueeze(0), val_range=255).item()
325 |
326 | return -10 * math.log10(mse), ssim_out
327 |
328 |
329 | def crop(img, crop_sz, step):
330 | b, c, h, w = img.shape
331 | h_space = np.arange(0, max(h - crop_sz,0) + 1, step)
332 | w_space = np.arange(0, max(w - crop_sz,0) + 1, step)
333 | num_h = 0
334 | lr_list=[]
335 | for x in h_space:
336 | num_h += 1
337 | num_w = 0
338 | for y in w_space:
339 | num_w += 1
340 | # remaining borders are NOT clipped
341 | x_end = x + crop_sz if x != h_space[-1] else h
342 | y_end = y + crop_sz if y != w_space[-1] else w
343 |
344 | crop_img = img[:,:, x : x_end, y : y_end]
345 | lr_list.append(crop_img)
346 |
347 | return lr_list, num_h, num_w, h, w
348 |
349 | def crop_parallel(img, crop_sz, step):
350 | # remaining borders are clipped
351 | b, c, h, w = img.shape
352 | h_space = np.arange(0, h - crop_sz + 1, step)
353 | w_space = np.arange(0, w - crop_sz + 1, step)
354 | index = 0
355 | num_h = 0
356 | lr_list=torch.Tensor().to(img.device)
357 | for x in h_space:
358 | num_h += 1
359 | num_w = 0
360 | for y in w_space:
361 | num_w += 1
362 | index += 1
363 | crop_img = img[:, :, x:x + crop_sz, y:y + crop_sz]
364 | lr_list = torch.cat([lr_list, crop_img])
365 | new_h=x + crop_sz # new height after crop
366 | new_w=y + crop_sz # new width after crop
367 | return lr_list, num_h, num_w, new_h, new_w
368 |
369 |
370 | def combine(sr_list, num_h, num_w, h, w, patch_size, step, scale):
371 | index=0
372 | sr_img = torch.zeros((1, 3, h*scale, w*scale)).cuda()
373 | step = step * scale
374 | patch_size = patch_size * scale
375 |
376 | for x in range(num_h):
377 | for y in range(num_w):
378 | x_patch_size = sr_list[index].shape[2]
379 | y_patch_size = sr_list[index].shape[3]
380 | sr_img[:, :, x*step : x*step+x_patch_size, y*step : y*step+y_patch_size] += sr_list[index]
381 | index += 1
382 |
383 | # mean the overlap region
384 | for x in range(1, num_h):
385 | sr_img[:, :, x*step : x*step+ (patch_size - step), :]/=2
386 | for y in range(1, num_w):
387 | sr_img[:, :, :, y*step : y*step+ (patch_size - step)]/=2
388 |
389 | return sr_img
390 |
391 |
--------------------------------------------------------------------------------