├── LICENSE ├── README.md ├── ZSSR.py ├── configs.py ├── example_with_gt ├── bsd_001.png └── bsd_001_gt.png ├── figs └── sketch.png ├── imresize.py ├── kernel_example ├── BSD100_100_lr_rand_ker_c_X2.png └── BSD100_100_lr_rand_ker_c_X2_0.mat ├── real_example ├── charlie.png └── charlie_0.mat ├── run_ZSSR.py ├── run_ZSSR_single_input.py ├── set14 ├── img_001_SRF_2_LR.png ├── img_002_SRF_2_LR.png ├── img_003_SRF_2_LR.png ├── img_004_SRF_2_LR.png ├── img_005_SRF_2_LR.png ├── img_006_SRF_2_LR.png ├── img_007_SRF_2_LR.png ├── img_008_SRF_2_LR.png ├── img_009_SRF_2_LR.png ├── img_010_SRF_2_LR.png ├── img_011_SRF_2_LR.png ├── img_012_SRF_2_LR.png ├── img_013_SRF_2_LR.png └── img_014_SRF_2_LR.png └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | The Weizmann Institute of Science 2 | Academic Non Commercial Software Code License 3 | [ZSSR – Zero-Shot Super-Resolution using Internal Deep Learning] (the "Work") 4 | © 2018 The Weizmann Institute of Science ("WIS") and Yeda Research and Development Company Ltd. ("Yeda") All Rights Reserved 5 | 6 | 1. YEDA, the commercial arm of WIS, hereby grants you, an individual or a legal entity exercising rights under, and complying with all of the provisions, of this License (“You”) a royalty-free, non-exclusive, sublicensable, worldwide license to: use, copy, modify, create derivative works (including without limiting to: adapt, alter, transform), integrate with other works, distribute, enable access (including without limiting to: communicate copies), publicly display and perform the Work in binary form or in source code, for academic and noncommercial use only and subject to all provisions of this License: 7 | 2. YEDA hereby grants You a royalty-free, non-exclusive, sublicensable, worldwide license under patents claimed or owned by YEDA that are embodied in the Work, to make, have made and use the Work under the License, for avoidance of doubt for academic and noncommercial use only. 8 | 3. Distribution or provision of access to the Work and to derivative works of the Work ("Derivative Works") may be made only under this License, accompanied with a copy of the source code or a reference to an online repository where such source code can be accessed. 9 | 4. Neither the names of WIS or Yeda, nor any of their trademarks or service marks, may be used to endorse or promote Derivative Works or for any other purpose except as expressly permitted hereunder. 10 | 5. Except as expressly stated in this License, nothing in this License grants any license to trademarks, copyrights, patents, trade secrets or any other intellectual property of WIS or Yeda. No license is granted to the trademarks of WIS or Yeda's even if such marks are included in the Work. 11 | 6. Nothing in this License shall be interpreted to prohibit WIS or Yeda from licensing the Work under terms different from this License. For commercial use please e-mail Yeda at: info.yeda@weizmann.ac.il 12 | 7. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Work, as well as a notice to inform recipients that You have modified the Work with a description of such modifications. 13 | 8. THE WORK IS PROVIDED "AS IS" AND WITHOUT ANY WARRANTIES WHATSOEVER, EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION ANY WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. 14 | 9. IN NO EVENT WILL WIS, YEDA OR ANY OF THEIR RELATED ENTITES, SCIENTISTS, EMPLOYEES, MANAGERS OR ANY OTHE PERSON ACTING ON THEIR BEHALF, BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY OR CAUSE OF ACTION, WHETHER IN CONTRACT, TORT, STRICT LIABILITY, UNJUST ENRICHMENT OR ANY OTHER, ARISING IN ANY WAY OUT OF THE USE OF THE WORK OR THIS LICENSE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | 10. This License will terminate automatically if any of its conditions is not met, or in case You commence an action, including a cross-claim or counterclaim, against WIS or YEDA or any licensee alleging that the Work (except due to combination with other software or hardware) infringes a patent. 16 | 11. This License shall be exclusively governed by the laws of the State of Israel, without giving effect to conflict of laws principles, and the competent courts in Tel Aviv will have exclusive jurisdiction and venue over any matter between You and WIS or YEDA or any of their related entities relating to this License or the Work. 17 | 12. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | Yeda - Academic Non Commercial Software Code License.txt 31 | Displaying Yeda - Academic Non Commercial Software Code License.txt. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "Zero-Shot" Super-Resolution using Deep Internal Learning (ZSSR) 2 | ### Official implementation for paper by: Assaf Shocher, Nadav Cohen, Michal Irani 3 | 4 | Paper: https://arxiv.org/abs/1712.06087 5 | Project page: http://www.wisdom.weizmann.ac.il/~vision/zssr/ (See our results and visual comparison to other methods) 6 | 7 | **Accepted CVPR'18** 8 | 9 | ---------- 10 | This current provided version of ZSSR actually achieves better results on benchmarks than indicated in the paper. 11 | For example, when current version is applied to 'Set14' without use of gradual SR increments, it achieves slightly higher PSNR than specified in the paper (when 6 gradual increments are applied). When gradual increments similar to those specified in the paper are applied, then +0.3dB is obtained. 12 | 13 | ---------- 14 | ![sketch](/figs/sketch.png) 15 | ---------- 16 | If you find our work useful in your research or publication, please cite our work: 17 | 18 | ``` 19 | @InProceedings{ZSSR, 20 | author = {Assaf Shocher, Nadav Cohen, Michal Irani}, 21 | title = {"Zero-Shot" Super-Resolution using Deep Internal Learning}, 22 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 23 | month = {June}, 24 | year = {2018} 25 | } 26 | ``` 27 | ---------- 28 | # Usage: 29 | 30 | ## Quick usage on your data: 31 | (First, put your desired low-res files in ```/test_data/```. 32 | Results will be generated to ```/results//```. 33 | data must be *.png type) 34 | ``` 35 | python run_ZSSR.py 36 | ``` 37 | 38 | ## General usage: 39 | ``` 40 | python run_ZSSR.py 41 | ``` 42 | While ``` ``` is an instance of configs.Config class (at configs.py) or 0 for default configuration. 43 | Please see configs.py to determine configuration (data paths, scale-factors etc.) 44 | ``` ``` is an optional parameter to determine how to use available GPUs (see next section). 45 | 46 | For using given kernels, you must have a kernels for each input file and each scale-factor named as follows: 47 | ``` _.mat ``` 48 | Kernels are MATLAB files containing a matrix named "Kernel". 49 | 50 | If gound-truth exists and true-error monitoring is wanted, then ground truth should be named as follows: 51 | ``` _gt.png ``` 52 | 53 | 54 | ## GPU options 55 | Run on a specific GPU: 56 | ``` 57 | python run_ZSSR.py 0 58 | ``` 59 | Run multiple files efficiently on multiple GPUs. 60 | **Before using this option make sure you update in the configs.py file the ***python_path*** parameter** 61 | ``` 62 | python run_ZSSR.py all 63 | ``` 64 | 65 | ## Quick usage examples (applied on provided data examples): 66 | Usage example to test 'Set14', Gradual SR (~0.3dB better results, 6x Runtime) 67 | ``` 68 | python run_ZSSR.py X2_GRADUAL_IDEAL_CONF 69 | ``` 70 | Usage example to test 'Set14' (Non-Gradual SR) 71 | ``` 72 | python run_ZSSR.py X2_ONE_JUMP_IDEAL_CONF 73 | ``` 74 | Visualization while running (Recommended for one image, interactive mode, for debugging) 75 | ``` 76 | python run_ZSSR.py X2_IDEAL_WITH_PLOT_CONF 77 | ``` 78 | Applying a given kernel 79 | ``` 80 | python run_ZSSR.py X2_GIVEN_KERNEL_CONF 81 | ``` 82 | Run on a real image 83 | ``` 84 | python run_ZSSR.py X2_REAL_CONF 85 | ``` 86 | 87 | ---------- 88 | Example kernels were generated from the input images using: 89 | [T. Michaeli and M. Irani, Nonparametric Blind Super-Resolution. International Conference on Computer Vision (ICCV), October 2013.](http://www.wisdom.weizmann.ac.il/~vision/BlindSR.html) 90 | -------------------------------------------------------------------------------- /ZSSR.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | import matplotlib.image as img 4 | from matplotlib.gridspec import GridSpec 5 | from configs import Config 6 | from utils import * 7 | 8 | 9 | class ZSSR: 10 | # Basic current state variables initialization / declaration 11 | kernel = None 12 | learning_rate = None 13 | hr_father = None 14 | lr_son = None 15 | sr = None 16 | sf = None 17 | gt_per_sf = None 18 | final_sr = None 19 | hr_fathers_sources = [] 20 | 21 | # Output variables initialization / declaration 22 | reconstruct_output = None 23 | train_output = None 24 | output_shape = None 25 | 26 | # Counters and logs initialization 27 | iter = 0 28 | base_sf = 1.0 29 | base_ind = 0 30 | sf_ind = 0 31 | mse = [] 32 | mse_rec = [] 33 | interp_rec_mse = [] 34 | interp_mse = [] 35 | mse_steps = [] 36 | loss = [] 37 | learning_rate_change_iter_nums = [] 38 | fig = None 39 | 40 | # Network tensors (all tensors end with _t to distinguish) 41 | learning_rate_t = None 42 | lr_son_t = None 43 | hr_father_t = None 44 | filters_t = None 45 | layers_t = None 46 | net_output_t = None 47 | loss_t = None 48 | train_op = None 49 | init_op = None 50 | 51 | # Parameters related to plotting and graphics 52 | plots = None 53 | loss_plot_space = None 54 | lr_son_image_space = None 55 | hr_father_image_space = None 56 | out_image_space = None 57 | 58 | # Tensorflow graph default 59 | sess = None 60 | 61 | def __init__(self, input_img, conf=Config(), ground_truth=None, kernels=None): 62 | # Acquire meta parameters configuration from configuration class as a class variable 63 | self.conf = conf 64 | 65 | # Read input image (can be either a numpy array or a path to an image file) 66 | self.input = input_img if type(input_img) is not str else img.imread(input_img) 67 | 68 | # For evaluation purposes, ground-truth image can be supplied. 69 | self.gt = ground_truth if type(ground_truth) is not str else img.imread(ground_truth) 70 | 71 | # Preprocess the kernels. (see function to see what in includes). 72 | self.kernels = preprocess_kernels(kernels, conf) 73 | 74 | # Prepare TF default computational graph 75 | self.model = tf.Graph() 76 | 77 | # Build network computational graph 78 | self.build_network(conf) 79 | 80 | # Initialize network weights and meta parameters 81 | self.init_sess(init_weights=True) 82 | 83 | # The first hr father source is the input (source goes through augmentation to become a father) 84 | # Later on, if we use gradual sr increments, results for intermediate scales will be added as sources. 85 | self.hr_fathers_sources = [self.input] 86 | 87 | # We keep the input file name to save the output with a similar name. If array was given rather than path 88 | # then we use default provided by the configs 89 | self.file_name = input_img if type(input_img) is str else conf.name 90 | 91 | def run(self): 92 | # Run gradually on all scale factors (if only one jump then this loop only happens once) 93 | for self.sf_ind, (sf, self.kernel) in enumerate(zip(self.conf.scale_factors, self.kernels)): 94 | # verbose 95 | print '** Start training for sf=', sf, ' **' 96 | 97 | # Relative_sf (used when base change is enabled. this is when input is the output of some previous scale) 98 | if np.isscalar(sf): 99 | sf = [sf, sf] 100 | self.sf = np.array(sf) / np.array(self.base_sf) 101 | self.output_shape = np.uint(np.ceil(np.array(self.input.shape[0:2]) * sf)) 102 | 103 | # Initialize network 104 | self.init_sess(init_weights=self.conf.init_net_for_each_sf) 105 | 106 | # Train the network 107 | self.train() 108 | 109 | # Use augmented outputs and back projection to enhance result. Also save the result. 110 | post_processed_output = self.final_test() 111 | 112 | # Keep the results for the next scale factors SR to use as dataset 113 | self.hr_fathers_sources.append(post_processed_output) 114 | 115 | # In some cases, the current output becomes the new input. If indicated and if this is the right scale to 116 | # become the new base input. all of these conditions are checked inside the function. 117 | self.base_change() 118 | 119 | # Save the final output if indicated 120 | if self.conf.save_results: 121 | sf_str = ''.join('X%.2f' % s for s in self.conf.scale_factors[self.sf_ind]) 122 | plt.imsave('%s/%s_zssr_%s.png' % 123 | (self.conf.result_path, os.path.basename(self.file_name)[:-4], sf_str), 124 | post_processed_output, vmin=0, vmax=1) 125 | 126 | # verbose 127 | print '** Done training for sf=', sf, ' **' 128 | 129 | # Return the final post processed output. 130 | # noinspection PyUnboundLocalVariable 131 | return post_processed_output 132 | 133 | def build_network(self, meta): 134 | with self.model.as_default(): 135 | 136 | # Learning rate tensor 137 | self.learning_rate_t = tf.placeholder(tf.float32, name='learning_rate') 138 | 139 | # Input image 140 | self.lr_son_t = tf.placeholder(tf.float32, name='lr_son') 141 | 142 | # Ground truth (supervision) 143 | self.hr_father_t = tf.placeholder(tf.float32, name='hr_father') 144 | 145 | # Filters 146 | self.filters_t = [tf.get_variable(shape=meta.filter_shape[ind], name='filter_%d' % ind, 147 | initializer=tf.random_normal_initializer( 148 | stddev=np.sqrt(meta.init_variance/np.prod( 149 | meta.filter_shape[ind][0:3])))) 150 | for ind in range(meta.depth)] 151 | 152 | # Activate filters on layers one by one (this is just building the graph, no calculation is done here) 153 | self.layers_t = [self.lr_son_t] + [None] * meta.depth 154 | for l in range(meta.depth - 1): 155 | self.layers_t[l + 1] = tf.nn.relu(tf.nn.conv2d(self.layers_t[l], self.filters_t[l], 156 | [1, 1, 1, 1], "SAME", name='layer_%d' % (l + 1))) 157 | 158 | # Last conv layer (Separate because no ReLU here) 159 | l = meta.depth - 1 160 | self.layers_t[-1] = tf.nn.conv2d(self.layers_t[l], self.filters_t[l], 161 | [1, 1, 1, 1], "SAME", name='layer_%d' % (l + 1)) 162 | 163 | # Output image (Add last conv layer result to input, residual learning with global skip connection) 164 | self.net_output_t = self.layers_t[-1] + self.conf.learn_residual * self.lr_son_t 165 | 166 | # Final loss (L1 loss between label and output layer) 167 | self.loss_t = tf.reduce_mean(tf.reshape(tf.abs(self.net_output_t - self.hr_father_t), [-1])) 168 | 169 | # Apply adam optimizer 170 | self.train_op = tf.train.AdamOptimizer(learning_rate=self.learning_rate_t).minimize(self.loss_t) 171 | self.init_op = tf.initialize_all_variables() 172 | 173 | def init_sess(self, init_weights=True): 174 | # Sometimes we only want to initialize some meta-params but keep the weights as they were 175 | if init_weights: 176 | 177 | # These are for GPU consumption, preventing TF to catch all available GPUs 178 | config = tf.ConfigProto() 179 | config.gpu_options.allow_growth = True 180 | 181 | # Initialize computational graph session 182 | self.sess = tf.Session(graph=self.model, config=config) 183 | 184 | # Initialize weights 185 | self.sess.run(self.init_op) 186 | 187 | # Initialize all counters etc 188 | self.loss = [None] * self.conf.max_iters 189 | self.mse, self.mse_rec, self.interp_mse, self.interp_rec_mse, self.mse_steps = [], [], [], [], [] 190 | self.iter = 0 191 | self.learning_rate = self.conf.learning_rate 192 | self.learning_rate_change_iter_nums = [0] 193 | 194 | # Downscale ground-truth to the intermediate sf size (for gradual SR). 195 | # This only happens if there exists ground-truth and sf is not the last one (or too close to it). 196 | # We use imresize with both scale and output-size, see comment in forward_backward_pass. 197 | # noinspection PyTypeChecker 198 | self.gt_per_sf = (imresize(self.gt, 199 | scale_factor=self.sf / self.conf.scale_factors[-1], 200 | output_shape=self.output_shape, 201 | kernel=self.conf.downscale_gt_method) 202 | if (self.gt is not None and 203 | self.sf is not None and 204 | np.any(np.abs(self.sf - self.conf.scale_factors[-1]) > 0.01)) 205 | else self.gt) 206 | 207 | def forward_backward_pass(self, lr_son, hr_father): 208 | # First gate for the lr-son into the network is interpolation to the size of the father 209 | # Note: we specify both output_size and scale_factor. best explained by example: say father size is 9 and sf=2, 210 | # small_son size is 4. if we upscale by sf=2 we get wrong size, if we upscale to size 9 we get wrong sf. 211 | # The current imresize implementation supports specifying both. 212 | interpolated_lr_son = imresize(lr_son, self.sf, hr_father.shape, self.conf.upscale_method) 213 | 214 | # Create feed dict 215 | feed_dict = {'learning_rate:0': self.learning_rate, 216 | 'lr_son:0': np.expand_dims(interpolated_lr_son, 0), 217 | 'hr_father:0': np.expand_dims(hr_father, 0)} 218 | 219 | # Run network 220 | _, self.loss[self.iter], train_output = self.sess.run([self.train_op, self.loss_t, self.net_output_t], 221 | feed_dict) 222 | return np.clip(np.squeeze(train_output), 0, 1) 223 | 224 | def forward_pass(self, lr_son, hr_father_shape=None): 225 | # First gate for the lr-son into the network is interpolation to the size of the father 226 | interpolated_lr_son = imresize(lr_son, self.sf, hr_father_shape, self.conf.upscale_method) 227 | 228 | # Create feed dict 229 | feed_dict = {'lr_son:0': np.expand_dims(interpolated_lr_son, 0)} 230 | 231 | # Run network 232 | return np.clip(np.squeeze(self.sess.run([self.net_output_t], feed_dict)), 0, 1) 233 | 234 | def learning_rate_policy(self): 235 | # fit linear curve and check slope to determine whether to do nothing, reduce learning rate or finish 236 | if (not (1 + self.iter) % self.conf.learning_rate_policy_check_every 237 | and self.iter - self.learning_rate_change_iter_nums[-1] > self.conf.min_iters): 238 | # noinspection PyTupleAssignmentBalance 239 | [slope, _], [[var, _], _] = np.polyfit(self.mse_steps[-(self.conf.learning_rate_slope_range / 240 | self.conf.run_test_every):], 241 | self.mse_rec[-(self.conf.learning_rate_slope_range / 242 | self.conf.run_test_every):], 243 | 1, cov=True) 244 | 245 | # We take the the standard deviation as a measure 246 | std = np.sqrt(var) 247 | 248 | # Verbose 249 | print 'slope: ', slope, 'STD: ', std 250 | 251 | # Determine learning rate maintaining or reduction by the ration between slope and noise 252 | if -self.conf.learning_rate_change_ratio * slope < std: 253 | self.learning_rate /= 10 254 | print "learning rate updated: ", self.learning_rate 255 | 256 | # Keep track of learning rate changes for plotting purposes 257 | self.learning_rate_change_iter_nums.append(self.iter) 258 | 259 | def quick_test(self): 260 | # There are four evaluations needed to be calculated: 261 | 262 | # 1. True MSE (only if ground-truth was given), note: this error is before post-processing. 263 | # Run net on the input to get the output super-resolution (almost final result, only post-processing needed) 264 | self.sr = self.forward_pass(self.input) 265 | self.mse = (self.mse + [np.mean(np.ndarray.flatten(np.square(self.gt_per_sf - self.sr)))] 266 | if self.gt_per_sf is not None else None) 267 | 268 | # 2. Reconstruction MSE, run for reconstruction- try to reconstruct the input from a downscaled version of it 269 | self.reconstruct_output = self.forward_pass(self.father_to_son(self.input), self.input.shape) 270 | self.mse_rec.append(np.mean(np.ndarray.flatten(np.square(self.input - self.reconstruct_output)))) 271 | 272 | # 3. True MSE of simple interpolation for reference (only if ground-truth was given) 273 | interp_sr = imresize(self.input, self.sf, self.output_shape, self.conf.upscale_method) 274 | self.interp_mse = (self.interp_mse + [np.mean(np.ndarray.flatten(np.square(self.gt_per_sf - interp_sr)))] 275 | if self.gt_per_sf is not None else None) 276 | 277 | # 4. Reconstruction MSE of simple interpolation over downscaled input 278 | interp_rec = imresize(self.father_to_son(self.input), self.sf, self.input.shape[0:2], self.conf.upscale_method) 279 | self.interp_rec_mse.append(np.mean(np.ndarray.flatten(np.square(self.input - interp_rec)))) 280 | 281 | # Track the iters in which tests are made for the graphics x axis 282 | self.mse_steps.append(self.iter) 283 | 284 | # Display test results if indicated 285 | if self.conf.display_test_results: 286 | print 'iteration: ', self.iter, 'reconstruct mse:', self.mse_rec[-1], ', true mse:', (self.mse[-1] 287 | if self.mse else None) 288 | 289 | # plot losses if needed 290 | if self.conf.plot_losses: 291 | self.plot() 292 | 293 | def train(self): 294 | # main training loop 295 | for self.iter in xrange(self.conf.max_iters): 296 | # Use augmentation from original input image to create current father. 297 | # If other scale factors were applied before, their result is also used (hr_fathers_in) 298 | self.hr_father = random_augment(ims=self.hr_fathers_sources, 299 | base_scales=[1.0] + self.conf.scale_factors, 300 | leave_as_is_probability=self.conf.augment_leave_as_is_probability, 301 | no_interpolate_probability=self.conf.augment_no_interpolate_probability, 302 | min_scale=self.conf.augment_min_scale, 303 | max_scale=([1.0] + self.conf.scale_factors)[len(self.hr_fathers_sources)-1], 304 | allow_rotation=self.conf.augment_allow_rotation, 305 | scale_diff_sigma=self.conf.augment_scale_diff_sigma, 306 | shear_sigma=self.conf.augment_shear_sigma, 307 | crop_size=self.conf.crop_size) 308 | 309 | # Get lr-son from hr-father 310 | self.lr_son = self.father_to_son(self.hr_father) 311 | 312 | # run network forward and back propagation, one iteration (This is the heart of the training) 313 | self.train_output = self.forward_backward_pass(self.lr_son, self.hr_father) 314 | 315 | # Display info and save weights 316 | if not self.iter % self.conf.display_every: 317 | print 'sf:', self.sf*self.base_sf, ', iteration: ', self.iter, ', loss: ', self.loss[self.iter] 318 | 319 | # Test network 320 | if self.conf.run_test and (not self.iter % self.conf.run_test_every): 321 | self.quick_test() 322 | 323 | # Consider changing learning rate or stop according to iteration number and losses slope 324 | self.learning_rate_policy() 325 | 326 | # stop when minimum learning rate was passed 327 | if self.learning_rate < self.conf.min_learning_rate: 328 | break 329 | 330 | def father_to_son(self, hr_father): 331 | # Create son out of the father by downscaling and if indicated adding noise 332 | lr_son = imresize(hr_father, 1.0 / self.sf, kernel=self.kernel) 333 | return np.clip(lr_son + np.random.randn(*lr_son.shape) * self.conf.noise_std, 0, 1) 334 | 335 | def final_test(self): 336 | # Run over 8 augmentations of input - 4 rotations and mirror (geometric self ensemble) 337 | outputs = [] 338 | 339 | # The weird range means we only do it once if output_flip is disabled 340 | # We need to check if scale factor is symmetric to all dimensions, if not we will do 180 jumps rather than 90 341 | for k in range(0, 1 + 7 * self.conf.output_flip, 1 + int(self.sf[0] != self.sf[1])): 342 | # Rotate 90*k degrees and mirror flip when k>=4 343 | test_input = np.rot90(self.input, k) if k < 4 else np.fliplr(np.rot90(self.input, k)) 344 | 345 | # Apply network on the rotated input 346 | tmp_output = self.forward_pass(test_input) 347 | 348 | # Undo the rotation for the processed output (mind the opposite order of the flip and the rotation) 349 | tmp_output = np.rot90(tmp_output, -k) if k < 4 else np.rot90(np.fliplr(tmp_output), -k) 350 | 351 | # fix SR output with back projection technique for each augmentation 352 | for bp_iter in range(self.conf.back_projection_iters[self.sf_ind]): 353 | tmp_output = back_projection(tmp_output, self.input, down_kernel=self.kernel, 354 | up_kernel=self.conf.upscale_method, sf=self.sf) 355 | 356 | # save outputs from all augmentations 357 | outputs.append(tmp_output) 358 | 359 | # Take the median over all 8 outputs 360 | almost_final_sr = np.median(outputs, 0) 361 | 362 | # Again back projection for the final fused result 363 | for bp_iter in range(self.conf.back_projection_iters[self.sf_ind]): 364 | almost_final_sr = back_projection(almost_final_sr, self.input, down_kernel=self.kernel, 365 | up_kernel=self.conf.upscale_method, sf=self.sf) 366 | 367 | # Now we can keep the final result (in grayscale case, colors still need to be added, but we don't care 368 | # because it is done before saving and for every other purpose we use this result) 369 | self.final_sr = almost_final_sr 370 | 371 | # Add colors to result image in case net was activated only on grayscale 372 | return self.final_sr 373 | 374 | def base_change(self): 375 | # If there is no base scale large than the current one get out of here 376 | if len(self.conf.base_change_sfs) < self.base_ind + 1: 377 | return 378 | 379 | # Change base input image if required (this means current output becomes the new input) 380 | if abs(self.conf.scale_factors[self.sf_ind] - self.conf.base_change_sfs[self.base_ind]) < 0.001: 381 | if len(self.conf.base_change_sfs) > self.base_ind: 382 | 383 | # The new input is the current output 384 | self.input = self.final_sr 385 | 386 | # The new base scale_factor 387 | self.base_sf = self.conf.base_change_sfs[self.base_ind] 388 | 389 | # Keeping track- this is the index inside the base scales list (provided in the config) 390 | self.base_ind += 1 391 | 392 | print 'base changed to %.2f' % self.base_sf 393 | 394 | def plot(self): 395 | plots_data, labels = zip(*[(np.array(x), l) for (x, l) 396 | in zip([self.mse, self.mse_rec, self.interp_mse, self.interp_rec_mse], 397 | ['True MSE', 'Reconstruct MSE', 'Bicubic to ground truth MSE', 398 | 'Bicubic to reconstruct MSE']) if x is not None]) 399 | 400 | # For the first iteration create the figure 401 | if not self.iter: 402 | # Create figure and split it using GridSpec. Name each region as needed 403 | self.fig = plt.figure(figsize=(9.5, 9)) 404 | grid = GridSpec(4, 4) 405 | self.loss_plot_space = plt.subplot(grid[:-1, :]) 406 | self.lr_son_image_space = plt.subplot(grid[3, 0]) 407 | self.hr_father_image_space = plt.subplot(grid[3, 3]) 408 | self.out_image_space = plt.subplot(grid[3, 1]) 409 | 410 | # Activate interactive mode for live plot updating 411 | plt.ion() 412 | 413 | # Set some parameters for the plots 414 | self.loss_plot_space.set_xlabel('step') 415 | self.loss_plot_space.set_ylabel('MSE') 416 | self.loss_plot_space.grid(True) 417 | self.loss_plot_space.set_yscale('log') 418 | self.loss_plot_space.legend() 419 | self.plots = [None] * 4 420 | 421 | # loop over all needed plot types. if some data is none than skip, if some data is one value tile it 422 | self.plots = self.loss_plot_space.plot(*[[0]] * 2 * len(plots_data)) 423 | 424 | # Update plots 425 | for plot, plot_data in zip(self.plots, plots_data): 426 | plot.set_data(self.mse_steps, plot_data) 427 | 428 | self.loss_plot_space.set_xlim([0, self.iter + 1]) 429 | all_losses = np.array(plots_data) 430 | self.loss_plot_space.set_ylim([np.min(all_losses)*0.9, np.max(all_losses)*1.1]) 431 | 432 | # Mark learning rate changes 433 | for iter_num in self.learning_rate_change_iter_nums: 434 | self.loss_plot_space.axvline(iter_num) 435 | 436 | # Add legend to graphics 437 | self.loss_plot_space.legend(labels) 438 | 439 | # Show current input and output images 440 | self.lr_son_image_space.imshow(self.lr_son, vmin=0.0, vmax=1.0) 441 | self.out_image_space.imshow(self.train_output, vmin=0.0, vmax=1.0) 442 | self.hr_father_image_space.imshow(self.hr_father, vmin=0.0, vmax=1.0) 443 | 444 | # These line are needed in order to see the graphics at real time 445 | self.fig.canvas.draw() 446 | plt.pause(0.01) 447 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config: 5 | # network meta params 6 | python_path = '/home/assafsho/PycharmProjects/network/venv/bin/python2.7' 7 | scale_factors = [[2.0, 2.0]] # list of pairs (vertical, horizontal) for gradual increments in resolution 8 | base_change_sfs = [] # list of scales after which the input is changed to be the output (recommended for high sfs) 9 | max_iters = 3000 10 | min_iters = 256 11 | min_learning_rate = 9e-6 # this tells the algorithm when to stop (specify lower than the last learning-rate) 12 | width = 64 13 | depth = 8 14 | output_flip = True # geometric self-ensemble (see paper) 15 | downscale_method = 'cubic' # a string ('cubic', 'linear'...), has no meaning if kernel given 16 | upscale_method = 'cubic' # this is the base interpolation from which we learn the residual (same options as above) 17 | downscale_gt_method = 'cubic' # when ground-truth given and intermediate scales tested, we shrink gt to wanted size 18 | learn_residual = True # when true, we only learn the residual from base interpolation 19 | init_variance = 0.1 # variance of weight initializations, typically smaller when residual learning is on 20 | back_projection_iters = [10] # for each scale num of bp iterations (same length as scale_factors) 21 | random_crop = True 22 | crop_size = 128 23 | noise_std = 0.0 # adding noise to lr-sons. small for real images, bigger for noisy images and zero for ideal case 24 | init_net_for_each_sf = False # for gradual sr- should we optimize from the last sf or initialize each time? 25 | 26 | # Params concerning learning rate policy 27 | learning_rate = 0.001 28 | learning_rate_change_ratio = 1.5 # ratio between STD and slope of linear fit, under which lr is reduced 29 | learning_rate_policy_check_every = 60 30 | learning_rate_slope_range = 256 31 | 32 | # Data augmentation related params 33 | augment_leave_as_is_probability = 0.05 34 | augment_no_interpolate_probability = 0.45 35 | augment_min_scale = 0.5 36 | augment_scale_diff_sigma = 0.25 37 | augment_shear_sigma = 0.1 38 | augment_allow_rotation = True # recommended false for non-symmetric kernels 39 | 40 | # params related to test and display 41 | run_test = True 42 | run_test_every = 50 43 | display_every = 20 44 | name = 'test' 45 | plot_losses = False 46 | result_path = os.path.dirname(__file__) + '/results' 47 | create_results_dir = True 48 | input_path = local_dir = os.path.dirname(__file__) + '/test_data' 49 | create_code_copy = True # save a copy of the code in the results folder to easily match code changes to results 50 | display_test_results = True 51 | save_results = True 52 | 53 | def __init__(self): 54 | # network meta params that by default are determined (by other params) by other params but can be changed 55 | self.filter_shape = ([[3, 3, 3, self.width]] + 56 | [[3, 3, self.width, self.width]] * (self.depth-2) + 57 | [[3, 3, self.width, 3]]) 58 | 59 | 60 | ######################################## 61 | # Some pre-made useful example configs # 62 | ######################################## 63 | 64 | # Basic default config (same as not specifying), non-gradual SRx2 with default bicubic kernel (Ideal case) 65 | # example is set to run on set14 66 | X2_ONE_JUMP_IDEAL_CONF = Config() 67 | X2_ONE_JUMP_IDEAL_CONF.input_path = os.path.dirname(__file__) + '/set14' 68 | 69 | # Same as above but with visualization (Recommended for one image, interactive mode, for debugging) 70 | X2_IDEAL_WITH_PLOT_CONF = Config() 71 | X2_IDEAL_WITH_PLOT_CONF.plot_losses = True 72 | X2_IDEAL_WITH_PLOT_CONF.run_test_every = 20 73 | X2_IDEAL_WITH_PLOT_CONF.input_path = os.path.dirname(__file__) + '/example_with_gt' 74 | 75 | # Gradual SRx2, to achieve superior results in the ideal case 76 | X2_GRADUAL_IDEAL_CONF = Config() 77 | X2_GRADUAL_IDEAL_CONF.scale_factors = [[1.0, 1.5], [1.5, 1.0], [1.5, 1.5], [1.5, 2.0], [2.0, 1.5], [2.0, 2.0]] 78 | X2_GRADUAL_IDEAL_CONF.back_projection_iters = [6, 6, 8, 10, 10, 12] 79 | X2_GRADUAL_IDEAL_CONF.input_path = os.path.dirname(__file__) + '/set14' 80 | 81 | # Applying a given kernel. Rotations are canceled sense kernel may be non-symmetric 82 | X2_GIVEN_KERNEL_CONF = Config() 83 | X2_GIVEN_KERNEL_CONF.output_flip = False 84 | X2_GIVEN_KERNEL_CONF.augment_allow_rotation = False 85 | X2_GIVEN_KERNEL_CONF.back_projection_iters = [2] 86 | X2_GIVEN_KERNEL_CONF.input_path = os.path.dirname(__file__) + '/kernel_example' 87 | 88 | # An example for a typical setup for real images. (Kernel needed + mild unknown noise) 89 | # back-projection is not recommended because of the noise. 90 | X2_REAL_CONF = Config() 91 | X2_REAL_CONF.output_flip = False 92 | X2_REAL_CONF.back_projection_iters = [0] 93 | X2_REAL_CONF.input_path = os.path.dirname(__file__) + '/real_example' 94 | X2_REAL_CONF.noise_std = 0.0125 95 | X2_REAL_CONF.augment_allow_rotation = False 96 | X2_REAL_CONF.augment_scale_diff_sigma = 0 97 | X2_REAL_CONF.augment_shear_sigma = 0 98 | X2_REAL_CONF.augment_min_scale = 0.75 99 | -------------------------------------------------------------------------------- /example_with_gt/bsd_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/example_with_gt/bsd_001.png -------------------------------------------------------------------------------- /example_with_gt/bsd_001_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/example_with_gt/bsd_001_gt.png -------------------------------------------------------------------------------- /figs/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/figs/sketch.png -------------------------------------------------------------------------------- /imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import filters, measurements, interpolation 3 | from math import pi 4 | 5 | 6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 9 | 10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 13 | 14 | # Choose interpolation method, each method has the matching kernel size 15 | method, kernel_width = { 16 | "cubic": (cubic, 4.0), 17 | "lanczos2": (lanczos2, 4.0), 18 | "lanczos3": (lanczos3, 6.0), 19 | "box": (box, 1.0), 20 | "linear": (linear, 2.0), 21 | None: (cubic, 4.0) # set default interpolation method as cubic 22 | }.get(kernel) 23 | 24 | # Antialiasing is only used when downscaling 25 | antialiasing *= (scale_factor[0] < 1) 26 | 27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 29 | 30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 31 | out_im = np.copy(im) 32 | for dim in sorted_dims: 33 | # No point doing calculations for scale-factor 1. nothing will happen anyway 34 | if scale_factor[dim] == 1.0: 35 | continue 36 | 37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 38 | # weights that multiply the values there to get its result. 39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 40 | method, kernel_width, antialiasing) 41 | 42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 44 | 45 | return out_im 46 | 47 | 48 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 50 | # same size as the number of input dimensions) 51 | if scale_factor is not None: 52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 53 | if np.isscalar(scale_factor): 54 | scale_factor = [scale_factor, scale_factor] 55 | 56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 57 | scale_factor = list(scale_factor) 58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 59 | 60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 61 | # to all the unspecified dimensions 62 | if output_shape is not None: 63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 64 | 65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 66 | # sub-optimal, because there can be different scales to the same output-shape. 67 | if scale_factor is None: 68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 69 | 70 | # Dealing with missing output-shape. calculating according to scale-factor 71 | if output_shape is None: 72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 73 | 74 | return scale_factor, output_shape 75 | 76 | 77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 79 | # such that each position from the field_of_view will be multiplied with a matching filter from the 80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 81 | # around it. This is only done for one dimension of the image. 82 | 83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 84 | # 1/sf. this means filtering is more 'low-pass filter'. 85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 86 | kernel_width *= 1.0 / scale if antialiasing else 1.0 87 | 88 | # These are the coordinates of the output image 89 | out_coordinates = np.arange(1, out_length+1) 90 | 91 | # These are the matching positions of the output-coordinates on the input image coordinates. 92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 101 | 102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 103 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 104 | 105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 107 | expanded_kernel_width = np.ceil(kernel_width) + 2 108 | 109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 111 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 113 | 114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 116 | # 'field_of_view') 117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 118 | 119 | # Normalize weights to sum up to 1. be careful from dividing by 0 120 | sum_weights = np.sum(weights, axis=1) 121 | sum_weights[sum_weights == 0] = 1.0 122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 123 | 124 | # We use this mirror structure as a trick for reflection padding at the boundaries 125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 127 | 128 | # Get rid of weights and pixel positions that are of zero weight 129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 130 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 132 | 133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 134 | return weights, field_of_view 135 | 136 | 137 | def resize_along_dim(im, dim, weights, field_of_view): 138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 139 | tmp_im = np.swapaxes(im, dim, 0) 140 | 141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 142 | # tmp_im[field_of_view.T], (bsxfun style) 143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 144 | 145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 150 | # same number 151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 152 | 153 | # Finally we swap back the axes to the original order 154 | return np.swapaxes(tmp_out_im, dim, 0) 155 | 156 | 157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 158 | # See kernel_shift function to understand what this is 159 | if kernel_shift_flag: 160 | kernel = kernel_shift(kernel, scale_factor) 161 | 162 | # First run a correlation (convolution with flipped kernel) 163 | out_im = np.zeros_like(im) 164 | for channel in range(np.ndim(im)): 165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 166 | 167 | # Then subsample and return 168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 170 | 171 | 172 | def kernel_shift(kernel, sf): 173 | # There are two reasons for shifting the kernel: 174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 175 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 178 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 181 | 182 | # First calculate the current center of mass for the kernel 183 | current_center_of_mass = measurements.center_of_mass(kernel) 184 | 185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 186 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 187 | 188 | # Define the shift vector for the kernel shifting (x,y) 189 | shift_vec = wanted_center_of_mass - current_center_of_mass 190 | 191 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 192 | # (biggest shift among dims + 1 for safety) 193 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 194 | 195 | # Finally shift the kernel and return 196 | return interpolation.shift(kernel, shift_vec) 197 | 198 | 199 | # These next functions are all interpolation methods. x is the distance from the left pixel center 200 | 201 | 202 | def cubic(x): 203 | absx = np.abs(x) 204 | absx2 = absx ** 2 205 | absx3 = absx ** 3 206 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + 207 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) 208 | 209 | 210 | def lanczos2(x): 211 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / 212 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) 213 | * (abs(x) < 2)) 214 | 215 | 216 | def box(x): 217 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 218 | 219 | 220 | def lanczos3(x): 221 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / 222 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) 223 | * (abs(x) < 3)) 224 | 225 | 226 | def linear(x): 227 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 228 | -------------------------------------------------------------------------------- /kernel_example/BSD100_100_lr_rand_ker_c_X2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/kernel_example/BSD100_100_lr_rand_ker_c_X2.png -------------------------------------------------------------------------------- /kernel_example/BSD100_100_lr_rand_ker_c_X2_0.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/kernel_example/BSD100_100_lr_rand_ker_c_X2_0.mat -------------------------------------------------------------------------------- /real_example/charlie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/real_example/charlie.png -------------------------------------------------------------------------------- /real_example/charlie_0.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/real_example/charlie_0.mat -------------------------------------------------------------------------------- /run_ZSSR.py: -------------------------------------------------------------------------------- 1 | import GPUtil 2 | import glob 3 | import os 4 | from utils import prepare_result_dir 5 | import configs 6 | from time import sleep 7 | import sys 8 | import run_ZSSR_single_input 9 | 10 | 11 | def main(conf_name, gpu): 12 | # Initialize configs and prepare result dir with date 13 | if conf_name is None: 14 | conf = configs.Config() 15 | else: 16 | conf = None 17 | exec ('conf = configs.%s' % conf_name) 18 | res_dir = prepare_result_dir(conf) 19 | local_dir = os.path.dirname(__file__) 20 | 21 | # We take all png files that are not ground truth 22 | files = [file_path for file_path in glob.glob('%s/*.png' % conf.input_path) 23 | if not file_path[-7:-4] == '_gt'] 24 | 25 | # Loop over all the files 26 | for file_ind, input_file in enumerate(files): 27 | 28 | # Ground-truth file needs to be like the input file with _gt (if exists) 29 | ground_truth_file = input_file[:-4] + '_gt.png' 30 | if not os.path.isfile(ground_truth_file): 31 | ground_truth_file = '0' 32 | 33 | # Numeric kernel files need to be like the input file with serial number 34 | kernel_files = ['%s_%d.mat;' % (input_file[:-4], ind) for ind in range(len(conf.scale_factors))] 35 | kernel_files_str = ''.join(kernel_files) 36 | for kernel_file in kernel_files: 37 | if not os.path.isfile(kernel_file[:-1]): 38 | kernel_files_str = '0' 39 | print 'no kernel loaded' 40 | break 41 | 42 | print kernel_files 43 | 44 | # This option uses all the gpu resources efficiently 45 | if gpu == 'all': 46 | 47 | # Stay stuck in this loop until there is some gpu available with at least half capacity 48 | gpus = [] 49 | while not gpus: 50 | gpus = GPUtil.getAvailable(order='memory') 51 | 52 | # Take the gpu with the most free memory 53 | cur_gpu = gpus[-1] 54 | 55 | # Run ZSSR from command line, open xterm for each run 56 | os.system("xterm -hold -e " + conf.python_path + 57 | " %s/run_ZSSR_single_input.py '%s' '%s' '%s' '%s' '%s' '%s' alias python &" 58 | % (local_dir, input_file, ground_truth_file, kernel_files_str, cur_gpu, conf_name, res_dir)) 59 | 60 | # Verbose 61 | print 'Ran file #%d: %s on GPU %d\n' % (file_ind, input_file, cur_gpu) 62 | 63 | # Wait 5 seconds for the previous process to start using GPU. if we wouldn't wait then GPU memory will not 64 | # yet be taken and all process will start on the same GPU at once and later collapse. 65 | sleep(5) 66 | 67 | # The other option is just to run sequentially on a chosen GPU. 68 | else: 69 | run_ZSSR_single_input.main(input_file, ground_truth_file, kernel_files_str, gpu, conf_name, res_dir) 70 | 71 | 72 | if __name__ == '__main__': 73 | conf_str = sys.argv[1] if len(sys.argv) > 1 else None 74 | gpu_str = sys.argv[2] if len(sys.argv) > 2 else None 75 | main(conf_str, gpu_str) 76 | -------------------------------------------------------------------------------- /run_ZSSR_single_input.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import configs 4 | import ZSSR 5 | 6 | 7 | def main(input_img, ground_truth, kernels, gpu, conf_str, results_path): 8 | # Choose the wanted GPU 9 | if gpu is not None: 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '%s' % gpu 11 | 12 | # 0 input for ground-truth or kernels means None 13 | ground_truth = None if ground_truth == '0' else ground_truth 14 | print '*****', kernels 15 | kernels = None if kernels == '0' else kernels.split(';')[:-1] 16 | 17 | # Setup configuration and results directory 18 | conf = configs.Config() 19 | if conf_str is not None: 20 | exec ('conf = configs.%s' % conf_str) 21 | conf.result_path = results_path 22 | 23 | # Run ZSSR on the image 24 | net = ZSSR.ZSSR(input_img, conf, ground_truth, kernels) 25 | net.run() 26 | 27 | 28 | if __name__ == '__main__': 29 | main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6]) 30 | -------------------------------------------------------------------------------- /set14/img_001_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_001_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_002_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_002_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_003_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_003_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_004_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_004_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_005_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_005_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_006_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_006_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_007_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_007_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_008_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_008_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_009_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_009_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_010_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_010_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_011_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_011_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_012_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_012_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_013_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_013_SRF_2_LR.png -------------------------------------------------------------------------------- /set14/img_014_SRF_2_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assafshocher/ZSSR/d2948526fde799a0c111469b78cd9e1b8659cb0f/set14/img_014_SRF_2_LR.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import pi, sin, cos 3 | from cv2 import warpPerspective, INTER_CUBIC 4 | from imresize import imresize 5 | from shutil import copy 6 | from time import strftime, localtime 7 | import os 8 | import glob 9 | from scipy.ndimage import measurements, interpolation 10 | from scipy.io import loadmat 11 | 12 | 13 | def random_augment(ims, 14 | base_scales=None, 15 | leave_as_is_probability=0.2, 16 | no_interpolate_probability=0.3, 17 | min_scale=0.5, 18 | max_scale=1.0, 19 | allow_rotation=True, 20 | scale_diff_sigma=0.01, 21 | shear_sigma=0.01, 22 | crop_size=128): 23 | 24 | # Determine which kind of augmentation takes place according to probabilities 25 | random_chooser = np.random.rand() 26 | 27 | # Option 1: No augmentation, return the original image 28 | if random_chooser < leave_as_is_probability: 29 | mode = 'leave_as_is' 30 | 31 | # Option 2: Only non-interpolated augmentation, which means 8 possible augs (4 rotations X 2 mirror flips) 32 | elif leave_as_is_probability < random_chooser < leave_as_is_probability + no_interpolate_probability: 33 | mode = 'no_interp' 34 | 35 | # Option 3: Affine transformation (uses interpolation) 36 | else: 37 | mode = 'affine' 38 | 39 | # If scales not given, calculate them according to sizes of images. This would be suboptimal, because when scales 40 | # are not integers, different scales can have the same image shape. 41 | if base_scales is None: 42 | base_scales = [np.sqrt(np.prod(im.shape) / np.prod(ims[0].shape)) for im in ims] 43 | 44 | # In case scale is a list of scales with take the smallest one to be the allowed minimum 45 | max_scale = np.min([max_scale]) 46 | 47 | # Determine a random scale by probability 48 | if mode == 'leave_as_is': 49 | scale = 1.0 50 | else: 51 | scale = np.random.rand() * (max_scale - min_scale) + min_scale 52 | 53 | # The image we will use is the smallest one that is bigger than the wanted scale 54 | # (Using a small value overlap instead of >= to prevent float issues) 55 | scale_ind, base_scale = next((ind, np.min([base_scale])) for ind, base_scale in enumerate(base_scales) 56 | if np.min([base_scale]) > scale - 1.0e-6) 57 | im = ims[scale_ind] 58 | 59 | # Next are matrices whose multiplication will be the transformation. All are 3x3 matrices. 60 | 61 | # First matrix shifts image to center so that crop is in the center of the image 62 | shift_to_center_mat = np.array([[1, 0, - im.shape[1] / 2.0], 63 | [0, 1, - im.shape[0] / 2.0], 64 | [0, 0, 1]]) 65 | 66 | shift_back_from_center = np.array([[1, 0, im.shape[1] / 2.0], 67 | [0, 1, im.shape[0] / 2.0], 68 | [0, 0, 1]]) 69 | # Keeping the transform interpolation free means only shifting by integers 70 | if mode != 'affine': 71 | shift_to_center_mat = np.round(shift_to_center_mat) 72 | shift_back_from_center = np.round(shift_back_from_center) 73 | 74 | # Scale matrix. if affine, first determine global scale by probability, then determine difference between x scale 75 | # and y scale by gaussian probability. 76 | if mode == 'affine': 77 | scale /= base_scale 78 | scale_diff = np.random.randn() * scale_diff_sigma 79 | else: 80 | scale = 1.0 81 | scale_diff = 0.0 82 | # In this matrix we also incorporate the possibility of mirror reflection (unless leave_as_is). 83 | if mode == 'leave_as_is' or not allow_rotation: 84 | reflect = 1 85 | else: 86 | reflect = np.sign(np.random.randn()) 87 | 88 | scale_mat = np.array([[reflect * (scale + scale_diff / 2), 0, 0], 89 | [0, scale - scale_diff / 2, 0], 90 | [0, 0, 1]]) 91 | 92 | # Shift matrix, this actually creates the random crop 93 | shift_x = np.random.rand() * np.clip(scale * im.shape[1] - crop_size, 0, 9999) 94 | shift_y = np.random.rand() * np.clip(scale * im.shape[0] - crop_size, 0, 9999) 95 | shift_mat = np.array([[1, 0, - shift_x], 96 | [0, 1, - shift_y], 97 | [0, 0, 1]]) 98 | 99 | # Keeping the transform interpolation free means only shifting by integers 100 | if mode != 'affine': 101 | shift_mat = np.round(shift_mat) 102 | 103 | # Rotation matrix angle. if affine, set a random angle. if no_interp then theta can only be pi/2 times int. 104 | if mode == 'affine': 105 | theta = np.random.rand() * 2 * pi 106 | elif mode == 'no_interp': 107 | theta = np.random.randint(4) * pi / 2 108 | else: 109 | theta = 0 110 | if not allow_rotation: 111 | theta = 0 112 | 113 | # Rotation matrix structure 114 | rotation_mat = np.array([[cos(theta), sin(theta), 0], 115 | [-sin(theta), cos(theta), 0], 116 | [0, 0, 1]]) 117 | 118 | # Shear Matrix, only for affine transformation. 119 | if mode == 'affine': 120 | shear_x = np.random.randn() * shear_sigma 121 | shear_y = np.random.randn() * shear_sigma 122 | else: 123 | shear_x = shear_y = 0 124 | shear_mat = np.array([[1, shear_x, 0], 125 | [shear_y, 1, 0], 126 | [0, 0, 1]]) 127 | 128 | # Create the final transformation by multiplying all the transformations. 129 | transform_mat = (shift_back_from_center 130 | .dot(shift_mat) 131 | .dot(shear_mat) 132 | .dot(rotation_mat) 133 | .dot(scale_mat) 134 | .dot(shift_to_center_mat)) 135 | 136 | # Apply transformation to image and return the transformed image clipped between 0-1 137 | return np.clip(warpPerspective(im, transform_mat, (crop_size, crop_size), flags=INTER_CUBIC), 0, 1) 138 | 139 | 140 | def back_projection(y_sr, y_lr, down_kernel, up_kernel, sf=None): 141 | y_sr += imresize(y_lr - imresize(y_sr, 142 | scale_factor=1.0/sf, 143 | output_shape=y_lr.shape, 144 | kernel=down_kernel), 145 | scale_factor=sf, 146 | output_shape=y_sr.shape, 147 | kernel=up_kernel) 148 | return np.clip(y_sr, 0, 1) 149 | 150 | 151 | def preprocess_kernels(kernels, conf): 152 | # Load kernels if given files. if not just use the downscaling method from the configs. 153 | # output is a list of kernel-arrays or a a list of strings indicating downscaling method. 154 | # In case of arrays, we shift the kernels (see next function for explanation why). 155 | # Kernel is a .mat file (MATLAB) containing a variable called 'Kernel' which is a 2-dim matrix. 156 | if kernels is not None: 157 | return [kernel_shift(loadmat(kernel)['Kernel'], sf) 158 | for kernel, sf in zip(kernels, conf.scale_factors)] 159 | else: 160 | return [conf.downscale_method] * len(conf.scale_factors) 161 | 162 | 163 | def kernel_shift(kernel, sf): 164 | # There are two reasons for shifting the kernel: 165 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 166 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 167 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 168 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 169 | # top left corner of the first pixel. that is why different shift size needed between odd and even size. 170 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 171 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 172 | 173 | # First calculate the current center of mass for the kernel 174 | current_center_of_mass = measurements.center_of_mass(kernel) 175 | 176 | # The second term ("+ 0.5 * ....") is for applying condition 2 from the comments above 177 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (np.array(sf) - (np.array(kernel.shape) % 2)) 178 | 179 | # Define the shift vector for the kernel shifting (x,y) 180 | shift_vec = wanted_center_of_mass - current_center_of_mass 181 | 182 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 183 | # (biggest shift among dims + 1 for safety) 184 | kernel = np.pad(kernel, np.int(np.ceil(np.max(np.abs(shift_vec)))) + 1, 'constant') 185 | 186 | # Finally shift the kernel and return 187 | return interpolation.shift(kernel, shift_vec) 188 | 189 | 190 | def prepare_result_dir(conf): 191 | # Create results directory 192 | if conf.create_results_dir: 193 | conf.result_path += '/' + conf.name + strftime('_%b_%d_%H_%M_%S', localtime()) 194 | os.makedirs(conf.result_path) 195 | 196 | # Put a copy of all *.py files in results path, to be able to reproduce experimental results 197 | if conf.create_code_copy: 198 | local_dir = os.path.dirname(__file__) 199 | for py_file in glob.glob(local_dir + '/*.py'): 200 | copy(py_file, conf.result_path) 201 | 202 | return conf.result_path 203 | --------------------------------------------------------------------------------