├── .gitignore ├── KernelGAN.yml ├── LICENSE.txt ├── README.md ├── ZSSRforKernelGAN ├── ZSSR.py ├── zssr_configs.py └── zssr_utils.py ├── configs.py ├── data.py ├── imresize.py ├── kernelGAN.py ├── learner.py ├── loss.py ├── networks.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # IDE 127 | .idea/ 128 | .vscode/ 129 | -------------------------------------------------------------------------------- /KernelGAN.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | 5 | dependencies: 6 | - python=3.6.6 7 | - cudatoolkit>=9.0 8 | - numpy=1.15.1 9 | - pytorch=1.0.0 10 | - tensorflow-gpu=1.12.0 11 | - tqdm=4.26.0 12 | - scipy=1.1.0 13 | - pillow=5.2.0 14 | - opencv=3.4.2 15 | - matplotlib=2.2.3 16 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | The Weizmann Institute of Science 3 | Academic Non Commercial Software Code License 4 | [Blind Super-Resolution Kernel Estimation using an Internal-GAN] (the "Work") 5 | © 2018 The Weizmann Institute of Science ("WIS") and Yeda Research and Development Company Ltd. ("Yeda") All Rights Reserved 6 | 7 | 1. YEDA, the commercial arm of WIS, hereby grants you, an individual or a legal entity exercising rights under, and complying with all of the provisions, of this License (“You”) a royalty-free, non-exclusive, sublicensable, worldwide license to: use, copy, modify, create derivative works (including without limiting to: adapt, alter, transform), integrate with other works, distribute, enable access (including without limiting to: communicate copies), publicly display and perform the Work in binary form or in source code, for academic and noncommercial use only and subject to all provisions of this License: 8 | 2. YEDA hereby grants You a royalty-free, non-exclusive, sublicensable, worldwide license under patents claimed or owned by YEDA that are embodied in the Work, to make, have made and use the Work under the License, for avoidance of doubt for academic and noncommercial use only. 9 | 3. Distribution or provision of access to the Work and to derivative works of the Work ("Derivative Works") may be made only under this License, accompanied with a copy of the source code or a reference to an online repository where such source code can be accessed. 10 | 4. Neither the names of WIS or Yeda, nor any of their trademarks or service marks, may be used to endorse or promote Derivative Works or for any other purpose except as expressly permitted hereunder. 11 | 5. Except as expressly stated in this License, nothing in this License grants any license to trademarks, copyrights, patents, trade secrets or any other intellectual property of WIS or Yeda. No license is granted to the trademarks of WIS or Yeda's even if such marks are included in the Work. 12 | 6. Nothing in this License shall be interpreted to prohibit WIS or Yeda from licensing the Work under terms different from this License. For commercial use please e-mail Yeda at: info.yeda@weizmann.ac.il 13 | 7. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Work, as well as a notice to inform recipients that You have modified the Work with a description of such modifications. 14 | 8. THE WORK IS PROVIDED "AS IS" AND WITHOUT ANY WARRANTIES WHATSOEVER, EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION ANY WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. 15 | 9. IN NO EVENT WILL WIS, YEDA OR ANY OF THEIR RELATED ENTITES, SCIENTISTS, EMPLOYEES, MANAGERS OR ANY OTHE PERSON ACTING ON THEIR BEHALF, BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY OR CAUSE OF ACTION, WHETHER IN CONTRACT, TORT, STRICT LIABILITY, UNJUST ENRICHMENT OR ANY OTHER, ARISING IN ANY WAY OUT OF THE USE OF THE WORK OR THIS LICENSE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 16 | 10. This License will terminate automatically if any of its conditions is not met, or in case You commence an action, including a cross-claim or counterclaim, against WIS or YEDA or any licensee alleging that the Work (except due to combination with other software or hardware) infringes a patent. 17 | 11. This License shall be exclusively governed by the laws of the State of Israel, without giving effect to conflict of laws principles, and the competent courts in Tel Aviv will have exclusive jurisdiction and venue over any matter between You and WIS or YEDA or any of their related entities relating to this License or the Work. 18 | 12. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. 19 | 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Blind Super-Resolution Kernel Estimation using an Internal-GAN 2 | # "KernelGAN" 3 | ### Sefi Bell-Kligler, Assaf Shocher, Michal Irani 4 | *(Official implementation)* 5 | 6 | Paper: https://arxiv.org/abs/1909.06581 7 | 8 | Project page: http://www.wisdom.weizmann.ac.il/~vision/kernelgan/ 9 | 10 | **Accepted NeurIPS 2019 (oral)** 11 | 12 | 13 | ## Usage: 14 | 15 | ### Quick usage on your data: 16 | To run KernelGAN on all images in : 17 | 18 | ``` python train.py --input-dir ``` 19 | 20 | 21 | This will produce kernel estimations in the results folder 22 | 23 | ### Extra configurations: 24 | ```--X4``` : Estimate the X4 kernel 25 | 26 | ```--SR``` : Perform ZSSR using the estimated kernel 27 | 28 | ```--real``` : Real-image configuration (effects only the ZSSR) 29 | 30 | ```--output-dir``` : Output folder for the images (default is results) 31 | 32 | 33 | ### Data: 34 | Download the DIV2KRK dataset: [dropbox](http://www.wisdom.weizmann.ac.il/~vision/kernelgan/DIV2KRK_public.zip) 35 | 36 | Reproduction code for your own Blind-SR dataset: [github](https://github.com/assafshocher/BlindSR_dataset_generator) 37 | -------------------------------------------------------------------------------- /ZSSRforKernelGAN/ZSSR.py: -------------------------------------------------------------------------------- 1 | import matplotlib.image as img 2 | from ZSSRforKernelGAN.zssr_configs import Config 3 | from ZSSRforKernelGAN.zssr_utils import * 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | class ZSSR: 9 | # Basic current state variables initialization / declaration 10 | kernel = None 11 | learning_rate = None 12 | hr_father = None 13 | lr_son = None 14 | sr = None 15 | sf = None 16 | gt_per_sf = None 17 | final_sr = None 18 | hr_fathers_sources = [] 19 | 20 | # Output variables initialization / declaration 21 | reconstruct_output = None 22 | train_output = None 23 | output_shape = None 24 | 25 | # Counters and logs initialization 26 | iter = 0 27 | base_sf = 1.0 28 | base_ind = 0 29 | sf_ind = 0 30 | mse = [] 31 | mse_rec = [] 32 | interp_rec_mse = [] 33 | interp_mse = [] 34 | mse_steps = [] 35 | loss = [] 36 | learning_rate_change_iter_nums = [] 37 | fig = None 38 | 39 | # Network tensors (all tensors end with _t to distinguish) 40 | learning_rate_t = None 41 | lr_son_t = None 42 | hr_father_t = None 43 | filters_t = None 44 | layers_t = None 45 | net_output_t = None 46 | loss_t = None 47 | loss_map_t = None 48 | train_op = None 49 | init_op = None 50 | 51 | # Parameters related to plotting and graphics 52 | plots = None 53 | loss_plot_space = None 54 | lr_son_image_space = None 55 | hr_father_image_space = None 56 | out_image_space = None 57 | 58 | # A map representing the gradient magnitude of the image at every crop 59 | prob_map = None 60 | cropped_loss_map = None 61 | avg_grad = 1 62 | loss_map = [] 63 | loss_map_sources = [] 64 | 65 | # Tensorflow graph default 66 | sess = None 67 | 68 | def __init__(self, input_img_path, scale_factor=2, kernels=None, is_real_img=False, noise_scale=1.): 69 | # Acquire meta parameters configuration from configuration class as a class variable 70 | self.conf = Config(scale_factor, is_real_img, noise_scale) 71 | # Read input image 72 | self.input = img.imread(input_img_path) 73 | # Discard the alpha channel from images 74 | if self.input.shape[-1] == 4: 75 | self.input = img.imread(input_img_path)[:, :, :3] 76 | # For gray-scale images - add a 3rd dimension to fit the network 77 | elif len(self.input.shape) == 2: 78 | self.input = np.expand_dims(self.input, -1) 79 | self.input = self.input / 255. if self.input.dtype == 'uint8' else self.input 80 | self.gt = None 81 | # Shift kernel to avoid misalignment 82 | self.kernels = [kernel_shift(kernel, sf) for kernel, sf in zip(kernels, self.conf.scale_factors)] if kernels is not None else [self.conf.downscale_method] * len(self.conf.scale_factors) 83 | 84 | # Prepare TF default computational graph 85 | self.model = tf.Graph() 86 | 87 | # Build network computational graph 88 | self.build_network(self.conf) 89 | 90 | # Initialize network weights and meta parameters 91 | self.init_sess(init_weights=True) 92 | 93 | # The first hr father source is the input (source goes through augmentation to become a father) 94 | # Later on, if we use gradual sr increments, results for intermediate scales will be added as sources. 95 | self.hr_fathers_sources = [self.input] 96 | 97 | # Create a loss map reflecting the weights per pixel of the image 98 | self.loss_map = create_loss_map(im=self.input) if self.conf.grad_based_loss_map else np.ones_like(self.input) 99 | 100 | # loss maps that correspond to the father sources array 101 | self.loss_map_sources = [self.loss_map] 102 | 103 | def run(self): 104 | # Run gradually on all scale factors (if only one jump then this loop only happens once) 105 | for self.sf_ind, (sf, self.kernel) in enumerate(zip(self.conf.scale_factors, self.kernels)): 106 | # Relative_sf (used when base change is enabled. this is when input is the output of some previous scale) 107 | sf = [sf, sf] if np.isscalar(sf) else sf 108 | self.sf = np.array(sf) / np.array(self.base_sf) 109 | 110 | self.output_shape = np.uint(np.ceil(np.array(self.input.shape[0:2]) * sf)) 111 | 112 | # Initialize network 113 | self.init_sess(init_weights=self.conf.init_net_for_each_sf) 114 | 115 | # Train the network 116 | self.train() 117 | 118 | # Use augmented outputs and back projection to enhance result. Also save the result. 119 | post_processed_output = self.final_test() 120 | 121 | # Keep the results for the next scale factors SR to use as dataset 122 | self.hr_fathers_sources.append(post_processed_output) 123 | 124 | # append a corresponding map loss 125 | self.loss_map_sources.append(create_loss_map(im=post_processed_output)) if self.conf.grad_based_loss_map else self.loss_map_sources.append(np.ones_like(post_processed_output)) 126 | 127 | # In some cases, the current output becomes the new input. If indicated and if this is the right scale to 128 | # become the new base input. all of these conditions are checked inside the function. 129 | self.base_change() 130 | 131 | # Return the final post processed output. 132 | # noinspection PyUnboundLocalVariable 133 | return post_processed_output 134 | 135 | def build_network(self, meta): 136 | with self.model.as_default(): 137 | # Learning rate tensor 138 | self.learning_rate_t = tf.placeholder(tf.float32, name='learning_rate') 139 | 140 | # Input image 141 | self.lr_son_t = tf.placeholder(tf.float32, name='lr_son') 142 | 143 | # Ground truth (supervision) 144 | self.hr_father_t = tf.placeholder(tf.float32, name='hr_father') 145 | 146 | # Loss map 147 | self.loss_map_t = tf.placeholder(tf.float32, name='loss_map') 148 | 149 | # Filters 150 | self.filters_t = [tf.get_variable(shape=meta.filter_shape[ind], name='filter_%d' % ind, 151 | initializer=tf.random_normal_initializer( 152 | stddev=np.sqrt(meta.init_variance / np.prod( 153 | meta.filter_shape[ind][0:3])))) 154 | for ind in range(meta.depth)] 155 | 156 | # Activate filters on layers one by one (this is just building the graph, no calculation is done here) 157 | self.layers_t = [self.lr_son_t] + [None] * meta.depth 158 | for l in range(meta.depth - 1): 159 | self.layers_t[l + 1] = tf.nn.relu(tf.nn.conv2d(self.layers_t[l], self.filters_t[l], [1, 1, 1, 1], "SAME", name='layer_%d' % (l + 1))) 160 | 161 | # Last conv layer (Separate because no ReLU here) 162 | l = meta.depth - 1 163 | self.layers_t[-1] = tf.nn.conv2d(self.layers_t[l], self.filters_t[l], [1, 1, 1, 1], "SAME", name='layer_%d' % (l + 1)) 164 | 165 | # Output image (Add last conv layer result to input, residual learning with global skip connection) 166 | self.net_output_t = self.layers_t[-1] + self.conf.learn_residual * self.lr_son_t 167 | 168 | # Final loss (L1 loss between label and output layer) 169 | self.loss_t = tf.reduce_mean(tf.reshape(tf.abs(self.net_output_t - self.hr_father_t) * self.loss_map_t, [-1])) 170 | 171 | # Apply adam optimizer 172 | self.train_op = tf.train.AdamOptimizer(learning_rate=self.learning_rate_t).minimize(self.loss_t) 173 | # self.init_op = tf.initialize_all_variables() 174 | self.init_op = tf.global_variables_initializer() 175 | 176 | def init_sess(self, init_weights=True): 177 | # Sometimes we only want to initialize some meta-params but keep the weights as they were 178 | if init_weights: 179 | # These are for GPU consumption, preventing TF to catch all available GPUs 180 | config = tf.ConfigProto() 181 | config.gpu_options.allow_growth = True 182 | 183 | # Initialize computational graph session 184 | self.sess = tf.Session(graph=self.model, config=config) 185 | 186 | # Initialize weights 187 | self.sess.run(self.init_op) 188 | 189 | # Initialize all counters etc 190 | self.loss = [None] * self.conf.max_iters 191 | self.mse, self.mse_rec, self.interp_mse, self.interp_rec_mse, self.mse_steps = [], [], [], [], [] 192 | self.iter = 0 193 | self.learning_rate = self.conf.learning_rate 194 | self.learning_rate_change_iter_nums = [0] 195 | 196 | def forward_backward_pass(self, lr_son, hr_father, cropped_loss_map): 197 | # First gate for the lr-son into the network is interpolation to the size of the father 198 | # Note: we specify both output_size and scale_factor. best explained by example: say father size is 9 and sf=2, 199 | # small_son size is 4. if we upscale by sf=2 we get wrong size, if we upscale to size 9 we get wrong sf. 200 | # The current imresize implementation supports specifying both. 201 | interpolated_lr_son = imresize(lr_son, self.sf, hr_father.shape, self.conf.upscale_method) 202 | # Create feed dict 203 | feed_dict = {'learning_rate:0': self.learning_rate, 204 | 'lr_son:0': np.expand_dims(interpolated_lr_son, 0), 205 | 'hr_father:0': np.expand_dims(hr_father, 0), 206 | 'loss_map:0': np.expand_dims(cropped_loss_map, 0)} 207 | 208 | # Run network 209 | _, self.loss[self.iter], train_output = self.sess.run([self.train_op, self.loss_t, self.net_output_t], 210 | feed_dict) 211 | return np.clip(np.squeeze(train_output), 0, 1) 212 | 213 | def forward_pass(self, lr_son, hr_father_shape=None): 214 | # First gate for the lr-son into the network is interpolation to the size of the father 215 | interpolated_lr_son = imresize(lr_son, self.sf, hr_father_shape, self.conf.upscale_method) 216 | 217 | # Create feed dict 218 | feed_dict = {'lr_son:0': np.expand_dims(interpolated_lr_son, 0)} 219 | 220 | # Run network 221 | return np.clip(np.squeeze(self.sess.run([self.net_output_t], feed_dict)), 0, 1) 222 | 223 | def learning_rate_policy(self): 224 | # fit linear curve and check slope to determine whether to do nothing, reduce learning rate or finish 225 | if (not (1 + self.iter) % self.conf.learning_rate_policy_check_every 226 | and self.iter - self.learning_rate_change_iter_nums[-1] > self.conf.min_iters): 227 | # noinspection PyTupleAssignmentBalance 228 | [slope, _], [[var, _], _] = np.polyfit(self.mse_steps[-int(self.conf.learning_rate_slope_range / 229 | self.conf.run_test_every):], 230 | self.mse_rec[-int(self.conf.learning_rate_slope_range / 231 | self.conf.run_test_every):], 232 | 1, cov=True) 233 | 234 | # We take the the standard deviation as a measure 235 | std = np.sqrt(var) 236 | 237 | # Determine learning rate maintaining or reduction by the ration between slope and noise 238 | if -self.conf.learning_rate_change_ratio * slope < std: 239 | self.learning_rate /= 10 240 | 241 | # Keep track of learning rate changes for plotting purposes 242 | self.learning_rate_change_iter_nums.append(self.iter) 243 | 244 | def quick_test(self): 245 | # There are four evaluations needed to be calculated: 246 | 247 | # 1. True MSE (only if ground-truth was given), note: this error is before post-processing. 248 | # Run net on the input to get the output super-resolution (almost final result, only post-processing needed) 249 | self.sr = self.forward_pass(self.input) 250 | self.mse = (self.mse + [np.mean(np.ndarray.flatten(np.square(self.gt_per_sf - self.sr)))] 251 | if self.gt_per_sf is not None else None) 252 | 253 | # 2. Reconstruction MSE, run for reconstruction- try to reconstruct the input from a downscaled version of it 254 | self.reconstruct_output = self.forward_pass(self.father_to_son(self.input), self.input.shape) 255 | self.mse_rec.append(np.mean(np.ndarray.flatten(np.square(self.input - self.reconstruct_output)))) 256 | 257 | # 3. True MSE of simple interpolation for reference (only if ground-truth was given) 258 | if self.gt_per_sf is not None: 259 | interp_sr = imresize(self.input, self.sf, self.output_shape, self.conf.upscale_method) 260 | 261 | self.interp_mse = self.interp_mse + [np.mean(np.ndarray.flatten(np.square(self.gt_per_sf - interp_sr)))] 262 | else: 263 | self.interp_mse = None 264 | 265 | # 4. Reconstruction MSE of simple interpolation over downscaled input 266 | interp_rec = imresize(self.father_to_son(self.input), self.sf, self.input.shape[:], self.conf.upscale_method) 267 | 268 | self.interp_rec_mse.append(np.mean(np.ndarray.flatten(np.square(self.input - interp_rec)))) 269 | 270 | # Track the iters in which tests are made for the graphics x axis 271 | self.mse_steps.append(self.iter) 272 | 273 | def train(self): 274 | # main training loop 275 | for self.iter in range(self.conf.max_iters): 276 | # Use augmentation from original input image to create current father. 277 | # If other scale factors were applied before, their result is also used (hr_fathers_in) 278 | # crop_center = choose_center_of_crop(self.prob_map) if self.conf.choose_varying_crop else None 279 | crop_center = None 280 | 281 | self.hr_father, self.cropped_loss_map = \ 282 | random_augment(ims=self.hr_fathers_sources, 283 | base_scales=[1.0] + self.conf.scale_factors, 284 | leave_as_is_probability=self.conf.augment_leave_as_is_probability, 285 | no_interpolate_probability=self.conf.augment_no_interpolate_probability, 286 | min_scale=self.conf.augment_min_scale, 287 | max_scale=([1.0] + self.conf.scale_factors)[len(self.hr_fathers_sources) - 1], 288 | allow_rotation=self.conf.augment_allow_rotation, 289 | scale_diff_sigma=self.conf.augment_scale_diff_sigma, 290 | shear_sigma=self.conf.augment_shear_sigma, 291 | crop_size=self.conf.crop_size, 292 | allow_scale_in_no_interp=self.conf.allow_scale_in_no_interp, 293 | crop_center=crop_center, 294 | loss_map_sources=self.loss_map_sources) 295 | 296 | # Get lr-son from hr-father 297 | self.lr_son = self.father_to_son(self.hr_father) 298 | # run network forward and back propagation, one iteration (This is the heart of the training) 299 | self.train_output = self.forward_backward_pass(self.lr_son, self.hr_father, self.cropped_loss_map) 300 | 301 | # Test network 302 | if self.conf.run_test and (not self.iter % self.conf.run_test_every): 303 | self.quick_test() 304 | 305 | # Consider changing learning rate or stop according to iteration number and losses slope 306 | self.learning_rate_policy() 307 | 308 | # stop when minimum learning rate was passed 309 | if self.learning_rate < self.conf.min_learning_rate: 310 | break 311 | 312 | def father_to_son(self, hr_father): 313 | # Create son out of the father by downscaling and if indicated adding noise 314 | lr_son = imresize(hr_father, 1.0 / self.sf, kernel=self.kernel) 315 | return np.clip(lr_son + np.random.randn(*lr_son.shape) * self.conf.noise_std, 0, 1) 316 | 317 | def final_test(self): 318 | # Run over 8 augmentations of input - 4 rotations and mirror (geometric self ensemble) 319 | # The weird range means we only do it once if output_flip is disabled 320 | # We need to check if scale factor is symmetric to all dimensions, if not we will do 180 jumps rather than 90 321 | 322 | outputs = [] 323 | for k in range(0, 1 + 7 * self.conf.output_flip, 1 + int(self.sf[0] != self.sf[1])): 324 | # Rotate 90*k degrees & mirror flip when k>=4 325 | test_input = np.rot90(self.input, k) if k < 4 else np.fliplr(np.rot90(self.input, k)) 326 | 327 | # Apply network on the rotated input 328 | tmp_output = self.forward_pass(test_input) 329 | 330 | # Undo the rotation for the processed output (mind the opposite order of the flip and the rotation) 331 | tmp_output = np.rot90(tmp_output, -k) if k < 4 else np.rot90(np.fliplr(tmp_output), -k) 332 | 333 | # fix SR output with back projection technique for each augmentation 334 | for bp_iter in range(self.conf.back_projection_iters[self.sf_ind]): 335 | tmp_output = back_projection(tmp_output, self.input, down_kernel=self.kernel, 336 | up_kernel=self.conf.upscale_method, sf=self.sf) 337 | 338 | # save outputs from all augmentations 339 | outputs.append(tmp_output) 340 | 341 | # Take the median over all 8 outputs 342 | almost_final_sr = np.median(outputs, 0) 343 | 344 | # Again back projection for the final fused result 345 | for bp_iter in range(self.conf.back_projection_iters[self.sf_ind]): 346 | almost_final_sr = back_projection(almost_final_sr, self.input, down_kernel=self.kernel, 347 | up_kernel=self.conf.upscale_method, sf=self.sf) 348 | 349 | # Now we can keep the final result (in grayscale case, colors still need to be added, but we don't care 350 | # because it is done before saving and for every other purpose we use this result) 351 | # noinspection PyUnboundLocalVariable 352 | self.final_sr = almost_final_sr 353 | 354 | # Add colors to result image in case net was activated only on grayscale 355 | return self.final_sr 356 | 357 | def base_change(self): 358 | # If there is no base scale large than the current one get out of here 359 | if len(self.conf.base_change_sfs) < self.base_ind + 1: 360 | return 361 | 362 | # Change base input image if required (this means current output becomes the new input) 363 | if abs(self.conf.scale_factors[self.sf_ind] - self.conf.base_change_sfs[self.base_ind]) < 0.001: 364 | if len(self.conf.base_change_sfs) > self.base_ind: 365 | # The new input is the current output 366 | self.input = self.final_sr 367 | 368 | # The new base scale_factor 369 | self.base_sf = self.conf.base_change_sfs[self.base_ind] 370 | 371 | # Keeping track- this is the index inside the base scales list (provided in the config) 372 | self.base_ind += 1 373 | 374 | print('base changed to %.2f' % self.base_sf) 375 | -------------------------------------------------------------------------------- /ZSSRforKernelGAN/zssr_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config: 5 | # Network meta params 6 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 7 | scale_factors = [[2.0, 2.0]] # list of pairs (vertical, horizontal) for gradual increments in resolution 8 | base_change_sfs = [] # list of scales after which the input is changed to be the output (recommended for high sfs) 9 | max_iters = 3000 10 | min_iters = 256 11 | min_learning_rate = 9e-6 # this tells the algorithm when to stop (specify lower than the last learning-rate) 12 | width = 64 13 | depth = 8 14 | # network meta params that by default are determined (by other params) by other params but can be changed 15 | filter_shape = ([[3, 3, 3, width]] + [[3, 3, width, width]] * (depth - 2) + [[3, 3, width, 3]]) 16 | 17 | output_flip = False # changed from True # geometric self-ensemble (see paper) 18 | downscale_method = 'cubic' # a string ('cubic', 'linear'...), has no meaning if kernel given 19 | upscale_method = 'cubic' # this is the base interpolation from which we learn the residual (same options as above) 20 | downscale_gt_method = 'cubic' # when ground-truth given and intermediate scales tested, we shrink gt to wanted size 21 | learn_residual = True # when true, we only learn the residual from base interpolation 22 | init_variance = 0.1 # variance of weight initializations, typically smaller when residual learning is on 23 | back_projection_iters = [2] # for each scale num of bp iterations (same length as scale_factors) 24 | random_crop = True 25 | crop_size = 128 26 | noise_std = 0. # adding noise to lr-sons. small for real images, bigger for noisy images and zero for ideal case 27 | init_net_for_each_sf = False # for gradual sr- should we optimize from the last sf or initialize each time? 28 | 29 | # Params concerning learning rate policy 30 | learning_rate = 0.001 31 | learning_rate_change_ratio = 1.5 # ratio between STD and slope of linear fit, under which lr is reduced 32 | learning_rate_policy_check_every = 60 33 | learning_rate_slope_range = 256 34 | 35 | # Data augmentation related params 36 | augment_leave_as_is_probability = 1 # changed from 0.05 37 | augment_no_interpolate_probability = 0 # changed from 0.45 38 | augment_min_scale = 0 # changed from 0.5 39 | augment_scale_diff_sigma = 0 # changed from 0.25 40 | augment_shear_sigma = 0 # changed from 0.1 41 | augment_allow_rotation = False # changed from True # recommended false for non-symmetric kernels 42 | 43 | # params related to test and display 44 | run_test = True 45 | run_test_every = 50 46 | 47 | allow_scale_in_no_interp = False 48 | grad_based_loss_map = True # In the case a loss should be calculated w.r.t gradient map 49 | 50 | def __init__(self, scale_factor, is_real_img, noise_scale): 51 | self.scale_factors = [[scale_factor, scale_factor]] if type(scale_factor) is int else scale_factor 52 | if is_real_img: 53 | print('\nZSSR configuration is for a real image') 54 | self.back_projection_iters = [0] # no B.P 55 | self.noise_std = 0.0125 * noise_scale # Add noise to sons 56 | if type(self.scale_factors[0]) is list: # for gradual SR 57 | self.back_projection_iters = [self.back_projection_iters[0], self.back_projection_iters[0]] 58 | 59 | -------------------------------------------------------------------------------- /ZSSRforKernelGAN/zssr_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import pi, sin, cos 3 | from cv2 import warpPerspective, INTER_CUBIC 4 | from imresize import imresize 5 | from scipy.ndimage import measurements, interpolation 6 | from scipy.io import loadmat 7 | from scipy.signal import convolve2d 8 | import tensorflow as tf 9 | from random import sample 10 | 11 | 12 | def random_augment(ims, 13 | base_scales=None, 14 | leave_as_is_probability=0.2, 15 | no_interpolate_probability=0.3, 16 | min_scale=0.5, 17 | max_scale=1.0, 18 | allow_rotation=True, 19 | scale_diff_sigma=0.01, 20 | shear_sigma=0.01, 21 | crop_size=128, 22 | allow_scale_in_no_interp=False, 23 | crop_center=None, 24 | loss_map_sources=None): 25 | # Determine which kind of augmentation takes place according to probabilities 26 | random_chooser = np.random.rand() 27 | 28 | # Option 1: No augmentation, return the original image 29 | if random_chooser < leave_as_is_probability: 30 | mode = 'leave_as_is' 31 | 32 | # Option 2: Only non-interpolated augmentation, which means 8 possible augmentations (4 rotations X 2 mirror flips) 33 | elif leave_as_is_probability < random_chooser < leave_as_is_probability + no_interpolate_probability: 34 | mode = 'no_interp' 35 | 36 | # Option 3: Affine transformation (uses interpolation) 37 | else: 38 | mode = 'affine' 39 | 40 | # If scales not given, calculate them according to sizes of images. This would be suboptimal, because when scales 41 | # are not integers, different scales can have the same image shape. 42 | if base_scales is None: 43 | base_scales = [np.sqrt(np.prod(im.shape) / np.prod(ims[0].shape)) for im in ims] 44 | 45 | # In case scale is a list of scales with take the smallest one to be the allowed minimum 46 | max_scale = np.min([max_scale]) 47 | 48 | # Determine a random scale by probability 49 | if mode == 'leave_as_is': 50 | scale = 1.0 51 | else: 52 | scale = np.random.rand() * (max_scale - min_scale) + min_scale 53 | 54 | # The image we will use is the smallest one that is bigger than the wanted scale 55 | # (Using a small value overlap instead of >= to prevent float issues) 56 | scale_ind, base_scale = next((ind, np.min([base_scale])) for ind, base_scale in enumerate(base_scales) 57 | if np.min([base_scale]) > scale - 1.0e-6) 58 | im = ims[scale_ind] 59 | 60 | # Next are matrices whose multiplication will be the transformation. All are 3x3 matrices. 61 | 62 | # First matrix shifts image to center so that crop is in the center of the image 63 | shift_to_center_mat = np.array([[1, 0, - im.shape[1] / 2.0], 64 | [0, 1, - im.shape[0] / 2.0], 65 | [0, 0, 1]]) 66 | 67 | shift_back_from_center = np.array([[1, 0, im.shape[1] / 2.0], 68 | [0, 1, im.shape[0] / 2.0], 69 | [0, 0, 1]]) 70 | # Keeping the transform interpolation free means only shifting by integers 71 | if mode != 'affine': 72 | shift_to_center_mat = np.round(shift_to_center_mat) 73 | shift_back_from_center = np.round(shift_back_from_center) 74 | 75 | # Scale matrix. if affine, first determine global scale by probability, then determine difference between x scale 76 | # and y scale by gaussian probability. 77 | if mode == 'affine' or (mode == 'no_interp' and allow_scale_in_no_interp): 78 | scale /= base_scale 79 | scale_diff = np.random.randn() * scale_diff_sigma 80 | else: 81 | scale = 1.0 82 | scale_diff = 0.0 83 | 84 | # In this matrix we also incorporate the possibility of mirror reflection (unless leave_as_is). 85 | if mode == 'leave_as_is' or not allow_rotation: 86 | reflect = 1 87 | else: 88 | reflect = np.sign(np.random.randn()) 89 | 90 | scale_mat = np.array([[reflect * (scale + scale_diff / 2), 0, 0], 91 | [0, scale - scale_diff / 2, 0], 92 | [0, 0, 1]]) 93 | # If center of crop was provided 94 | if crop_center is not None: 95 | shift_y, shift_x = crop_center 96 | shift_x = shift_x - crop_size / 2 97 | shift_y = shift_y - crop_size / 2 98 | # Shift matrix, this actually creates the random crop 99 | else: 100 | shift_x = np.random.rand() * np.clip(scale * im.shape[1] - crop_size, 0, 9999) 101 | shift_y = np.random.rand() * np.clip(scale * im.shape[0] - crop_size, 0, 9999) 102 | 103 | # Rotation matrix angle. if affine, set a random angle. if no_interp then theta can only be pi/2 times int. 104 | rotation_indicator = 0 # used for finding the correct crop 105 | if mode == 'affine': 106 | theta = np.random.rand() * 2 * pi 107 | elif mode == 'no_interp': 108 | rotation_indicator = np.random.randint(4) 109 | theta = rotation_indicator * pi / 2 110 | else: 111 | theta = 0 112 | 113 | if not allow_rotation: 114 | theta = 0 115 | 116 | # Rotation matrix structure 117 | rotation_mat = np.array([[cos(theta), sin(theta), 0], 118 | [-sin(theta), cos(theta), 0], 119 | [0, 0, 1]]) 120 | 121 | if crop_center is not None: 122 | tmp_shift_y = shift_y 123 | rotation_indicator = (rotation_indicator * reflect) % 4 124 | if rotation_indicator == 1: 125 | shift_y = im.shape[1] - shift_x - crop_size 126 | shift_x = tmp_shift_y 127 | elif rotation_indicator == 2: 128 | shift_y = im.shape[0] - shift_y - crop_size 129 | shift_x = im.shape[1] - shift_x - crop_size 130 | elif rotation_indicator == 3: 131 | shift_y = shift_x 132 | shift_x = im.shape[0] - tmp_shift_y - crop_size 133 | 134 | shift_mat = np.array([[1, 0, - shift_x], 135 | [0, 1, - shift_y], 136 | [0, 0, 1]]) 137 | # Keeping the transform interpolation free means only shifting by integers 138 | if mode != 'affine': 139 | shift_mat = np.round(shift_mat) 140 | 141 | # Shear Matrix, only for affine transformation. 142 | if mode == 'affine' and allow_rotation: 143 | shear_x = np.random.randn() * shear_sigma 144 | shear_y = np.random.randn() * shear_sigma 145 | else: 146 | shear_x = shear_y = 0 147 | shear_mat = np.array([[1, shear_x, 0], 148 | [shear_y, 1, 0], 149 | [0, 0, 1]]) 150 | 151 | # Create the final transformation by multiplying all the transformations. 152 | transform_mat = (shift_back_from_center 153 | .dot(shift_mat) 154 | .dot(shear_mat) 155 | .dot(rotation_mat) 156 | .dot(scale_mat) 157 | .dot(shift_to_center_mat)) 158 | 159 | # Apply transformation to image and return the transformed image clipped between 0-1 160 | return np.clip(warpPerspective(im, transform_mat, (crop_size, crop_size), flags=INTER_CUBIC), 0, 1), \ 161 | warpPerspective(loss_map_sources[scale_ind], transform_mat, (crop_size, crop_size), flags=INTER_CUBIC) 162 | 163 | 164 | def preprocess_kernels(kernels, conf): 165 | # Load kernels if given files. if not just use the downscaling method from the configs. 166 | # output is a list of kernel-arrays or a a list of strings indicating downscaling method. 167 | # In case of arrays, we shift the kernels (see next function for explanation why). 168 | # Kernel is a .mat file (MATLAB) containing a variable called 'Kernel' which is a 2-dim matrix. 169 | if kernels is not None: 170 | return [kernel_shift(loadmat(kernel)['Kernel'], sf) 171 | for kernel, sf in zip(kernels, conf.scale_factors)] 172 | else: 173 | return [conf.downscale_method] * len(conf.scale_factors) 174 | 175 | 176 | def kernel_shift(kernel, sf): 177 | # There are two reasons for shifting the kernel : 178 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 179 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 180 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 181 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 182 | # top left corner of the first pixel. that is why different shift size needed between odd and even size. 183 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 184 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 185 | 186 | # First calculate the current center of mass for the kernel 187 | current_center_of_mass = measurements.center_of_mass(kernel) 188 | 189 | # The second term ("+ 0.5 * ....") is for applying condition 2 from the comments above 190 | wanted_center_of_mass = np.array(kernel.shape) // 2 + 0.5 * (np.array(sf) - (np.array(kernel.shape) % 2)) 191 | # Define the shift vector for the kernel shifting (x,y) 192 | shift_vec = wanted_center_of_mass - current_center_of_mass 193 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 194 | # (biggest shift among dims + 1 for safety) 195 | kernel = np.pad(kernel, np.int(np.ceil(np.max(np.abs(shift_vec)))) + 1, 'constant') 196 | 197 | # Finally shift the kernel and return 198 | kernel = interpolation.shift(kernel, shift_vec) 199 | 200 | return kernel 201 | 202 | 203 | def tensorshave(im, margin): 204 | shp = tf.shape(im) 205 | if shp[3] == 3: 206 | return im[:, margin:-margin, margin:-margin, :] 207 | else: 208 | return im[:, margin:-margin, margin:-margin] 209 | 210 | 211 | def rgb_augment(im, rndm=True, shuff_ind=0): 212 | if rndm: 213 | shuffle = sample(range(3), 3) 214 | else: 215 | shuffle = [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]] 216 | shuffle = shuffle[shuff_ind] 217 | 218 | return im[:, :, shuffle] 219 | 220 | 221 | def probability_map(im, crop_size): 222 | # margin of probabilities that will be zero 223 | margin = crop_size // 2 - 1 224 | prob_map = np.zeros(im.shape[0:2]) 225 | # Gradient calculation 226 | gx, gy, _ = np.gradient(im) 227 | grad_magnitude = np.sum(np.sqrt(gx ** 2 + gy ** 2), axis=2) 228 | # Convolving with rect to get a map of probabilities per crop 229 | rect = np.ones([crop_size - 3, crop_size - 3]) 230 | grad_magnitude_conv = convolve2d(grad_magnitude, rect, 'same') 231 | # Copying the values without the margins of the image 232 | prob_map[margin:-margin, margin:-margin] = grad_magnitude_conv[margin:-margin, margin:-margin] 233 | # normalize for probabilities 234 | sum_of_grads = np.sum(prob_map) 235 | prob_map = prob_map / sum_of_grads 236 | 237 | return prob_map 238 | 239 | 240 | def choose_center_of_crop(prob_map): 241 | # Retrieving a probability map and reshaping to be a vector 242 | prob_vector = np.reshape(prob_map, prob_map.shape[0] * prob_map.shape[1]) 243 | # creating a vector of indices to match the image 244 | indices = np.arange(start=0, stop=prob_map.shape[0] * prob_map.shape[1]) 245 | # Choosing an index according to the probabilities 246 | index_choice = np.random.choice(indices, p=prob_vector) 247 | # Translating to an index in the image - row, column 248 | return index_choice // prob_map.shape[1], index_choice % prob_map.shape[1] 249 | 250 | 251 | def create_loss_map(im, window=5, clip_rng=np.array([0.0, 255.0])): 252 | # Counting number of pixels for normalization issues 253 | numel = im.shape[0] * im.shape[1] 254 | # rgb2gray if image is in color 255 | gray = np.dot(im[:, :, 0:3], [0.299, 0.587, 0.114]) if len(im.shape) == 3 else im 256 | # Gradient calculation 257 | gx, gy = np.gradient(gray) 258 | gmag = np.sqrt(gx ** 2 + gy ** 2) 259 | processed_gmag = convolve2d(gmag, np.ones(shape=(window, window)), 'same') 260 | # pad the gmag with zeros the size of the process to eliminate artifacts 261 | margin = int((window + window % 2) / 2) 262 | loss_map = np.zeros_like(processed_gmag) 263 | # ignoring edges + clipping 264 | loss_map[margin:-margin, margin:-margin] = np.clip(processed_gmag[margin:-margin, margin:-margin], clip_rng[0], clip_rng[1]) 265 | # Normalizing the grad magnitude to sum to numel 266 | norm_factor = np.sum(loss_map) / numel 267 | loss_map = loss_map / norm_factor 268 | 269 | # In case the image is color, return 3 channels with the loss map duplicated 270 | if len(im.shape) == 3: 271 | loss_map = np.expand_dims(loss_map, axis=2) 272 | loss_map = np.append(np.append(loss_map, loss_map, axis=2), loss_map, axis=2) 273 | 274 | return loss_map 275 | 276 | 277 | def image_float2int(im): 278 | """converts a float image to uint""" 279 | if np.max(im) < 2: 280 | im = im * 255. 281 | return np.uint8(im) 282 | 283 | 284 | def image_int2float(im): 285 | """converts a uint image to float""" 286 | return np.float32(im) / 255. if np.max(im) > 2 else im 287 | 288 | 289 | def back_project_image(lr, sf=2, output_shape=None, down_kernel='cubic', up_kernel='cubic', bp_iters=8): 290 | """Runs 'bp_iters' iteration of back projection SR technique""" 291 | tmp_sr = imresize(lr, scale_factor=sf, output_shape=output_shape, kernel=up_kernel) 292 | for _ in range(bp_iters): 293 | tmp_sr = back_projection(y_sr=tmp_sr, y_lr=lr, down_kernel=down_kernel, up_kernel=up_kernel, sf=sf) 294 | return tmp_sr 295 | 296 | 297 | def back_projection(y_sr, y_lr, down_kernel, up_kernel, sf=None): 298 | """Projects the error between the downscaled SR image and the LR image""" 299 | y_sr += imresize(y_lr - imresize(y_sr, 300 | scale_factor=1.0 / sf, 301 | output_shape=y_lr.shape, 302 | kernel=down_kernel), 303 | scale_factor=sf, 304 | output_shape=y_sr.shape, 305 | kernel=up_kernel) 306 | return np.clip(y_sr, 0, 1) 307 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | 5 | 6 | # noinspection PyPep8 7 | class Config: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser() 10 | self.conf = None 11 | 12 | # Paths 13 | self.parser.add_argument('--img_name', default='image1', help='image name for saving purposes') 14 | self.parser.add_argument('--input_image_path', default=os.path.dirname(__file__) + '/training_data/input.png', help='path to one specific image file') 15 | self.parser.add_argument('--output_dir_path', default=os.path.dirname(__file__) + '/results', help='results path') 16 | 17 | # Sizes 18 | self.parser.add_argument('--input_crop_size', type=int, default=64, help='Generators crop size') 19 | self.parser.add_argument('--scale_factor', type=float, default=0.5, help='The downscaling scale factor') 20 | self.parser.add_argument('--X4', action='store_true', help='The wanted SR scale factor') 21 | 22 | # Network architecture 23 | self.parser.add_argument('--G_chan', type=int, default=64, help='# of channels in hidden layer in the G') 24 | self.parser.add_argument('--D_chan', type=int, default=64, help='# of channels in hidden layer in the D') 25 | self.parser.add_argument('--G_kernel_size', type=int, default=13, help='The kernel size G is estimating') 26 | self.parser.add_argument('--D_n_layers', type=int, default=7, help='Discriminators depth') 27 | self.parser.add_argument('--D_kernel_size', type=int, default=7, help='Discriminators convolution kernels size') 28 | 29 | # Iterations 30 | self.parser.add_argument('--max_iters', type=int, default=3000, help='# of iterations') 31 | 32 | # Optimization hyper-parameters 33 | self.parser.add_argument('--g_lr', type=float, default=2e-4, help='initial learning rate for generator') 34 | self.parser.add_argument('--d_lr', type=float, default=2e-4, help='initial learning rate for discriminator') 35 | self.parser.add_argument('--beta1', type=float, default=0.5, help='Adam momentum') 36 | 37 | # GPU 38 | self.parser.add_argument('--gpu_id', type=int, default=0, help='gpu id number') 39 | 40 | # Kernel post processing 41 | self.parser.add_argument('--n_filtering', type=float, default=40, help='Filtering small values of the kernel') 42 | 43 | # ZSSR configuration 44 | self.parser.add_argument('--do_ZSSR', action='store_true', help='when activated - ZSSR is not performed') 45 | self.parser.add_argument('--noise_scale', type=float, default=1., help='ZSSR uses this to partially de-noise images') 46 | self.parser.add_argument('--real_image', action='store_true', help='ZSSRs configuration is for real images') 47 | 48 | def parse(self, args=None): 49 | """Parse the configuration""" 50 | self.conf = self.parser.parse_args(args=args) 51 | self.set_gpu_device() 52 | self.clean_file_name() 53 | self.set_output_directory() 54 | self.conf.G_structure = [7, 5, 3, 1, 1, 1] 55 | print("Scale Factor: %s \tZSSR: %s \tReal Image: %s" % (('X4' if self.conf.X4 else 'X2'), str(self.conf.do_ZSSR), str(self.conf.real_image))) 56 | return self.conf 57 | 58 | def clean_file_name(self): 59 | """Retrieves the clean image file_name for saving purposes""" 60 | self.conf.img_name = self.conf.input_image_path.split('/')[-1].replace('ZSSR', '') \ 61 | .replace('real', '').replace('__', '').split('_.')[0].split('.')[0] 62 | 63 | def set_gpu_device(self): 64 | """Sets the GPU device if one is given""" 65 | if os.environ.get('CUDA_VISIBLE_DEVICES', '') == '': 66 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.conf.gpu_id) 67 | torch.cuda.set_device(0) 68 | else: 69 | torch.cuda.set_device(self.conf.gpu_id) 70 | 71 | def set_output_directory(self): 72 | """Define the output directory name and create the folder""" 73 | self.conf.output_dir_path = os.path.join(self.conf.output_dir_path, self.conf.img_name) 74 | # In case the folder exists - stack 'l's to the folder name 75 | while os.path.isdir(self.conf.output_dir_path): 76 | self.conf.output_dir_path += 'l' 77 | os.makedirs(self.conf.output_dir_path) 78 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | from imresize import imresize 4 | from util import read_image, create_gradient_map, im2tensor, create_probability_map, nn_interpolation 5 | 6 | 7 | class DataGenerator(Dataset): 8 | """ 9 | The data generator loads an image once, calculates it's gradient map on initialization and then outputs a cropped version 10 | of that image whenever called. 11 | """ 12 | 13 | def __init__(self, conf, gan): 14 | # Default shapes 15 | self.g_input_shape = conf.input_crop_size 16 | self.d_input_shape = gan.G.output_size # shape entering D downscaled by G 17 | self.d_output_shape = self.d_input_shape - gan.D.forward_shave 18 | 19 | # Read input image 20 | self.input_image = read_image(conf.input_image_path) / 255. 21 | self.shave_edges(scale_factor=conf.scale_factor, real_image=conf.real_image) 22 | 23 | self.in_rows, self.in_cols = self.input_image.shape[0:2] 24 | 25 | # Create prob map for choosing the crop 26 | self.crop_indices_for_g, self.crop_indices_for_d = self.make_list_of_crop_indices(conf=conf) 27 | 28 | def __len__(self): 29 | return 1 30 | 31 | def __getitem__(self, idx): 32 | """Get a crop for both G and D """ 33 | g_in = self.next_crop(for_g=True, idx=idx) 34 | d_in = self.next_crop(for_g=False, idx=idx) 35 | 36 | return g_in, d_in 37 | 38 | def next_crop(self, for_g, idx): 39 | """Return a crop according to the pre-determined list of indices. Noise is added to crops for D""" 40 | size = self.g_input_shape if for_g else self.d_input_shape 41 | top, left = self.get_top_left(size, for_g, idx) 42 | crop_im = self.input_image[top:top + size, left:left + size, :] 43 | if not for_g: # Add noise to the image for d 44 | crop_im += np.random.randn(*crop_im.shape) / 255.0 45 | return im2tensor(crop_im) 46 | 47 | def make_list_of_crop_indices(self, conf): 48 | iterations = conf.max_iters 49 | prob_map_big, prob_map_sml = self.create_prob_maps(scale_factor=conf.scale_factor) 50 | crop_indices_for_g = np.random.choice(a=len(prob_map_sml), size=iterations, p=prob_map_sml) 51 | crop_indices_for_d = np.random.choice(a=len(prob_map_big), size=iterations, p=prob_map_big) 52 | return crop_indices_for_g, crop_indices_for_d 53 | 54 | def create_prob_maps(self, scale_factor): 55 | # Create loss maps for input image and downscaled one 56 | loss_map_big = create_gradient_map(self.input_image) 57 | loss_map_sml = create_gradient_map(imresize(im=self.input_image, scale_factor=scale_factor, kernel='cubic')) 58 | # Create corresponding probability maps 59 | prob_map_big = create_probability_map(loss_map_big, self.d_input_shape) 60 | prob_map_sml = create_probability_map(nn_interpolation(loss_map_sml, int(1 / scale_factor)), self.g_input_shape) 61 | return prob_map_big, prob_map_sml 62 | 63 | def shave_edges(self, scale_factor, real_image): 64 | """Shave pixels from edges to avoid code-bugs""" 65 | # Crop 10 pixels to avoid boundaries effects in synthetically generated examples 66 | if not real_image: 67 | self.input_image = self.input_image[10:-10, 10:-10, :] 68 | # Crop pixels for the shape to be divisible by the scale factor 69 | sf = int(1 / scale_factor) 70 | shape = self.input_image.shape 71 | self.input_image = self.input_image[:-(shape[0] % sf), :, :] if shape[0] % sf > 0 else self.input_image 72 | self.input_image = self.input_image[:, :-(shape[1] % sf), :] if shape[1] % sf > 0 else self.input_image 73 | 74 | def get_top_left(self, size, for_g, idx): 75 | """Translate the center of the index of the crop to it's corresponding top-left""" 76 | center = self.crop_indices_for_g[idx] if for_g else self.crop_indices_for_d[idx] 77 | row, col = int(center / self.in_cols), center % self.in_cols 78 | top, left = min(max(0, row - size // 2), self.in_rows - size), min(max(0, col - size // 2), self.in_cols - size) 79 | # Choose even indices (to avoid misalignment with the loss map for_g) 80 | return top - top % 2, left - left % 2 -------------------------------------------------------------------------------- /imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import filters, measurements, interpolation 3 | from math import pi 4 | 5 | 6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 9 | 10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 13 | 14 | # Choose interpolation method, each method has the matching kernel size 15 | method, kernel_width = { 16 | "cubic": (cubic, 4.0), 17 | "lanczos2": (lanczos2, 4.0), 18 | "lanczos3": (lanczos3, 6.0), 19 | "box": (box, 1.0), 20 | "linear": (linear, 2.0), 21 | None: (cubic, 4.0) # Default interpolation method is cubic 22 | }.get(kernel) 23 | 24 | # Antialiasing is only used when downscaling 25 | antialiasing *= (scale_factor[0] < 1) 26 | 27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 29 | 30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 31 | out_im = np.copy(im) 32 | for dim in sorted_dims: 33 | # No point doing calculations for scale-factor 1. nothing will happen anyway 34 | if scale_factor[dim] == 1.0: 35 | continue 36 | 37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 38 | # weights that multiply the values there to get its result. 39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 40 | method, kernel_width, antialiasing) 41 | 42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 44 | 45 | return out_im 46 | 47 | 48 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 50 | # same size as the number of input dimensions) 51 | if scale_factor is not None: 52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 53 | if np.isscalar(scale_factor): 54 | scale_factor = [scale_factor, scale_factor] 55 | 56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 57 | scale_factor = list(scale_factor) 58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 59 | 60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 61 | # to all the unspecified dimensions 62 | if output_shape is not None: 63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 64 | 65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 66 | # sub-optimal, because there can be different scales to the same output-shape. 67 | if scale_factor is None: 68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 69 | 70 | # Dealing with missing output-shape. calculating according to scale-factor 71 | if output_shape is None: 72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 73 | 74 | return scale_factor, output_shape 75 | 76 | 77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 79 | # such that each position from the field_of_view will be multiplied with a matching filter from the 80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 81 | # around it. This is only done for one dimension of the image. 82 | 83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 84 | # 1/sf. this means filtering is more 'low-pass filter'. 85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 86 | kernel_width *= 1.0 / scale if antialiasing else 1.0 87 | 88 | # These are the coordinates of the output image 89 | out_coordinates = np.arange(1, out_length + 1) 90 | 91 | # These are the matching positions of the output-coordinates on the input image coordinates. 92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 101 | 102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 103 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 104 | 105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 107 | expanded_kernel_width = np.ceil(kernel_width) + 2 108 | 109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 111 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 113 | 114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 116 | # 'field_of_view') 117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 118 | 119 | # Normalize weights to sum up to 1. be careful from dividing by 0 120 | sum_weights = np.sum(weights, axis=1) 121 | sum_weights[sum_weights == 0] = 1.0 122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 123 | 124 | # We use this mirror structure as a trick for reflection padding at the boundaries 125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 127 | 128 | # Get rid of weights and pixel positions that are of zero weight 129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 130 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 132 | 133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 134 | return weights, field_of_view 135 | 136 | 137 | def resize_along_dim(im, dim, weights, field_of_view): 138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 139 | tmp_im = np.swapaxes(im, dim, 0) 140 | 141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 142 | # tmp_im[field_of_view.T], (bsxfun style) 143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 144 | 145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 150 | # same number 151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 152 | 153 | # Finally we swap back the axes to the original order 154 | return np.swapaxes(tmp_out_im, dim, 0) 155 | 156 | 157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 158 | # See kernel_shift function to understand what this is 159 | if kernel_shift_flag: 160 | kernel = kernel_shift(kernel, scale_factor) 161 | 162 | # First run a correlation (convolution with flipped kernel) 163 | out_im = np.zeros_like(im) 164 | for channel in range(np.ndim(im)): 165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 166 | 167 | # Then subsample and return 168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 170 | 171 | 172 | def kernel_shift(kernel, sf): 173 | # There are two reasons for shifting the kernel: 174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 175 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 178 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 181 | 182 | # First calculate the current center of mass for the kernel 183 | current_center_of_mass = measurements.center_of_mass(kernel) 184 | 185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 186 | wanted_center_of_mass = np.array(kernel.shape) // 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 187 | # wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (np.array(sf)[0:2] - (kernel.shape[0] % 2)) 188 | 189 | # Define the shift vector for the kernel shifting (x,y) 190 | shift_vec = wanted_center_of_mass - current_center_of_mass 191 | 192 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 193 | # (biggest shift among dims + 1 for safety) 194 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 195 | 196 | # Finally shift the kernel and return 197 | return interpolation.shift(kernel, shift_vec) 198 | 199 | 200 | # These next functions are all interpolation methods. x is the distance from the left pixel center 201 | 202 | 203 | def cubic(x): 204 | absx = np.abs(x) 205 | absx2 = absx ** 2 206 | absx3 = absx ** 3 207 | return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + 208 | (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2))) 209 | 210 | 211 | def lanczos2(x): 212 | return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) / 213 | ((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps)) 214 | * (abs(x) < 2)) 215 | 216 | 217 | def box(x): 218 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 219 | 220 | 221 | def lanczos3(x): 222 | return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) / 223 | ((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps)) 224 | * (abs(x) < 3)) 225 | 226 | 227 | def linear(x): 228 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 229 | -------------------------------------------------------------------------------- /kernelGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import loss 3 | import networks 4 | import torch.nn.functional as F 5 | from util import save_final_kernel, run_zssr, post_process_k 6 | 7 | 8 | class KernelGAN: 9 | # Constraint co-efficients 10 | lambda_sum2one = 0.5 11 | lambda_bicubic = 5 12 | lambda_boundaries = 0.5 13 | lambda_centralized = 0 14 | lambda_sparse = 0 15 | 16 | def __init__(self, conf): 17 | # Acquire configuration 18 | self.conf = conf 19 | 20 | # Define the GAN 21 | self.G = networks.Generator(conf).cuda() 22 | self.D = networks.Discriminator(conf).cuda() 23 | 24 | # Calculate D's input & output shape according to the shaving done by the networks 25 | self.d_input_shape = self.G.output_size 26 | self.d_output_shape = self.d_input_shape - self.D.forward_shave 27 | 28 | # Input tensors 29 | self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda() 30 | self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda() 31 | 32 | # The kernel G is imitating 33 | self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda() 34 | 35 | # Losses 36 | self.GAN_loss_layer = loss.GANLoss(d_last_layer_size=self.d_output_shape).cuda() 37 | self.bicubic_loss = loss.DownScaleLoss(scale_factor=conf.scale_factor).cuda() 38 | self.sum2one_loss = loss.SumOfWeightsLoss().cuda() 39 | self.boundaries_loss = loss.BoundariesLoss(k_size=conf.G_kernel_size).cuda() 40 | self.centralized_loss = loss.CentralizedLoss(k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda() 41 | self.sparse_loss = loss.SparsityLoss().cuda() 42 | self.loss_bicubic = 0 43 | 44 | # Define loss function 45 | self.criterionGAN = self.GAN_loss_layer.forward 46 | 47 | # Initialize networks weights 48 | self.G.apply(networks.weights_init_G) 49 | self.D.apply(networks.weights_init_D) 50 | 51 | # Optimizers 52 | self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999)) 53 | self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999)) 54 | 55 | print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path) 56 | 57 | # noinspection PyUnboundLocalVariable 58 | def calc_curr_k(self): 59 | """given a generator network, the function calculates the kernel it is imitating""" 60 | delta = torch.Tensor([1.]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).cuda() 61 | for ind, w in enumerate(self.G.parameters()): 62 | curr_k = F.conv2d(delta, w, padding=self.conf.G_kernel_size - 1) if ind == 0 else F.conv2d(curr_k, w) 63 | self.curr_k = curr_k.squeeze().flip([0, 1]) 64 | 65 | def train(self, g_input, d_input): 66 | self.set_input(g_input, d_input) 67 | self.train_g() 68 | self.train_d() 69 | 70 | def set_input(self, g_input, d_input): 71 | self.g_input = g_input.contiguous() 72 | self.d_input = d_input.contiguous() 73 | 74 | def train_g(self): 75 | # Zeroize gradients 76 | self.optimizer_G.zero_grad() 77 | # Generator forward pass 78 | g_pred = self.G.forward(self.g_input) 79 | # Pass Generators output through Discriminator 80 | d_pred_fake = self.D.forward(g_pred) 81 | # Calculate generator loss, based on discriminator prediction on generator result 82 | loss_g = self.criterionGAN(d_last_layer=d_pred_fake, is_d_input_real=True) 83 | # Sum all losses 84 | total_loss_g = loss_g + self.calc_constraints(g_pred) 85 | # Calculate gradients 86 | total_loss_g.backward() 87 | # Update weights 88 | self.optimizer_G.step() 89 | 90 | def calc_constraints(self, g_pred): 91 | # Calculate K which is equivalent to G 92 | self.calc_curr_k() 93 | # Calculate constraints 94 | self.loss_bicubic = self.bicubic_loss.forward(g_input=self.g_input, g_output=g_pred) 95 | loss_boundaries = self.boundaries_loss.forward(kernel=self.curr_k) 96 | loss_sum2one = self.sum2one_loss.forward(kernel=self.curr_k) 97 | loss_centralized = self.centralized_loss.forward(kernel=self.curr_k) 98 | loss_sparse = self.sparse_loss.forward(kernel=self.curr_k) 99 | # Apply constraints co-efficients 100 | return self.loss_bicubic * self.lambda_bicubic + loss_sum2one * self.lambda_sum2one + \ 101 | loss_boundaries * self.lambda_boundaries + loss_centralized * self.lambda_centralized + \ 102 | loss_sparse * self.lambda_sparse 103 | 104 | def train_d(self): 105 | # Zeroize gradients 106 | self.optimizer_D.zero_grad() 107 | # Discriminator forward pass over real example 108 | d_pred_real = self.D.forward(self.d_input) 109 | # Discriminator forward pass over fake example (generated by generator) 110 | # Note that generator result is detached so that gradients are not propagating back through generator 111 | g_output = self.G.forward(self.g_input) 112 | d_pred_fake = self.D.forward((g_output + torch.randn_like(g_output) / 255.).detach()) 113 | # Calculate discriminator loss 114 | loss_d_fake = self.criterionGAN(d_pred_fake, is_d_input_real=False) 115 | loss_d_real = self.criterionGAN(d_pred_real, is_d_input_real=True) 116 | loss_d = (loss_d_fake + loss_d_real) * 0.5 117 | # Calculate gradients, note that gradients are not propagating back through generator 118 | loss_d.backward() 119 | # Update weights, note that only discriminator weights are updated (by definition of the D optimizer) 120 | self.optimizer_D.step() 121 | 122 | def finish(self): 123 | final_kernel = post_process_k(self.curr_k, n=self.conf.n_filtering) 124 | save_final_kernel(final_kernel, self.conf) 125 | print('KernelGAN estimation complete!') 126 | run_zssr(final_kernel, self.conf) 127 | print('FINISHED RUN (see --%s-- folder)\n' % self.conf.output_dir_path + '*' * 60 + '\n\n') 128 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | class Learner: 2 | # Default hyper-parameters 3 | lambda_update_freq = 200 4 | bic_loss_to_start_change = 0.4 5 | lambda_bicubic_decay_rate = 100. 6 | update_l_rate_freq = 750 7 | update_l_rate_rate = 10. 8 | lambda_sparse_end = 5 9 | lambda_centralized_end = 1 10 | lambda_bicubic_min = 5e-6 11 | 12 | def __init__(self): 13 | self.bic_loss_counter = 0 14 | self.similar_to_bicubic = False # Flag indicating when the bicubic similarity is achieved 15 | self.insert_constraints = True # Flag is switched to false once constraints are added to the loss 16 | 17 | def update(self, iteration, gan): 18 | if iteration == 0: 19 | return 20 | # Update learning rate every update_l_rate freq 21 | if iteration % self.update_l_rate_freq == 0: 22 | for params in gan.optimizer_G.param_groups: 23 | params['lr'] /= self.update_l_rate_rate 24 | for params in gan.optimizer_D.param_groups: 25 | params['lr'] /= self.update_l_rate_rate 26 | 27 | # Until similar to bicubic is satisfied, don't update any other lambdas 28 | if not self.similar_to_bicubic: 29 | if gan.loss_bicubic < self.bic_loss_to_start_change: 30 | if self.bic_loss_counter >= 2: 31 | self.similar_to_bicubic = True 32 | else: 33 | self.bic_loss_counter += 1 34 | else: 35 | self.bic_loss_counter = 0 36 | # Once similar to bicubic is satisfied, consider inserting other constraints 37 | elif iteration % self.lambda_update_freq == 0 and gan.lambda_bicubic > self.lambda_bicubic_min: 38 | gan.lambda_bicubic = max(gan.lambda_bicubic / self.lambda_bicubic_decay_rate, self.lambda_bicubic_min) 39 | if self.insert_constraints and gan.lambda_bicubic < 5e-3: 40 | gan.lambda_centralized = self.lambda_centralized_end 41 | gan.lambda_sparse = self.lambda_sparse_end 42 | self.insert_constraints = False 43 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from util import shave_a2b, resize_tensor_w_kernel, create_penalty_mask, map2tensor 5 | 6 | 7 | # noinspection PyUnresolvedReferences 8 | class GANLoss(nn.Module): 9 | """D outputs a [0,1] map of size of the input. This map is compared in a pixel-wise manner to 1/0 according to 10 | whether the input is real (i.e. from the input image) or fake (i.e. from the Generator)""" 11 | 12 | def __init__(self, d_last_layer_size): 13 | super(GANLoss, self).__init__() 14 | # The loss function is applied after the pixel-wise comparison to the true label (0/1) 15 | self.loss = nn.L1Loss(reduction='mean') 16 | # Make a shape 17 | d_last_layer_shape = [1, 1, d_last_layer_size, d_last_layer_size] 18 | # The two possible label maps are pre-prepared 19 | self.label_tensor_fake = Variable(torch.zeros(d_last_layer_shape).cuda(), requires_grad=False) 20 | self.label_tensor_real = Variable(torch.ones(d_last_layer_shape).cuda(), requires_grad=False) 21 | 22 | def forward(self, d_last_layer, is_d_input_real): 23 | # Determine label map according to whether current input to discriminator is real or fake 24 | label_tensor = self.label_tensor_real if is_d_input_real else self.label_tensor_fake 25 | # Compute the loss 26 | return self.loss(d_last_layer, label_tensor) 27 | 28 | 29 | class DownScaleLoss(nn.Module): 30 | """ Computes the difference between the Generator's downscaling and an ideal (bicubic) downscaling""" 31 | 32 | def __init__(self, scale_factor): 33 | super(DownScaleLoss, self).__init__() 34 | self.loss = nn.MSELoss() 35 | bicubic_k = [[0.0001373291015625, 0.0004119873046875, -0.0013275146484375, -0.0050811767578125, -0.0050811767578125, -0.0013275146484375, 0.0004119873046875, 0.0001373291015625], 36 | [0.0004119873046875, 0.0012359619140625, -0.0039825439453125, -0.0152435302734375, -0.0152435302734375, -0.0039825439453125, 0.0012359619140625, 0.0004119873046875], 37 | [-.0013275146484375, -0.0039825439453130, 0.0128326416015625, 0.0491180419921875, 0.0491180419921875, 0.0128326416015625, -0.0039825439453125, -0.0013275146484375], 38 | [-.0050811767578125, -0.0152435302734375, 0.0491180419921875, 0.1880035400390630, 0.1880035400390630, 0.0491180419921875, -0.0152435302734375, -0.0050811767578125], 39 | [-.0050811767578125, -0.0152435302734375, 0.0491180419921875, 0.1880035400390630, 0.1880035400390630, 0.0491180419921875, -0.0152435302734375, -0.0050811767578125], 40 | [-.0013275146484380, -0.0039825439453125, 0.0128326416015625, 0.0491180419921875, 0.0491180419921875, 0.0128326416015625, -0.0039825439453125, -0.0013275146484375], 41 | [0.0004119873046875, 0.0012359619140625, -0.0039825439453125, -0.0152435302734375, -0.0152435302734375, -0.0039825439453125, 0.0012359619140625, 0.0004119873046875], 42 | [0.0001373291015625, 0.0004119873046875, -0.0013275146484375, -0.0050811767578125, -0.0050811767578125, -0.0013275146484375, 0.0004119873046875, 0.0001373291015625]] 43 | self.bicubic_kernel = Variable(torch.Tensor(bicubic_k).cuda(), requires_grad=False) 44 | self.scale_factor = scale_factor 45 | 46 | def forward(self, g_input, g_output): 47 | downscaled = resize_tensor_w_kernel(im_t=g_input, k=self.bicubic_kernel, sf=self.scale_factor) 48 | # Shave the downscaled to fit g_output 49 | return self.loss(g_output, shave_a2b(downscaled, g_output)) 50 | 51 | 52 | class SumOfWeightsLoss(nn.Module): 53 | """ Encourages the kernel G is imitating to sum to 1 """ 54 | 55 | def __init__(self): 56 | super(SumOfWeightsLoss, self).__init__() 57 | self.loss = nn.L1Loss() 58 | 59 | def forward(self, kernel): 60 | return self.loss(torch.ones(1).to(kernel.device), torch.sum(kernel)) 61 | 62 | 63 | class CentralizedLoss(nn.Module): 64 | """ Penalizes distance of center of mass from K's center""" 65 | 66 | def __init__(self, k_size, scale_factor=.5): 67 | super(CentralizedLoss, self).__init__() 68 | self.indices = Variable(torch.arange(0., float(k_size)).cuda(), requires_grad=False) 69 | wanted_center_of_mass = k_size // 2 + 0.5 * (int(1 / scale_factor) - k_size % 2) 70 | self.center = Variable(torch.FloatTensor([wanted_center_of_mass, wanted_center_of_mass]).cuda(), requires_grad=False) 71 | self.loss = nn.MSELoss() 72 | 73 | def forward(self, kernel): 74 | """Return the loss over the distance of center of mass from kernel center """ 75 | r_sum, c_sum = torch.sum(kernel, dim=1).reshape(1, -1), torch.sum(kernel, dim=0).reshape(1, -1) 76 | return self.loss(torch.stack((torch.matmul(r_sum, self.indices) / torch.sum(kernel), 77 | torch.matmul(c_sum, self.indices) / torch.sum(kernel))), self.center) 78 | 79 | 80 | class BoundariesLoss(nn.Module): 81 | """ Encourages sparsity of the boundaries by penalizing non-zeros far from the center """ 82 | 83 | def __init__(self, k_size): 84 | super(BoundariesLoss, self).__init__() 85 | self.mask = map2tensor(create_penalty_mask(k_size, 30)) 86 | self.zero_label = Variable(torch.zeros(k_size).cuda(), requires_grad=False) 87 | self.loss = nn.L1Loss() 88 | 89 | def forward(self, kernel): 90 | return self.loss(kernel * self.mask, self.zero_label) 91 | 92 | 93 | class SparsityLoss(nn.Module): 94 | """ Penalizes small values to encourage sparsity """ 95 | def __init__(self): 96 | super(SparsityLoss, self).__init__() 97 | self.power = 0.2 98 | self.loss = nn.L1Loss() 99 | 100 | def forward(self, kernel): 101 | return self.loss(torch.abs(kernel) ** self.power, torch.zeros_like(kernel)) 102 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util import swap_axis 4 | 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, conf): 8 | super(Generator, self).__init__() 9 | struct = conf.G_structure 10 | # First layer - Converting RGB image to latent space 11 | self.first_layer = nn.Conv2d(in_channels=1, out_channels=conf.G_chan, kernel_size=struct[0], bias=False) 12 | 13 | feature_block = [] # Stacking intermediate layer 14 | for layer in range(1, len(struct) - 1): 15 | feature_block += [nn.Conv2d(in_channels=conf.G_chan, out_channels=conf.G_chan, kernel_size=struct[layer], bias=False)] 16 | self.feature_block = nn.Sequential(*feature_block) 17 | # Final layer - Down-sampling and converting back to image 18 | self.final_layer = nn.Conv2d(in_channels=conf.G_chan, out_channels=1, kernel_size=struct[-1], 19 | stride=int(1 / conf.scale_factor), bias=False) 20 | 21 | # Calculate number of pixels shaved in the forward pass 22 | self.output_size = self.forward(torch.FloatTensor(torch.ones([1, 1, conf.input_crop_size, conf.input_crop_size]))).shape[-1] 23 | self.forward_shave = int(conf.input_crop_size * conf.scale_factor) - self.output_size 24 | 25 | def forward(self, input_tensor): 26 | # Swap axis of RGB image for the network to get a "batch" of size = 3 rather the 3 channels 27 | input_tensor = swap_axis(input_tensor) 28 | downscaled = self.first_layer(input_tensor) 29 | features = self.feature_block(downscaled) 30 | output = self.final_layer(features) 31 | return swap_axis(output) 32 | 33 | 34 | class Discriminator(nn.Module): 35 | 36 | def __init__(self, conf): 37 | super(Discriminator, self).__init__() 38 | 39 | # First layer - Convolution (with no ReLU) 40 | self.first_layer = nn.utils.spectral_norm(nn.Conv2d(in_channels=3, out_channels=conf.D_chan, kernel_size=conf.D_kernel_size, bias=True)) 41 | feature_block = [] # Stacking layers with 1x1 kernels 42 | for _ in range(1, conf.D_n_layers - 1): 43 | feature_block += [nn.utils.spectral_norm(nn.Conv2d(in_channels=conf.D_chan, out_channels=conf.D_chan, kernel_size=1, bias=True)), 44 | nn.BatchNorm2d(conf.D_chan), 45 | nn.ReLU(True)] 46 | self.feature_block = nn.Sequential(*feature_block) 47 | self.final_layer = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=conf.D_chan, out_channels=1, kernel_size=1, bias=True)), 48 | nn.Sigmoid()) 49 | 50 | # Calculate number of pixels shaved in the forward pass 51 | self.forward_shave = conf.input_crop_size - self.forward(torch.FloatTensor(torch.ones([1, 3, conf.input_crop_size, conf.input_crop_size]))).shape[-1] 52 | 53 | def forward(self, input_tensor): 54 | receptive_extraction = self.first_layer(input_tensor) 55 | features = self.feature_block(receptive_extraction) 56 | return self.final_layer(features) 57 | 58 | 59 | def weights_init_D(m): 60 | """ initialize weights of the discriminator """ 61 | class_name = m.__class__.__name__ 62 | if class_name.find('Conv') != -1: 63 | nn.init.xavier_normal_(m.weight, 0.1) 64 | if hasattr(m.bias, 'data'): 65 | m.bias.data.fill_(0) 66 | elif class_name.find('BatchNorm2d') != -1: 67 | m.weight.data.normal_(1.0, 0.02) 68 | m.bias.data.fill_(0) 69 | 70 | 71 | def weights_init_G(m): 72 | """ initialize weights of the generator """ 73 | if m.__class__.__name__.find('Conv') != -1: 74 | nn.init.xavier_normal_(m.weight, 0.1) 75 | if hasattr(m.bias, 'data'): 76 | m.bias.data.fill_(0) 77 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | 4 | from configs import Config 5 | from data import DataGenerator 6 | from kernelGAN import KernelGAN 7 | from learner import Learner 8 | 9 | 10 | def train(conf): 11 | gan = KernelGAN(conf) 12 | learner = Learner() 13 | data = DataGenerator(conf, gan) 14 | for iteration in tqdm.tqdm(range(conf.max_iters), ncols=60): 15 | [g_in, d_in] = data.__getitem__(iteration) 16 | gan.train(g_in, d_in) 17 | learner.update(iteration, gan) 18 | gan.finish() 19 | 20 | 21 | def main(): 22 | """The main function - performs kernel estimation (+ ZSSR) for all images in the 'test_images' folder""" 23 | import argparse 24 | # Parse the command line arguments 25 | prog = argparse.ArgumentParser() 26 | prog.add_argument('--input-dir', '-i', type=str, default='test_images', help='path to image input directory.') 27 | prog.add_argument('--output-dir', '-o', type=str, default='results', help='path to image output directory.') 28 | prog.add_argument('--X4', action='store_true', help='The wanted SR scale factor') 29 | prog.add_argument('--SR', action='store_true', help='when activated - ZSSR is not performed') 30 | prog.add_argument('--real', action='store_true', help='ZSSRs configuration is for real images') 31 | prog.add_argument('--noise_scale', type=float, default=1., help='ZSSR uses this to partially de-noise images') 32 | args = prog.parse_args() 33 | # Run the KernelGAN sequentially on all images in the input directory 34 | for filename in os.listdir(os.path.abspath(args.input_dir)): 35 | conf = Config().parse(create_params(filename, args)) 36 | train(conf) 37 | prog.exit(0) 38 | 39 | 40 | def create_params(filename, args): 41 | params = ['--input_image_path', os.path.join(args.input_dir, filename), 42 | '--output_dir_path', os.path.abspath(args.output_dir), 43 | '--noise_scale', str(args.noise_scale)] 44 | if args.X4: 45 | params.append('--X4') 46 | if args.SR: 47 | params.append('--do_ZSSR') 48 | if args.real: 49 | params.append('--real_image') 50 | return params 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | from scipy.signal import convolve2d 9 | from torch.nn import functional as F 10 | from scipy.ndimage import measurements, interpolation 11 | 12 | from ZSSRforKernelGAN.ZSSR import ZSSR 13 | 14 | 15 | def move2cpu(d): 16 | """Move data from gpu to cpu""" 17 | return d.detach().cpu().float().numpy() 18 | 19 | 20 | def tensor2im(im_t): 21 | """Copy the tensor to the cpu & convert to range [0,255]""" 22 | im_np = np.clip(np.round((np.transpose(move2cpu(im_t).squeeze(0), (1, 2, 0)) + 1) / 2.0 * 255.0), 0, 255) 23 | return im_np.astype(np.uint8) 24 | 25 | 26 | def im2tensor(im_np): 27 | """Copy the image to the gpu & converts to range [-1,1]""" 28 | im_np = im_np / 255.0 if im_np.dtype == 'uint8' else im_np 29 | return torch.FloatTensor(np.transpose(im_np, (2, 0, 1)) * 2.0 - 1.0).unsqueeze(0).cuda() 30 | 31 | 32 | def map2tensor(gray_map): 33 | """Move gray maps to GPU, no normalization is done""" 34 | return torch.FloatTensor(gray_map).unsqueeze(0).unsqueeze(0).cuda() 35 | 36 | 37 | def resize_tensor_w_kernel(im_t, k, sf=None): 38 | """Convolves a tensor with a given bicubic kernel according to scale factor""" 39 | # Expand dimensions to fit convolution: [out_channels, in_channels, k_height, k_width] 40 | k = k.expand(im_t.shape[1], im_t.shape[1], k.shape[0], k.shape[1]) 41 | # Calculate padding 42 | padding = (k.shape[-1] - 1) // 2 43 | return F.conv2d(im_t, k, stride=round(1 / sf), padding=padding) 44 | 45 | 46 | def read_image(path): 47 | """Loads an image""" 48 | im = Image.open(path).convert('RGB') 49 | im = np.array(im, dtype=np.uint8) 50 | return im 51 | 52 | 53 | def rgb2gray(im): 54 | """Convert and RGB image to gray-scale""" 55 | return np.dot(im, [0.299, 0.587, 0.114]) if len(im.shape) == 3 else im 56 | 57 | 58 | def swap_axis(im): 59 | """Swap axis of a tensor from a 3 channel tensor to a batch of 3-single channel and vise-versa""" 60 | return im.transpose(0, 1) if type(im) == torch.Tensor else np.moveaxis(im, 0, 1) 61 | 62 | 63 | def shave_a2b(a, b): 64 | """Given a big image or tensor 'a', shave it symmetrically into b's shape""" 65 | # If dealing with a tensor should shave the 3rd & 4th dimension, o.w. the 1st and 2nd 66 | is_tensor = (type(a) == torch.Tensor) 67 | r = 2 if is_tensor else 0 68 | c = 3 if is_tensor else 1 69 | # Calculate the shaving of each dimension 70 | shave_r, shave_c = max(0, a.shape[r] - b.shape[r]), max(0, a.shape[c] - b.shape[c]) 71 | return a[:, :, shave_r // 2:a.shape[r] - shave_r // 2 - shave_r % 2, shave_c // 2:a.shape[c] - shave_c // 2 - shave_c % 2] if is_tensor \ 72 | else a[shave_r // 2:a.shape[r] - shave_r // 2 - shave_r % 2, shave_c // 2:a.shape[c] - shave_c // 2 - shave_c % 2] 73 | 74 | 75 | def create_gradient_map(im, window=5, percent=.97): 76 | """Create a gradient map of the image blurred with a rect of size window and clips extreme values""" 77 | # Calculate gradients 78 | gx, gy = np.gradient(rgb2gray(im)) 79 | # Calculate gradient magnitude 80 | gmag, gx, gy = np.sqrt(gx ** 2 + gy ** 2), np.abs(gx), np.abs(gy) 81 | # Pad edges to avoid artifacts in the edge of the image 82 | gx_pad, gy_pad, gmag = pad_edges(gx, int(window)), pad_edges(gy, int(window)), pad_edges(gmag, int(window)) 83 | lm_x, lm_y, lm_gmag = clip_extreme(gx_pad, percent), clip_extreme(gy_pad, percent), clip_extreme(gmag, percent) 84 | # Sum both gradient maps 85 | grads_comb = lm_x / lm_x.sum() + lm_y / lm_y.sum() + gmag / gmag.sum() 86 | # Blur the gradients and normalize to original values 87 | loss_map = convolve2d(grads_comb, np.ones(shape=(window, window)), 'same') / (window ** 2) 88 | # Normalizing: sum of map = numel 89 | return loss_map / np.mean(loss_map) 90 | 91 | 92 | def create_probability_map(loss_map, crop): 93 | """Create a vector of probabilities corresponding to the loss map""" 94 | # Blur the gradients to get the sum of gradients in the crop 95 | blurred = convolve2d(loss_map, np.ones([crop // 2, crop // 2]), 'same') / ((crop // 2) ** 2) 96 | # Zero pad s.t. probabilities are NNZ only in valid crop centers 97 | prob_map = pad_edges(blurred, crop // 2) 98 | # Normalize to sum to 1 99 | prob_vec = prob_map.flatten() / prob_map.sum() if prob_map.sum() != 0 else np.ones_like(prob_map.flatten()) / prob_map.flatten().shape[0] 100 | return prob_vec 101 | 102 | 103 | def pad_edges(im, edge): 104 | """Replace image boundaries with 0 without changing the size""" 105 | zero_padded = np.zeros_like(im) 106 | zero_padded[edge:-edge, edge:-edge] = im[edge:-edge, edge:-edge] 107 | return zero_padded 108 | 109 | 110 | def clip_extreme(im, percent): 111 | """Zeroize values below the a threshold and clip all those above""" 112 | # Sort the image 113 | im_sorted = np.sort(im.flatten()) 114 | # Choose a pivot index that holds the min value to be clipped 115 | pivot = int(percent * len(im_sorted)) 116 | v_min = im_sorted[pivot] 117 | # max value will be the next value in the sorted array. if it is equal to the min, a threshold will be added 118 | v_max = im_sorted[pivot + 1] if im_sorted[pivot + 1] > v_min else v_min + 10e-6 119 | # Clip an zeroize all the lower values 120 | return np.clip(im, v_min, v_max) - v_min 121 | 122 | 123 | def post_process_k(k, n): 124 | """Move the kernel to the CPU, eliminate negligible values, and centralize k""" 125 | k = move2cpu(k) 126 | # Zeroize negligible values 127 | significant_k = zeroize_negligible_val(k, n) 128 | # Force centralization on the kernel 129 | centralized_k = kernel_shift(significant_k, sf=2) 130 | # return shave_a2b(centralized_k, k) 131 | return centralized_k 132 | 133 | 134 | def zeroize_negligible_val(k, n): 135 | """Zeroize values that are negligible w.r.t to values in k""" 136 | # Sort K's values in order to find the n-th largest 137 | k_sorted = np.sort(k.flatten()) 138 | # Define the minimum value as the 0.75 * the n-th largest value 139 | k_n_min = 0.75 * k_sorted[-n - 1] 140 | # Clip values lower than the minimum value 141 | filtered_k = np.clip(k - k_n_min, a_min=0, a_max=100) 142 | # Normalize to sum to 1 143 | return filtered_k / filtered_k.sum() 144 | 145 | 146 | def create_penalty_mask(k_size, penalty_scale): 147 | """Generate a mask of weights penalizing values close to the boundaries""" 148 | center_size = k_size // 2 + k_size % 2 149 | mask = create_gaussian(size=k_size, sigma1=k_size, is_tensor=False) 150 | mask = 1 - mask / np.max(mask) 151 | margin = (k_size - center_size) // 2 - 1 152 | mask[margin:-margin, margin:-margin] = 0 153 | return penalty_scale * mask 154 | 155 | 156 | def create_gaussian(size, sigma1, sigma2=-1, is_tensor=False): 157 | """Return a Gaussian""" 158 | func1 = [np.exp(-z ** 2 / (2 * sigma1 ** 2)) / np.sqrt(2 * np.pi * sigma1 ** 2) for z in range(-size // 2 + 1, size // 2 + 1)] 159 | func2 = func1 if sigma2 == -1 else [np.exp(-z ** 2 / (2 * sigma2 ** 2)) / np.sqrt(2 * np.pi * sigma2 ** 2) for z in range(-size // 2 + 1, size // 2 + 1)] 160 | return torch.FloatTensor(np.outer(func1, func2)).cuda() if is_tensor else np.outer(func1, func2) 161 | 162 | 163 | def nn_interpolation(im, sf): 164 | """Nearest neighbour interpolation""" 165 | pil_im = Image.fromarray(im) 166 | return np.array(pil_im.resize((im.shape[1] * sf, im.shape[0] * sf), Image.NEAREST), dtype=im.dtype) 167 | 168 | 169 | def analytic_kernel(k): 170 | """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" 171 | k_size = k.shape[0] 172 | # Calculate the big kernels size 173 | big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) 174 | # Loop over the small kernel to fill the big one 175 | for r in range(k_size): 176 | for c in range(k_size): 177 | big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k 178 | # Crop the edges of the big kernel to ignore very small values and increase run time of SR 179 | crop = k_size // 2 180 | cropped_big_k = big_k[crop:-crop, crop:-crop] 181 | # Normalize to 1 182 | return cropped_big_k / cropped_big_k.sum() 183 | 184 | 185 | def kernel_shift(kernel, sf): 186 | # There are two reasons for shifting the kernel : 187 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 188 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 189 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 190 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 191 | # top left corner of the first pixel. that is why different shift size needed between odd and even size. 192 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 193 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 194 | 195 | # First calculate the current center of mass for the kernel 196 | current_center_of_mass = measurements.center_of_mass(kernel) 197 | 198 | # The second term ("+ 0.5 * ....") is for applying condition 2 from the comments above 199 | wanted_center_of_mass = np.array(kernel.shape) // 2 + 0.5 * (np.array(sf) - (np.array(kernel.shape) % 2)) 200 | # Define the shift vector for the kernel shifting (x,y) 201 | shift_vec = wanted_center_of_mass - current_center_of_mass 202 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 203 | # (biggest shift among dims + 1 for safety) 204 | kernel = np.pad(kernel, np.int(np.ceil(np.max(np.abs(shift_vec)))) + 1, 'constant') 205 | 206 | # Finally shift the kernel and return 207 | kernel = interpolation.shift(kernel, shift_vec) 208 | 209 | return kernel 210 | 211 | 212 | def save_final_kernel(k_2, conf): 213 | """saves the final kernel and the analytic kernel to the results folder""" 214 | sio.savemat(os.path.join(conf.output_dir_path, '%s_kernel_x2.mat' % conf.img_name), {'Kernel': k_2}) 215 | if conf.X4: 216 | k_4 = analytic_kernel(k_2) 217 | sio.savemat(os.path.join(conf.output_dir_path, '%s_kernel_x4.mat' % conf.img_name), {'Kernel': k_4}) 218 | 219 | 220 | def run_zssr(k_2, conf): 221 | """Performs ZSSR with estimated kernel for wanted scale factor""" 222 | if conf.do_ZSSR: 223 | start_time = time.time() 224 | print('~' * 30 + '\nRunning ZSSR X%d...' % (4 if conf.X4 else 2)) 225 | if conf.X4: 226 | sr = ZSSR(conf.input_image_path, scale_factor=[[2, 2], [4, 4]], kernels=[k_2, analytic_kernel(k_2)], is_real_img=conf.real_image, noise_scale=conf.noise_scale).run() 227 | else: 228 | sr = ZSSR(conf.input_image_path, scale_factor=2, kernels=[k_2], is_real_img=conf.real_image, noise_scale=conf.noise_scale).run() 229 | max_val = 255 if sr.dtype == 'uint8' else 1. 230 | plt.imsave(os.path.join(conf.output_dir_path, 'ZSSR_%s.png' % conf.img_name), sr, vmin=0, vmax=max_val, dpi=1) 231 | runtime = int(time.time() - start_time) 232 | print('Completed! runtime=%d:%d\n' % (runtime // 60, runtime % 60) + '~' * 30) 233 | --------------------------------------------------------------------------------